UPDATE: 2022-09-03 16:06:41

はじめに

このノートでは比例ハザードモデルのパラメタを最尤法で計算する方法についてまとめておく。

比例ハザードモデル

ここでの比例ハザードモデルにはワイブル分布を仮定した話をまとめる。まず確率変数\(T\)\(t\)よりも大きい値を取る確率を生存関数\(S(t)\)という。\(F(t)\)\(T\)のCDFである。

\[ S(t) = P(T > t) = 1-F(t) \]

\(T\)を生存している時間とすれば、\(t\)までは生存しているという事象=その時点でまだ生きている確率を返す関数が生存関数。

ワイブル分布の生存関数は下記の通り定義される。

\[ S(t) = exp \left[ -\frac{t}{\phi}^{\gamma} \right], t>0, \phi>0,\gamma>0 \]

次に危険度関数をまとめておく。危険率(Hazard rate)は、

\[ h(t) = \lim_{\Delta t → 0} \frac{P(t < T \le t + \Delta t| T > t)}{\Delta t} \] であり、時点\(t\)まで生存した\(T > t\)という条件のもとで与えられた個体が\(T = t\)で死亡する確率。\(T\)を生存している時間とすれば、\(t\)時間まで死ななかった人が\(t\)を少し超えたときに寿命が来て死亡したということ。危険度関数\(h(t)\)は、\(T\)のpdfを\(f(t)\)とすれば、

\[ h(t) = \frac{f(t)}{S(t)} \] と定義される。ワイブル分布の危険度関数は、

\[ h(t) = \frac{\gamma}{\phi}(\frac{t}{\phi})^{\gamma-1} \]

であり、危険率は\(\gamma=1\)のとき一定、\(\gamma \gt 1\)のとき\(t\)の増加関数、\(0 \lt \gamma \lt 1\)のとき\(t\)の減少関数となる。

ここで、ワイブル分布の危険度関数において、

\[ \theta = \phi^{-\gamma} = \frac{1}{\phi^{\gamma}},\quad log \theta = \boldsymbol{ x^{T} \beta }\Leftrightarrow \theta = exp \left[\boldsymbol{ x^{T} \beta } \right] \]

と定式化すると、

\[ h(t) = \gamma t^{\gamma-1} exp \left[\boldsymbol{ x^{T} \beta } \right] \]

と表すことができる。ここで$ \(とすると\)h_{0}(t) = t^{}$より、

\[ h(t) = h_{0}(t) exp \left[\boldsymbol{ x^{T} \beta } \right] \]

ワイブル分布の比例危険度モデルとなる。ワイブル分布の生存関数は、

\[ S(t) = exp[-\theta t^{\gamma}] = exp[-exp \left[\boldsymbol{ x^{T} \beta } \right] t^{\gamma}] \] となるため、下記の関係が得られる。

\[ logS(t) = -exp \left[\boldsymbol{ x^{T} \beta } \right] t^{\gamma} \\ log[-logS(t)] = \gamma log t + \boldsymbol{ x^{T} \beta} \]

基本的な情報をまとめたので、ここからワイブル分布の比例危険度モデルの尤度関数は導出する。生存時間\(t\)、打ち切り\(\delta\)を下記の通りとする。打ち切りではない場合は\(1\)、打ち切りの場合\(0\)とする。

\[ t_{i}, i>0,\quad \delta_{i} =0,1 \]

また、これまでの情報をまとめると、

\[ \begin{eqnarray} f(t) &=& \theta \gamma t^{\gamma - 1} exp[-\theta t^{\gamma}]\\ S(t) &=& \int_{t}^{\inf}f(s)ds = exp[-\theta t^{\gamma}] \\ h(t) &=& \theta \gamma t^{\gamma - 1} \\ \theta &=& exp[\boldsymbol{ x^{T} \beta}] \end{eqnarray} \]

であり、\(n\)個のデータについて、\(i\)番目のサンプルの生存時間を\(t_{i}\)とするとき、サンプル\(i\)がが非打ち切りデータならば、その尤度は確率密度関数\(f(t_{i}\)と等しくなるため、尤度を下記のように考えられる。

\[ L = \prod \left\{ f(t_{i})^{\delta_{i}} S(t_{i})^{1-\delta_{i}}\right\} \]

対数尤度関数は、

\[ \begin{eqnarray} logL &=& \sum \left\{ \delta_{i}log(f(t_{i})) + (1-\delta_{i}) log S(t_{i}) \right\} \\ &=& \sum \left\{ \delta_{i}[ log \gamma + (\gamma - 1)log t_{i} +\boldsymbol{ x_{i} \beta} ] -exp[\boldsymbol{ x_{i}^{T} \beta} ] t_{i}^{\gamma} \right\} \end{eqnarray} \] となる。

\[ \begin{eqnarray} \frac{ \partial log L(\boldsymbol{\beta, \gamma}) }{ \partial \beta_{j}} &=& \sum \delta_{i}x_{ij} - x_{ij}exp[\boldsymbol{ x_{i}^{T} \beta}]t_{i}^{\gamma} = 0 \\ \frac{ \partial log L(\boldsymbol{\beta, \gamma}) }{ \partial \gamma} &=& \sum \delta_{i} \left( \frac{1}{\gamma} + logt_{i} \right) - exp[\boldsymbol{ x_{i}^{T} \beta}] t_{i}^{\gamma} logt_{i}= 0 \end{eqnarray} \]

ここからはRで実装していく。サンプルデータは下記を参照した。

library(tidyverse)
library(eha)

# From: https://statisticalhorizons.com/resources/data-sets
# library(foreign) 
# recid.wide <- foreign:: read.dta("recid.dta")
# work1~work52は削除したrecid_small.csvを利用する。
df <- read_csv("~/Desktop/recid_small.csv")
head(df)
## # A tibble: 6 × 10
##    week arrest   fin   age  race  wexp   mar  paro  prio  educ
##   <dbl>  <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1    20      1     0    27     1     0     0     1     3     3
## 2    17      1     0    18     1     0     0     1     8     4
## 3    25      1     0    19     0     1     0     1    13     3
## 4    52      0     1    23     1     1     1     1     1     5
## 5    52      0     0    19     0     1     0     1     3     3
## 6    52      0     0    24     1     1     0     0     2     4

まずは、ehaパッケージのweibreg関数でワイブル分布の比例危険度モデルを実行し、計算したいパラメタの値を確認しておく。survivalパッケージのsurvreg関数はAccelarated Failure Time Modelsなので、ここで想定しているモデルとは異なるので注意。

# phreg() is same
wfit <- weibreg(Surv(week, arrest) ~ age + educ + fin, data = df)
wfit
## Call:
## weibreg(formula = Surv(week, arrest) ~ age + educ + fin, data = df)
## 
## Covariate           Mean       Coef Exp(Coef)  se(Coef)    Wald p
## age                24.765    -0.069     0.933     0.021     0.001 
## educ                3.490    -0.254     0.775     0.127     0.045 
## fin                 0.511    -0.309     0.734     0.190     0.104 
## 
## log(scale)                    2.893    18.053     0.460     0.000 
## log(shape)                    0.325     1.384     0.090     0.000 
## 
## Events                    
## Total time at risk         19809 
## Max. log. likelihood      -685.35 
## LR test statistic         22.6 
## Degrees of freedom        3 
## Overall p-value           5.00642e-05

まずは、変数をベクトルと行列に格納する。

t <- df %>% pull(week)
delta <- df %>% pull(arrest)
X <- df %>% 
  mutate(x0 = 1) %>% 
  select(x0, x1 = age, x2 = educ, x3 = fin) %>% 
  as.matrix()

ここでは準ニュートン法でワイブル分布の比例危険度モデルのパラメタを計算する。

f <- function(x){
  gamma <- x[1]
  b <- x[-1]
  f1 <- sum(delta * (1 / gamma + log(t)) - exp(X %*% c(b)) * log(t) * t^gamma)
  f2 <- apply(X * c(delta), 2, sum) - apply(X * c(exp(X %*% b) * t^gamma), 2, sum)
  return(c(f1, f2))
}

iter <- 10^6       # パラメータの初期値
eta <- 10^(-10)   # 学習率
B <- rep(0.01, 5) # パラメタの初期値(γ, intercept, age, educ, fin)
H <- diag(f(B))   # ヤコビアンの初期値

for(i in 1:iter){
  eta <- eta + 10^(-11) # 学習率を更新数に応じて増加修正
  B_pre <- B
  B <- B - eta * H %*% f(B)
  s <- B - B_pre
  y <- f(B) - f(B_pre)
  # ヤコビアン近似行列を更新
  H <- H + ((s - H %*% y) / as.numeric(t(s) %*% H %*% y)) %*% t(s) %*% H
  # alpha <- B[1]
  # beta  <- B[-1]
  # loglik <- sum(delta * (X %*% c(beta) + log(rep(alpha,length(t))) + (rep(alpha, length(t)) - 1) * log(t))) - sum(exp(X %*% c(beta)) * t^alpha)
  # print(sprintf("%d times: (γ=%2.5f, x0=%2.5f, x1=%2.5f, x2=%2.5f, x3=%2.5f)", i, B[1,1], B[2,1], B[3,1], B[4,1], B[5,1]))
  if(i %% 50000 == 0) {
    print(sprintf("%d times: (γ=%2.5f, x0=%2.5f, x1=%2.5f, x2=%2.5f, x3=%2.5f)",
                  i, B[1,1], B[2,1], B[3,1], B[4,1], B[5,1]))
  }
}
## [1] "50000 times: (γ=0.03342, x0=0.00948, x1=-0.02634, x2=0.00817, x3=0.00850)"
## [1] "100000 times: (γ=0.03467, x0=0.00835, x1=-0.02681, x2=0.00475, x3=0.00402)"
## [1] "150000 times: (γ=0.03686, x0=0.00736, x1=-0.02759, x2=-0.00101, x3=-0.00350)"
## [1] "200000 times: (γ=0.04015, x0=0.00770, x1=-0.02871, x2=-0.00921, x3=-0.01412)"
## [1] "250000 times: (γ=0.04479, x0=0.01068, x1=-0.03018, x2=-0.01994, x3=-0.02787)"
## [1] "300000 times: (γ=0.05117, x0=0.01739, x1=-0.03203, x2=-0.03325, x3=-0.04470)"
## [1] "350000 times: (γ=0.05983, x0=0.02810, x1=-0.03425, x2=-0.04912, x3=-0.06447)"
## [1] "400000 times: (γ=0.07152, x0=0.04155, x1=-0.03682, x2=-0.06736, x3=-0.08680)"
## [1] "450000 times: (γ=0.08726, x0=0.05414, x1=-0.03972, x2=-0.08756, x3=-0.11112)"
## [1] "500000 times: (γ=0.10837, x0=0.05946, x1=-0.04286, x2=-0.10908, x3=-0.13661)"
## [1] "550000 times: (γ=0.13649, x0=0.04790, x1=-0.04614, x2=-0.13106, x3=-0.16227)"
## [1] "600000 times: (γ=0.17404, x0=0.00561, x1=-0.04943, x2=-0.15249, x3=-0.18708)"
## [1] "650000 times: (γ=0.22454, x0=-0.08590, x1=-0.05260, x2=-0.17246, x3=-0.21019)"
## [1] "700000 times: (γ=0.29208, x0=-0.24682, x1=-0.05556, x2=-0.19022, x3=-0.23094)"
## [1] "750000 times: (γ=0.38047, x0=-0.49529, x1=-0.05825, x2=-0.20539, x3=-0.24897)"
## [1] "800000 times: (γ=0.49199, x0=-0.84239, x1=-0.06064, x2=-0.21790, x3=-0.26417)"
## [1] "850000 times: (γ=0.62522, x0=-1.28418, x1=-0.06271, x2=-0.22794, x3=-0.27659)"
## [1] "900000 times: (γ=0.77297, x0=-1.79411, x1=-0.06446, x2=-0.23582, x3=-0.28641)"
## [1] "950000 times: (γ=0.92220, x0=-2.32246, x1=-0.06588, x2=-0.24182, x3=-0.29387)"
## [1] "1000000 times: (γ=1.05767, x0=-2.80993, x1=-0.06697, x2=-0.24623, x3=-0.29928)"
# print(B)
#             [,1]
# [1,]  1.05767039 -- γ
# [2,] -2.80992807 -- intercept
# [3,] -0.06696622 -- x1(age)
# [4,] -0.24623321 -- x2(educ)
# [5,] -0.29928444 -- x3(fin)