UPDATE: 2024-01-27 14:51:54.075614

はじめに

このノートは「ベイズ統計」に関する何らかの内容をまとめ、ベイズ統計への理解を深めていくために作成している。k下記の資料を参考にさせていただき、今回は「WAIC」のおさらいをしておく。私の解釈がおかしく、メモが誤っている場合があるので注意。

WAICの定義

WAIC(Widely applicable information criterion)は広く使える情報量規準であり、モデル選択において活用できる。WAICの詳細に関しては、書籍やリンク先などを参照いただくとして、ここでは簡単におさらいしておく。

WAICは汎化損失\(G\)を近似する量となっている。\(q(x)\)は真の分布、\(\log p^{*}(x)\)は対数尤度である。

\[ \begin{eqnarray} G = - \int q(x) \log p^{*}(x) dx \end{eqnarray} \]

\(p^{*}(x)\)は事後予測分布である。事後予測分布は、事後分布でモデルを重みづけたもの。

\[ \begin{eqnarray} p^{*}(x) = p(x|X) = \int p(x|\theta) p(\theta|X) d\theta \end{eqnarray} \] 汎化損失はと\(q(x), p^{*}(x)\)の違いを測定し、この汎化損失を近似しているものがWAIC。ただ、真の分布\(q(x)\)が含まれているので直接は計算できない。

ここで、真の分布\(q(x)\)から新しく無限にサンプリングができるという状況を考える。つまり、無限のサンプリングができるのであれば、モンテカルロ近似で積分を近似できる。

\[ \begin{eqnarray} G = - \int q(x) \log p^{*}(x) dx \approx - \frac{1}{N} \sum_{i}^{N} \log p^{*}(x_i) \end{eqnarray} \]

この結果、汎化損失は真の分布から無限にサンプリングしたとき、その対数尤度を平均したもので近似される。対数尤度は、事後予測分布への当てはまり具合を表している。

ここでWAICの定義を見ておく。\(T\)は経験損失、\(V\)は汎関数分散を表す。なぜ、これが汎化損失の近似になっているのかは、渡辺先生が証明された論文を参考にしてもらうとして(証明を私は理解できない)、一般的な条件のもとであれば汎化損失が経験損失と汎関数分散で近似的に計算できる。

\[ \begin{eqnarray} WAIC &=& -\frac{1}{n} \sum_{i}^{N} \log p^{*}(X_i) + \frac{1}{n} \sum_{i}^{N} \left\{ E_{\theta} \left[ (\log p(X_i|\theta))^{2} \right] - E_{\theta} \left[ (\log p(X_i|\theta) \right]^{2} \right\} \\ &=& T + \frac{V}{n} \\ \end{eqnarray} \] WAICに関して調べると、下記の通り定義しているものもある。\(lppd\)は対数各点予測密度(log point-wise predictive density)で、\(P_{waic}\)は有効パラメタ数(effective number of parameters)である。対数各点予測密度=経験損失、有効パラメタ数=汎関数分散に対応ている。\(\theta_1, \theta_2, …, \theta_S\)はパラメタの事後分布からのサンプルで あり、見慣れない\(V^S_{s}(\cdot)\)はサンプル\(\{\theta_s\}^S_{s}\)による分散を表している。

\[ \begin{eqnarray} WAIC &=& -2 lppd + 2 P_{waic} \\ &=& -2 \cdot \left[ \sum_{i}^{N} \log \left( \frac{1}{S} \sum_{i}^{S} p(x_i | \theta_s )\right) \right] + 2 \cdot \left[ \sum_{i}^{N} V^{S}_{s} (\log p (x_i|\theta_{s})) \right] \end{eqnarray} \]

WAICの計算

まずは必要なライブラリや設定を行っておく。

library(tidyverse)
library(rstan)
library(loo)

options(max.print = 999999)
rstan_options(auto_write=TRUE)
options(mc.cores=parallel::detectCores())

データを用意する。

d <- data.frame(
  school = c('A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'),
  y = c(28, 8, -3, 7, -1, 1, 18, 12),
  s = c(15, 10, 16, 11, 9, 11, 10, 18)
  )
j <- nrow(d)
data01 <- list(Y = d$y, S = d$s, J = j)

map(.x = data01, .f = function(x){head(x, 50)})
## $Y
## [1] 28  8 -3  7 -1  1 18 12
## 
## $S
## [1] 15 10 16 11  9 11 10 18
## 
## $J
## [1] 8

モデルはこちら。WAICを計算するためにnormal_lpdf()関数からのサンプリングを行っておく。

data {
  int<lower=1> J;
  real Y[J];
  real<lower=0> S[J];
}
parameters {
  real theta[J];
  real mu;
  real<lower=0> sigma;
}
model {
  for (j in 1:J) {
    theta[j] ~ normal(mu, sigma);
  }

  for (j in 1:J) {
    Y[j] ~ normal(theta[j], S[j]);
  }
}

generated quantities{
  vector[J] log_lik;
  
  for (j in 1:J) {
    // The log of the normal density of y
    log_lik[j] = normal_lpdf(Y[j]| theta[j], S[j]);
  }
} 

先にコンパイルしてから、sampling()関数でサンプリングする。

model01 <- stan_model('model01.stan')
fit01 <- sampling(object = model01, data = data01, seed = 1989)

推定結果を確認する。log_lik[]が推定されている。これを使ってWAICを計算する。

print(fit01, prob = c(0.025, 0.5, 0.975), digits = 2)
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##              mean se_mean   sd   2.5%    50% 97.5% n_eff Rhat
## theta[1]    11.24    0.37 8.55  -2.51  10.05 31.49   536 1.01
## theta[2]     7.83    0.21 6.23  -3.83   7.55 20.73   852 1.01
## theta[3]     5.99    0.21 7.78 -10.91   6.47 20.55  1393 1.01
## theta[4]     7.30    0.20 6.54  -5.84   7.18 20.65  1068 1.01
## theta[5]     4.76    0.24 6.35  -9.37   5.18 16.31   678 1.01
## theta[6]     5.78    0.22 6.75  -9.45   5.93 18.57   937 1.01
## theta[7]    10.47    0.26 6.86  -1.13   9.61 25.18   716 1.01
## theta[8]     8.12    0.20 7.90  -7.61   7.66 24.78  1627 1.00
## mu           7.64    0.20 5.08  -1.73   7.50 17.71   667 1.01
## sigma        6.89    0.34 5.42   1.31   5.52 20.59   251 1.02
## log_lik[1]  -4.41    0.03 0.57  -5.70  -4.35 -3.64   512 1.01
## log_lik[2]  -3.42    0.01 0.33  -4.33  -3.30 -3.22  1302 1.00
## log_lik[3]  -3.97    0.01 0.32  -4.79  -3.88 -3.69  1596 1.00
## log_lik[4]  -3.49    0.01 0.28  -4.30  -3.39 -3.32  1481 1.00
## log_lik[5]  -3.57    0.02 0.52  -4.98  -3.41 -3.12   997 1.01
## log_lik[6]  -3.60    0.01 0.40  -4.67  -3.46 -3.32  1135 1.00
## log_lik[7]  -3.74    0.02 0.54  -5.07  -3.59 -3.22   634 1.01
## log_lik[8]  -3.93    0.01 0.19  -4.44  -3.86 -3.81  1089 1.00
## lp__       -17.76    0.49 5.38 -27.62 -18.19 -7.21   119 1.03
## 
## Samples were drawn using NUTS(diag_e) at Sat Jan 27 14:51:56 2024.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

いずれの定義にしても同じものを表しているので、Rで下記の通り実装できる。

# loo package
# https://github.com/stan-dev/loo/blob/master/R/waic.R
get_waic <- function(loglik){
  lppd <- sum(log(colMeans(exp(loglik)))) # loglikは対数なので指数変換してもとに戻す
  p_waic <- sum(apply(loglik, 2, var))
  elpd_waic <- lppd - p_waic
  waic <- -2*lppd + 2*p_waic
  
  return(
    list(
      lppd = lppd,
      elpd_waic = elpd_waic,
      p_waic = p_waic,
      waic = waic
    )
  )
}

ll <- rstan::extract(fit01, pars = 'log_lik')$log_lik
dim(ll)
## [1] 4000    8

このような計算イメージ。

WAICを計算すると、61.9となった。WAICが小さいモデルを選択すると良いので、他のモデルがあれば同様にWAICを計算することで比較し、よりよいモデルを選択できる。つまり、「真の分布\(q(x)\)と予測分布\(p^{*}(x)\)の誤差が小さい」モデルを選択できる。

map(.x = get_waic(loglik = ll), .f = function(x){round(x ,1)})
## $lppd
## [1] -29.6
## 
## $elpd_waic
## [1] -31
## 
## $p_waic
## [1] 1.4
## 
## $waic
## [1] 61.9

looパッケージのwaic()関数を使用すれば、同様にWAICを計算できる。

waic(ll)
## 
## Computed from 4000 by 8 log-likelihood matrix
## 
##           Estimate  SE
## elpd_waic    -31.0 1.0
## p_waic         1.4 0.3
## waic          61.9 2.0

looパッケージのloo()関数もあり、WAICと似たような値を得られる。この関数はEfficient approximate leave-one-out cross-validation(LOO)を行っている関数であり、要するに、1個抜きクロスバリデーションである。

loo(ll)
## 
## Computed from 4000 by 8 log-likelihood matrix
## 
##          Estimate  SE
## elpd_loo    -31.2 1.0
## p_loo         1.6 0.3
## looic        62.3 2.0
## ------
## Monte Carlo SE of elpd_loo is NA.
## 
## Pareto k diagnostic values:
##                          Count Pct.    Min. n_eff
## (-Inf, 0.5]   (good)     1     12.5%   3645      
##  (0.5, 0.7]   (ok)       6     75.0%   1121      
##    (0.7, 1]   (bad)      1     12.5%   1036      
##    (1, Inf)   (very bad) 0      0.0%   <NA>      
## See help('pareto-k-diagnostic') for details.