UPDATE: 2024-02-22 23:44:39.556861
このノートは「ベイズ統計」に関する何らかの内容をまとめ、ベイズ統計への理解を深めていくために作成している。今回は「社会科学のためのベイズ統計モデリング」の第8章を写経していく。基本的には気になった部分を写経しながら、ところどころ自分用の補足をメモすることで、自分用の補足資料になることを目指す。私の解釈がおかしく、メモが誤っている場合があるので注意。
今回はWAICのシュミレーションの数式とコードの対応を理解することを目的にする。ここでの真の分布は\(\lambda^*=3\)のポアソン分布とする。確率モデルにポアソン分布、事前分布に\(a=3,b=1\)のガンマ分布を考える。
library(tidyverse)
library(rstan)
library(loo)
options(max.print = 999999)
rstan_options(auto_write=TRUE)
options(mc.cores=parallel::detectCores())
lambda_prime <- 3
n <- 50
a <- 3
b <- 1
set.seed(1989)
x <- rpois(n, lambda_prime) #q(x)
hist(x, main = 'Poisson(λ=3)')
ここでは、カウントデータに対して、ガンマ-ポアソン分布を用いてモデル化を行い、WAICを計算する。ガンマ-ポアソン分布は共役関係から下記の通り事後分布が計算できる。これは\(a_n = \sum x_i + a, b_n = n + b\)のガンマ分布と一致する。
\[ \begin{eqnarray} X &\sim& Poisson(\lambda) \\ \lambda &\sim& Gamma(a,b) \\ p(\lambda|x^n) &=& \frac{b_{n}^{a_n}}{\Gamma(a_n)} \lambda^{a_{n} -1} e^{-b_{n} \lambda} \end{eqnarray} \]
事後分布を使った事後予測分布は、
\[ \begin{eqnarray} p^*(x) &=& \int_0^{\infty} P(x|\lambda) p(\lambda|x^n) d\lambda \\ &=& \int_0^{\infty} \frac{1}{x!}\lambda^x e^{-\lambda} \frac{b_{n}^{a_n}}{\Gamma(a_n)} \lambda^{a_{n} -1} e^{-b_{n} \lambda} d\lambda \\ &=& \frac{\Gamma(x + a_n)}{\Gamma(a_n)x!}\frac{b_{n}^{a_n}}{(1+b_n)^{x+a_n}} \end{eqnarray} \]
である。予測分布の情報量は、下記となる。
\[ \begin{eqnarray} -\log p^*(x) &=& -log \left( \frac{\Gamma(x + a_n)}{\Gamma(a_n)x!}\frac{b_{n}^{a_n}}{(1+b_n)^{x+a_n}}\right) \\ &=& -\log \Gamma(x+a_n) + \log\Gamma(a_n) + logx! - a_n \log b_n + (x+a_n) \log(1+b_n) \end{eqnarray} \]
予測分布の情報量に対応する関数は下記である。
# a_n = \sum x_i + a,
# b_n = n + b
# Γ(n+1) = n!
info_pred_dist <- function(x_d, x){
a_n <- sum(x)+a
b_n <- length(x)+b
return(-lgamma(x_d+a_n) + lgamma(a_n) + lgamma(x_d+1) - a_n*log(b_n) + (x_d+a_n)*log(1+b_n))
}
汎化損失\(G_n\)は真の分布として、\(\lambda^*=3\)のポアソン分布を仮定しているため、
\[ \begin{eqnarray} G_n &=& \int - \log p^*(x) q(x) dx \\ &=& \sum - \log p^*(x_i) Poisson(x_i|\lambda) \\ &=& \sum \frac{(\lambda^*)^{x_i} e^{-\lambda^*}}{x_i!} [-\log \Gamma(x_i + a_n) + \log (a_n) + log x_i! - a_n \log b_n + (x_i+a_n) \log(1+b_n)] \\ \end{eqnarray} \]
となる。
G_n <- function(lambda_prime,x){
temp <- 0
for(i in 0:20){
temp = temp + dpois(i, lambda_prime)*info_pred_dist(i,x)
}
return(temp)
}
次にWAICを計算するための数式と関数の対応を確認する。WAICの計算には、経験損失と汎関数分散が必要。経験損失は予測分布の情報量を使えば計算できる。
\[ \begin{eqnarray} T_n &=& \frac{1}{n} \sum -\log p^*(x_i) \\ &=& \frac{1}{n} \sum [-\log \Gamma(x_i + a_n) + \log (a_n) + log x_i! - a_n \log b_n + (x_i+a_n) \log(1+b_n)] \\ \end{eqnarray} \] 経験損失に対応する関数が下記である。
# info_pred_dist <- function(x_d, x){
# a_n <- sum(x)+a
# b_n <- length(x)+b
# return(-lgamma(x_d+a_n) + lgamma(a_n) + lgamma(x_d+1) - a_n*log(b_n) + (x_d+a_n)*log(1+b_n))
# }
T_n <- function(x){
temp <- c()
for(i in 1:length(x)){
temp[i] = info_pred_dist(x[i],x)
}
return(mean(temp))
}
汎関数分散については、
\[ \begin{eqnarray} V_n &=& \sum \left[ \int \log (Poisson(x_i|\lambda))^2 \underbrace{p(\lambda|x^n)}_{事後予測分布}d\lambda - \left( \int \log Poisson(x_i|\lambda) \underbrace{p(\lambda|x^n)}_{事後予測分布} d\lambda \right)^{2} \right] \end{eqnarray} \]
下記の数式と関数が対応する。
V_n <- function(x){
a_n <- sum(x)+a
b_n <- length(x)+b
V_comp <- function(x){
poisgamma2 <- function(lambda){
dpois(x,lambda,log=T)^2*dgamma(lambda,a_n,b_n)
}
poisgamma <- function(lambda){
dpois(x,lambda,log=T)*dgamma(lambda,a_n,b_n)
}
temp1 <- integrate(poisgamma2,0,Inf)$value
temp2 <- integrate(poisgamma,0,Inf)$value
return(temp1 - temp2^2)
}
temp <- c()
for(i in 1:length(x)){
temp[i] <- V_comp(x[i])
}
return(sum(temp))
}
経験損失と汎関数分散を組み合わせることでWAICを得る。
\[ \begin{eqnarray} WAIC &=& T_n + \frac{V_n}{n} \\ \end{eqnarray} \]
対応する関数は下記である。
WAIC <- function(x){
return(T_n(x) + V_n(x)/length(x))
}
ただ、解析的には計算できないので、MCMCのサンプルを利用して計算する。データ点ごとにMCMCサンプルから計算した対数尤度を用いる。MCMCサンプルが数が\(S\)個あり、\(s\)番目のパラメタ\(\theta\)の事後分布のMCMCサンプルを\(\theta_s\)とする。経験損失は下記である。\(()\)内は、尤度関数を事後分布で平均した値、つまり予測分布を近似している。
\[ \begin{eqnarray} T_n &\approx& \frac{1}{n} \sum_{i}^{N}\left[ -\log \left(\frac{1}{S} \sum^S p(x_i|\theta_s)\right) \right] \\ \end{eqnarray} \]
もう少し細かい部分をみておく。log_lik
はMCMCサンプル×\(x\)の長さのサイズ分のマトリックスで返される。
# dim(log_lik)
# [1] 4000 50
# Inference for Stan model: anon_model.
# 4 chains, each with iter=2000; warmup=1000; thin=1;
# post-warmup draws per chain=1000, total post-warmup draws=4000.
#
# mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
# lambda 2.98 0.01 0.24 2.51 2.81 2.97 3.13 3.48 1513 1
# log_lik[1] -1.80 0.00 0.09 -2.01 -1.85 -1.79 -1.74 -1.67 1456 1
# log_lik[2] -1.89 0.00 0.16 -2.24 -1.99 -1.88 -1.78 -1.59 1518 1
# log_lik[3] -2.98 0.01 0.24 -3.48 -3.13 -2.97 -2.81 -2.51 1513 1
# log_lik[4] -1.51 0.00 0.01 -1.55 -1.51 -1.50 -1.50 -1.50 1567 1
# log_lik[5] -2.98 0.01 0.24 -3.48 -3.13 -2.97 -2.81 -2.51 1513 1
# ///
# log_lik[46] -1.51 0.00 0.01 -1.55 -1.51 -1.50 -1.50 -1.50 1567 1
# log_lik[47] -2.98 0.01 0.24 -3.48 -3.13 -2.97 -2.81 -2.51 1513 1
# log_lik[48] -1.89 0.00 0.16 -2.24 -1.99 -1.88 -1.78 -1.59 1518 1
# log_lik[49] -1.51 0.00 0.01 -1.55 -1.51 -1.50 -1.50 -1.50 1567 1
# log_lik[50] -3.03 0.01 0.25 -3.57 -3.19 -3.02 -2.86 -2.57 1482 1
# lp__ 13.48 0.02 0.75 11.34 13.33 13.77 13.94 13.99 1600 1
対数尤度を指数変換して尤度に戻してから、列ごとに格納されている\(x\)の各対数尤度を使って、MCMCMサンプル分で割ることで、平均を計算している。後は対数を取ってから平均を計算する。
# colMeans(exp(log_lik))
# [1] 0.16509634 0.15311517 0.05246389 0.22178503 0.05246389 0.22494175 0.22178503 0.09896860 0.09896860
# [10] 0.22178503 0.22494175 0.22178503 0.22494175 0.22494175 0.16509634 0.22178503 0.15311517 0.04976527
# [19] 0.15311517 0.15311517 0.15311517 0.16509634 0.22494175 0.16509634 0.22494175 0.15311517 0.09896860
# [28] 0.22494175 0.22178503 0.22494175 0.09896860 0.22178503 0.09896860 0.22178503 0.04976527 0.09896860
# [37] 0.22178503 0.02158969 0.09896860 0.09896860 0.15311517 0.22494175 0.22494175 0.16509634 0.22178503
# [46] 0.22178503 0.05246389 0.15311517 0.22178503 0.04976527
#
# -log(colMeans(exp(log_lik)))
# [1] 1.801226 1.876565 2.947630 1.506047 2.947630 1.491914 1.506047 2.312953 2.312953 1.506047 1.491914 1.506047
# [13] 1.491914 1.491914 1.801226 1.506047 1.876565 3.000438 1.876565 1.876565 1.876565 1.801226 1.491914 1.801226
# [25] 1.491914 1.876565 2.312953 1.491914 1.506047 1.491914 2.312953 1.506047 2.312953 1.506047 3.000438 2.312953
# [37] 1.506047 3.835540 2.312953 2.312953 1.876565 1.491914 1.491914 1.801226 1.506047 1.506047 2.947630 1.876565
# [49] 1.506047 3.000438
#
# mean(-log(colMeans(exp(log_lik))))
# [1] 1.943874
また、汎関数分散の\([]\)内は、事後分布で平均した尤度尤度の分散を近似している。
\[ \begin{eqnarray} V_n &\approx& \sum_{i}^{N} \left[ \ \frac{1}{S-1} \sum_{s}^{S} \log p(x_i|\theta_s)^2 - \left( \frac{1}{S-1} \sum_{s}^{S} \log p(x_i|\theta_s) \right)^{2} \right] \end{eqnarray} \]
汎関数分散は、\(()\)内はゴツく見えるが分散の定義なので、素直に対数尤度を使って分散を計算した後に平均を計算している。
# V_n_divide_n
# apply(log_lik,2, var)
# [1] 0.007437887 0.026469827 0.059934105 0.000221718 0.059934105 0.006565698
# [7] 0.000221718 0.028214205 0.028214205 0.000221718 0.006565698 0.000221718
# [13] 0.006565698 0.006565698 0.007437887 0.000221718 0.026469827 0.062550673
# [19] 0.026469827 0.026469827 0.026469827 0.007437887 0.006565698 0.007437887
# [25] 0.006565698 0.026469827 0.028214205 0.006565698 0.000221718 0.006565698
# [31] 0.028214205 0.000221718 0.028214205 0.000221718 0.062550673 0.028214205
# [37] 0.000221718 0.110447289 0.028214205 0.028214205 0.026469827 0.006565698
# [43] 0.006565698 0.007437887 0.000221718 0.000221718 0.059934105 0.026469827
# [49] 0.000221718 0.062550673
# mean(apply(log_lik,2, var))
# [1] 0.02041762
これらをまとめた関数はこちら。
waic_mcmc <- function(log_lik){
T_n <- mean(-log(colMeans(exp(log_lik))))
V_n_divide_n <- mean(apply(log_lik,2,var))
waic <- T_n + V_n_divide_n
return(waic)
}
これをRで実装していく。
モデルはこちら。WAICを計算するためにpoisson_lpmf()
関数からのサンプリングを行っておく。
data{
int N;
int X[N];
real a;
real b;
}
parameters{
real<lower=0> lambda;
}
model{
for(i in 1:N){
X[i] ~ poisson(lambda);
}
lambda ~ gamma(a,b);
}
generated quantities{
real log_lik[N];
for(i in 1:N){
log_lik[i] = poisson_lpmf(X[i]|lambda);
}
}
先にコンパイルしてから、
model.waic <- stan_model('poisson_WAIC.stan')
sampling()
関数でサンプリングする。複数回のシュミレーションを行うことで、WAICが汎化損失のよい近似となるかを確認する。
sim_n <- 50
G_n_e <- WAIC_mcmc <- WAIC_e <- vector(mode = 'numeric', length = sim_n)
set.seed(1989)
for(i in 1:sim_n){
x <- rpois(n, lambda_prime)
WAIC_e[i] <- WAIC(x)
fit.waic <- sampling(model.waic, data = list(N = n, X = x, a = a, b = b))
log_lik <- rstan::extract(fit.waic)$log_lik
WAIC_mcmc[i] <- waic_mcmc(log_lik)
G_n_e[i] <- G_n(lambda_prime,x)
}
シュミレーションされた値の平均はこちら。
list(
WAIC_e = mean(WAIC_e),
WAIC_mcmc = mean(WAIC_mcmc),
G_n_e = mean(G_n_e)
)
## $WAIC_e
## [1] 1.941089
##
## $WAIC_mcmc
## [1] 1.940835
##
## $G_n_e
## [1] 1.937293
可視化しておく。今回のシュミレーションでは良い近似となっていることがわかる。
tibble(
WAIC = WAIC_e,
WAIC_mcmc = WAIC_mcmc,
G_n = G_n_e
) %>%
gather(key = "x", value = "y") %>%
mutate(x = forcats::fct_inorder(x)) %>%
ggplot() +
theme_bw(base_size = 18) +
geom_boxplot(aes(x, y, fill = x)) +
xlab('') + ylab('') + ylim(1.5, 2.5)
WAICが最小だからといいて、最良のモデル化はわからない点は注意が必要。あくまでも比較したモデルの中での相対的な良さである。
なんらかの確率モデルとデータから得た予測分布\(p^{*}(x)\)を考える。真の分布\(q(x)\)と予測分布\(p^{*}(x)\)の交差エントロピー\(H_q(p^{*})\)を汎化損失\(G_n\)と呼ぶ。
\[ \begin{eqnarray} G_n &=& - \mathbb{E}_{q(X)}[\log p^{*}(X)] \\ &=& - \int \log p^{*}(x) q(x) dx \\ &=& \underbrace{H(q)}_{真の分布のエントロピー} + \underbrace{D(q||p^{*})}_{真の分布と予測分布のKL情報量} \end{eqnarray} \] と定義する。
汎化損失\(G_n\)は\(q(x)\)のエントロピーにおいて、\(-logp^{*}(x)\)を仮定した情報量といえる。また、予測分布\(p^{*}(x)\)の導出に用いた\(x^{n}=(x_1,...,x_n)\)を入れて、相加平均を計算したものを、経験損失\(T_n\)と呼ぶ。
\[ \begin{eqnarray} T_n &=& - \frac{1}{n} \sum_{i}^{n} logp^{*}(x_i) \end{eqnarray} \]
経験損失\(T_n\)を使って、汎化損失\(G_n\)を推定する方法として、最尤推定予測分布についてのAIC、WAICを導入する。
ベイズモデルにおける汎化損失について考える。下記を仮定する。
ベイズ推定をもとにした予測分布は、確率モデルを事後分布によって期待値を取ったものとして、
\[ \begin{eqnarray} p^{*}(x) &=& \mathbb{E}_{p(\theta|x^n)}[p(x|\theta)] \\ &=& \int p(x|\theta) p(\theta|x^n) d\theta \end{eqnarray} \]
定義される。ベイズ予測分布についての汎化損失\(G_n\)は、
\[ \begin{eqnarray} G_n &=& -\mathbb{E}_q(X)[ \log p^*(X)] \\ &=& -\mathbb{E}_q(X)[ \log \mathbb{E}_{p(\theta|x^n)}[p(X|\theta)]] \\ &=& -\int q(x) \log \left( \int p(x|\theta) p(\theta|x^n) d\theta\right)dx \end{eqnarray} \]
で、ベイズ予測分布の経験損失\(T_n\)は、
\[ \begin{eqnarray} T_n &=& -\frac{1}{n} \sum \log p^*(x_i) \\ &=& -\frac{1}{n} \sum \log \mathbb{E}_{p(\theta|x^n)}[p(x_i|\theta)] \\ \end{eqnarray} \]
経験損失は何らかのバイアス\(b(X^n)\)が生じている可能性がある。このバイアスの期待値は、
\[ \begin{eqnarray} \mathbb{E}_{q(X^n)}[b(X^n)] &=& \mathbb{E}_{q(X^n)}[G_n - T_n] \\ \end{eqnarray} \]
であり、このバイアスの期待値は汎関数分散\(V_n\)を\(n\)で割った\(\frac{V_n}{n}\)と漸近的に一致する。
\[ \begin{eqnarray} V_n &=& \sum_{i}^{N} \left\{ \mathbb{E}_{p(\theta|x^n)} \left[ (\log p(x_i|\theta))^{2} \right] - \mathbb{E}_{p(\theta|x^n)} \left[ (\log p(x_i|\theta) \right]^{2} \right\} \\ \end{eqnarray} \]
以上よりWAICは、経験損失と汎関数分散を\(n\)で割った値の和として定義される。
\[ \begin{eqnarray} WAIC &=& T_n + \frac{V_n}{n} \\ &=& -\frac{1}{n} \sum \log \mathbb{E}_{p(\theta|x^n)}[p(x_i|\theta)] + \sum \left\{ \mathbb{E}_{p(\theta|x^n)} \left[ (\log p(x_i|\theta))^{2} \right] - \mathbb{E}_{p(\theta|x^n)} \left[ (\log p(x_i|\theta) \right]^{2} \right\} \\ \end{eqnarray} \]