UPDATE: 2024-02-22 23:44:39.556861

はじめに

このノートは「ベイズ統計」に関する何らかの内容をまとめ、ベイズ統計への理解を深めていくために作成している。今回は「社会科学のためのベイズ統計モデリング」の第8章を写経していく。基本的には気になった部分を写経しながら、ところどころ自分用の補足をメモすることで、自分用の補足資料になることを目指す。私の解釈がおかしく、メモが誤っている場合があるので注意。

今回はWAICのシュミレーションの数式とコードの対応を理解することを目的にする。ここでの真の分布は\(\lambda^*=3\)のポアソン分布とする。確率モデルにポアソン分布、事前分布に\(a=3,b=1\)のガンマ分布を考える。

library(tidyverse)
library(rstan)
library(loo)

options(max.print = 999999)
rstan_options(auto_write=TRUE)
options(mc.cores=parallel::detectCores())

lambda_prime <- 3
n <- 50
a <- 3
b <- 1
set.seed(1989)
x <- rpois(n, lambda_prime) #q(x)
hist(x, main = 'Poisson(λ=3)')

WAICを計算する

ここでは、カウントデータに対して、ガンマ-ポアソン分布を用いてモデル化を行い、WAICを計算する。ガンマ-ポアソン分布は共役関係から下記の通り事後分布が計算できる。これは\(a_n = \sum x_i + a, b_n = n + b\)のガンマ分布と一致する。

\[ \begin{eqnarray} X &\sim& Poisson(\lambda) \\ \lambda &\sim& Gamma(a,b) \\ p(\lambda|x^n) &=& \frac{b_{n}^{a_n}}{\Gamma(a_n)} \lambda^{a_{n} -1} e^{-b_{n} \lambda} \end{eqnarray} \]

事後分布を使った事後予測分布は、

\[ \begin{eqnarray} p^*(x) &=& \int_0^{\infty} P(x|\lambda) p(\lambda|x^n) d\lambda \\ &=& \int_0^{\infty} \frac{1}{x!}\lambda^x e^{-\lambda} \frac{b_{n}^{a_n}}{\Gamma(a_n)} \lambda^{a_{n} -1} e^{-b_{n} \lambda} d\lambda \\ &=& \frac{\Gamma(x + a_n)}{\Gamma(a_n)x!}\frac{b_{n}^{a_n}}{(1+b_n)^{x+a_n}} \end{eqnarray} \]

である。予測分布の情報量は、下記となる。

\[ \begin{eqnarray} -\log p^*(x) &=& -log \left( \frac{\Gamma(x + a_n)}{\Gamma(a_n)x!}\frac{b_{n}^{a_n}}{(1+b_n)^{x+a_n}}\right) \\ &=& -\log \Gamma(x+a_n) + \log\Gamma(a_n) + logx! - a_n \log b_n + (x+a_n) \log(1+b_n) \end{eqnarray} \]

予測分布の情報量に対応する関数は下記である。

# a_n = \sum x_i + a, 
# b_n = n + b
# Γ(n+1) = n!
info_pred_dist <- function(x_d, x){
  a_n <- sum(x)+a
  b_n <- length(x)+b
  return(-lgamma(x_d+a_n) + lgamma(a_n) + lgamma(x_d+1) - a_n*log(b_n) + (x_d+a_n)*log(1+b_n))
}

汎化損失\(G_n\)は真の分布として、\(\lambda^*=3\)のポアソン分布を仮定しているため、

\[ \begin{eqnarray} G_n &=& \int - \log p^*(x) q(x) dx \\ &=& \sum - \log p^*(x_i) Poisson(x_i|\lambda) \\ &=& \sum \frac{(\lambda^*)^{x_i} e^{-\lambda^*}}{x_i!} [-\log \Gamma(x_i + a_n) + \log (a_n) + log x_i! - a_n \log b_n + (x_i+a_n) \log(1+b_n)] \\ \end{eqnarray} \]

となる。

G_n <- function(lambda_prime,x){
  temp <- 0
  for(i in 0:20){
    temp = temp + dpois(i, lambda_prime)*info_pred_dist(i,x)
  }
  return(temp)
}

次にWAICを計算するための数式と関数の対応を確認する。WAICの計算には、経験損失と汎関数分散が必要。経験損失は予測分布の情報量を使えば計算できる。

\[ \begin{eqnarray} T_n &=& \frac{1}{n} \sum -\log p^*(x_i) \\ &=& \frac{1}{n} \sum [-\log \Gamma(x_i + a_n) + \log (a_n) + log x_i! - a_n \log b_n + (x_i+a_n) \log(1+b_n)] \\ \end{eqnarray} \] 経験損失に対応する関数が下記である。

# info_pred_dist <- function(x_d, x){
#   a_n <- sum(x)+a
#   b_n <- length(x)+b
#   return(-lgamma(x_d+a_n) + lgamma(a_n) + lgamma(x_d+1) - a_n*log(b_n) + (x_d+a_n)*log(1+b_n))
# }
T_n <- function(x){
  temp <- c()
  for(i in 1:length(x)){
    temp[i] = info_pred_dist(x[i],x)
  }
  return(mean(temp))
}

汎関数分散については、

\[ \begin{eqnarray} V_n &=& \sum \left[ \int \log (Poisson(x_i|\lambda))^2 \underbrace{p(\lambda|x^n)}_{事後予測分布}d\lambda - \left( \int \log Poisson(x_i|\lambda) \underbrace{p(\lambda|x^n)}_{事後予測分布} d\lambda \right)^{2} \right] \end{eqnarray} \]

下記の数式と関数が対応する。

V_n <- function(x){
  a_n <- sum(x)+a
  b_n <- length(x)+b
  
  V_comp <- function(x){
    poisgamma2 <- function(lambda){
      dpois(x,lambda,log=T)^2*dgamma(lambda,a_n,b_n)
    }
    
    poisgamma <- function(lambda){
      dpois(x,lambda,log=T)*dgamma(lambda,a_n,b_n)
    }
    
    temp1 <- integrate(poisgamma2,0,Inf)$value
    temp2 <- integrate(poisgamma,0,Inf)$value
    
    return(temp1 - temp2^2)
  }
  
  temp <- c()
  for(i in 1:length(x)){
    temp[i] <- V_comp(x[i])
  }
  return(sum(temp))
}

経験損失と汎関数分散を組み合わせることでWAICを得る。

\[ \begin{eqnarray} WAIC &=& T_n + \frac{V_n}{n} \\ \end{eqnarray} \]

対応する関数は下記である。

WAIC <- function(x){
  return(T_n(x) + V_n(x)/length(x))
}

ただ、解析的には計算できないので、MCMCのサンプルを利用して計算する。データ点ごとにMCMCサンプルから計算した対数尤度を用いる。MCMCサンプルが数が\(S\)個あり、\(s\)番目のパラメタ\(\theta\)の事後分布のMCMCサンプルを\(\theta_s\)とする。経験損失は下記である。\(()\)内は、尤度関数を事後分布で平均した値、つまり予測分布を近似している。

\[ \begin{eqnarray} T_n &\approx& \frac{1}{n} \sum_{i}^{N}\left[ -\log \left(\frac{1}{S} \sum^S p(x_i|\theta_s)\right) \right] \\ \end{eqnarray} \]

もう少し細かい部分をみておく。log_likはMCMCサンプル×\(x\)の長さのサイズ分のマトリックスで返される。

# dim(log_lik)
# [1] 4000   50
# Inference for Stan model: anon_model.
# 4 chains, each with iter=2000; warmup=1000; thin=1; 
# post-warmup draws per chain=1000, total post-warmup draws=4000.
# 
#              mean se_mean   sd  2.5%   25%   50%   75% 97.5% n_eff Rhat
# lambda       2.98    0.01 0.24  2.51  2.81  2.97  3.13  3.48  1513    1
# log_lik[1]  -1.80    0.00 0.09 -2.01 -1.85 -1.79 -1.74 -1.67  1456    1
# log_lik[2]  -1.89    0.00 0.16 -2.24 -1.99 -1.88 -1.78 -1.59  1518    1
# log_lik[3]  -2.98    0.01 0.24 -3.48 -3.13 -2.97 -2.81 -2.51  1513    1
# log_lik[4]  -1.51    0.00 0.01 -1.55 -1.51 -1.50 -1.50 -1.50  1567    1
# log_lik[5]  -2.98    0.01 0.24 -3.48 -3.13 -2.97 -2.81 -2.51  1513    1
# ///
# log_lik[46] -1.51    0.00 0.01 -1.55 -1.51 -1.50 -1.50 -1.50  1567    1
# log_lik[47] -2.98    0.01 0.24 -3.48 -3.13 -2.97 -2.81 -2.51  1513    1
# log_lik[48] -1.89    0.00 0.16 -2.24 -1.99 -1.88 -1.78 -1.59  1518    1
# log_lik[49] -1.51    0.00 0.01 -1.55 -1.51 -1.50 -1.50 -1.50  1567    1
# log_lik[50] -3.03    0.01 0.25 -3.57 -3.19 -3.02 -2.86 -2.57  1482    1
# lp__        13.48    0.02 0.75 11.34 13.33 13.77 13.94 13.99  1600    1

対数尤度を指数変換して尤度に戻してから、列ごとに格納されている\(x\)の各対数尤度を使って、MCMCMサンプル分で割ることで、平均を計算している。後は対数を取ってから平均を計算する。

# colMeans(exp(log_lik))
#  [1] 0.16509634 0.15311517 0.05246389 0.22178503 0.05246389 0.22494175 0.22178503 0.09896860 0.09896860
# [10] 0.22178503 0.22494175 0.22178503 0.22494175 0.22494175 0.16509634 0.22178503 0.15311517 0.04976527
# [19] 0.15311517 0.15311517 0.15311517 0.16509634 0.22494175 0.16509634 0.22494175 0.15311517 0.09896860
# [28] 0.22494175 0.22178503 0.22494175 0.09896860 0.22178503 0.09896860 0.22178503 0.04976527 0.09896860
# [37] 0.22178503 0.02158969 0.09896860 0.09896860 0.15311517 0.22494175 0.22494175 0.16509634 0.22178503
# [46] 0.22178503 0.05246389 0.15311517 0.22178503 0.04976527
# 
# -log(colMeans(exp(log_lik)))
#  [1] 1.801226 1.876565 2.947630 1.506047 2.947630 1.491914 1.506047 2.312953 2.312953 1.506047 1.491914 1.506047
# [13] 1.491914 1.491914 1.801226 1.506047 1.876565 3.000438 1.876565 1.876565 1.876565 1.801226 1.491914 1.801226
# [25] 1.491914 1.876565 2.312953 1.491914 1.506047 1.491914 2.312953 1.506047 2.312953 1.506047 3.000438 2.312953
# [37] 1.506047 3.835540 2.312953 2.312953 1.876565 1.491914 1.491914 1.801226 1.506047 1.506047 2.947630 1.876565
# [49] 1.506047 3.000438
#
# mean(-log(colMeans(exp(log_lik))))
# [1] 1.943874

また、汎関数分散の\([]\)内は、事後分布で平均した尤度尤度の分散を近似している。

\[ \begin{eqnarray} V_n &\approx& \sum_{i}^{N} \left[ \ \frac{1}{S-1} \sum_{s}^{S} \log p(x_i|\theta_s)^2 - \left( \frac{1}{S-1} \sum_{s}^{S} \log p(x_i|\theta_s) \right)^{2} \right] \end{eqnarray} \]

汎関数分散は、\(()\)内はゴツく見えるが分散の定義なので、素直に対数尤度を使って分散を計算した後に平均を計算している。

# V_n_divide_n 
# apply(log_lik,2, var)
#  [1] 0.007437887 0.026469827 0.059934105 0.000221718 0.059934105 0.006565698
#  [7] 0.000221718 0.028214205 0.028214205 0.000221718 0.006565698 0.000221718
# [13] 0.006565698 0.006565698 0.007437887 0.000221718 0.026469827 0.062550673
# [19] 0.026469827 0.026469827 0.026469827 0.007437887 0.006565698 0.007437887
# [25] 0.006565698 0.026469827 0.028214205 0.006565698 0.000221718 0.006565698
# [31] 0.028214205 0.000221718 0.028214205 0.000221718 0.062550673 0.028214205
# [37] 0.000221718 0.110447289 0.028214205 0.028214205 0.026469827 0.006565698
# [43] 0.006565698 0.007437887 0.000221718 0.000221718 0.059934105 0.026469827
# [49] 0.000221718 0.062550673

# mean(apply(log_lik,2, var))
# [1] 0.02041762

これらをまとめた関数はこちら。

waic_mcmc <- function(log_lik){
  T_n <- mean(-log(colMeans(exp(log_lik))))
  V_n_divide_n <- mean(apply(log_lik,2,var))
  waic <- T_n + V_n_divide_n
  return(waic)
}

これをRで実装していく。

モデルはこちら。WAICを計算するためにpoisson_lpmf()関数からのサンプリングを行っておく。

data{
  int N;
  int X[N];
  real a;
  real b;
}

parameters{
  real<lower=0> lambda;
}

model{
  for(i in 1:N){
    X[i] ~ poisson(lambda);
  }
  lambda ~ gamma(a,b);
}

generated quantities{
  real log_lik[N];
  for(i in 1:N){
    log_lik[i] = poisson_lpmf(X[i]|lambda);
  }
}

先にコンパイルしてから、

model.waic <- stan_model('poisson_WAIC.stan')

sampling()関数でサンプリングする。複数回のシュミレーションを行うことで、WAICが汎化損失のよい近似となるかを確認する。

sim_n <- 50
G_n_e <- WAIC_mcmc <- WAIC_e <- vector(mode = 'numeric', length = sim_n)
set.seed(1989)

for(i in 1:sim_n){
  x <- rpois(n, lambda_prime)
  WAIC_e[i] <- WAIC(x)
  fit.waic <- sampling(model.waic, data = list(N = n, X = x, a = a, b = b))
  log_lik <- rstan::extract(fit.waic)$log_lik
  WAIC_mcmc[i] <- waic_mcmc(log_lik)
  G_n_e[i] <- G_n(lambda_prime,x)
}

シュミレーションされた値の平均はこちら。

list(
  WAIC_e = mean(WAIC_e),
  WAIC_mcmc = mean(WAIC_mcmc),
  G_n_e = mean(G_n_e)
)
## $WAIC_e
## [1] 1.941089
## 
## $WAIC_mcmc
## [1] 1.940835
## 
## $G_n_e
## [1] 1.937293

可視化しておく。今回のシュミレーションでは良い近似となっていることがわかる。

tibble(
  WAIC = WAIC_e,
  WAIC_mcmc = WAIC_mcmc,
  G_n = G_n_e
  ) %>%
  gather(key = "x", value = "y") %>%
  mutate(x = forcats::fct_inorder(x)) %>%
  ggplot() +
  theme_bw(base_size = 18) +
  geom_boxplot(aes(x, y, fill = x)) +
  xlab('') + ylab('') + ylim(1.5, 2.5)

WAICが最小だからといいて、最良のモデル化はわからない点は注意が必要。あくまでも比較したモデルの中での相対的な良さである。

汎化損失とWAIC

なんらかの確率モデルとデータから得た予測分布\(p^{*}(x)\)を考える。真の分布\(q(x)\)と予測分布\(p^{*}(x)\)の交差エントロピー\(H_q(p^{*})\)を汎化損失\(G_n\)と呼ぶ。

汎化損失

\[ \begin{eqnarray} G_n &=& - \mathbb{E}_{q(X)}[\log p^{*}(X)] \\ &=& - \int \log p^{*}(x) q(x) dx \\ &=& \underbrace{H(q)}_{真の分布のエントロピー} + \underbrace{D(q||p^{*})}_{真の分布と予測分布のKL情報量} \end{eqnarray} \] と定義する。

汎化損失\(G_n\)\(q(x)\)のエントロピーにおいて、\(-logp^{*}(x)\)を仮定した情報量といえる。また、予測分布\(p^{*}(x)\)の導出に用いた\(x^{n}=(x_1,...,x_n)\)を入れて、相加平均を計算したものを、経験損失\(T_n\)と呼ぶ。

\[ \begin{eqnarray} T_n &=& - \frac{1}{n} \sum_{i}^{n} logp^{*}(x_i) \end{eqnarray} \]

経験損失\(T_n\)を使って、汎化損失\(G_n\)を推定する方法として、最尤推定予測分布についてのAIC、WAICを導入する。

6.9 WAIC

ベイズモデルにおける汎化損失について考える。下記を仮定する。

  • 真の確率分布\(q(x)\)は想定される確率モデルの中に真の分布が含まれてなくてもよい。
  • サンプル\(X^n \sim q(x^n) = \prod q(x)\)は、独立同一の分布を仮定する。
  • 確率モデル\(p(x^n|\theta)\)
  • パラメタの事前分布\(\varphi(\theta)\)
  • \(q(x)\)は確率について正則でなくてもよい

ベイズ推定をもとにした予測分布は、確率モデルを事後分布によって期待値を取ったものとして、

\[ \begin{eqnarray} p^{*}(x) &=& \mathbb{E}_{p(\theta|x^n)}[p(x|\theta)] \\ &=& \int p(x|\theta) p(\theta|x^n) d\theta \end{eqnarray} \]

定義される。ベイズ予測分布についての汎化損失\(G_n\)は、

\[ \begin{eqnarray} G_n &=& -\mathbb{E}_q(X)[ \log p^*(X)] \\ &=& -\mathbb{E}_q(X)[ \log \mathbb{E}_{p(\theta|x^n)}[p(X|\theta)]] \\ &=& -\int q(x) \log \left( \int p(x|\theta) p(\theta|x^n) d\theta\right)dx \end{eqnarray} \]

で、ベイズ予測分布の経験損失\(T_n\)は、

\[ \begin{eqnarray} T_n &=& -\frac{1}{n} \sum \log p^*(x_i) \\ &=& -\frac{1}{n} \sum \log \mathbb{E}_{p(\theta|x^n)}[p(x_i|\theta)] \\ \end{eqnarray} \]

経験損失は何らかのバイアス\(b(X^n)\)が生じている可能性がある。このバイアスの期待値は、

\[ \begin{eqnarray} \mathbb{E}_{q(X^n)}[b(X^n)] &=& \mathbb{E}_{q(X^n)}[G_n - T_n] \\ \end{eqnarray} \]

であり、このバイアスの期待値は汎関数分散\(V_n\)\(n\)で割った\(\frac{V_n}{n}\)と漸近的に一致する。

\[ \begin{eqnarray} V_n &=& \sum_{i}^{N} \left\{ \mathbb{E}_{p(\theta|x^n)} \left[ (\log p(x_i|\theta))^{2} \right] - \mathbb{E}_{p(\theta|x^n)} \left[ (\log p(x_i|\theta) \right]^{2} \right\} \\ \end{eqnarray} \]

以上よりWAICは、経験損失と汎関数分散を\(n\)で割った値の和として定義される。

\[ \begin{eqnarray} WAIC &=& T_n + \frac{V_n}{n} \\ &=& -\frac{1}{n} \sum \log \mathbb{E}_{p(\theta|x^n)}[p(x_i|\theta)] + \sum \left\{ \mathbb{E}_{p(\theta|x^n)} \left[ (\log p(x_i|\theta))^{2} \right] - \mathbb{E}_{p(\theta|x^n)} \left[ (\log p(x_i|\theta) \right]^{2} \right\} \\ \end{eqnarray} \]