UPDATE: 2023-12-31 12:21:19.7456

はじめに

このノートは「StanとRでベイズ統計モデリング」の内容を写経することで、ベイズ統計への理解を深めていくために作成している。

基本的には気になった部分を写経しながら、ところどころ自分用の補足をメモすることで、「StanとRでベイズ統計モデリング」を読み進めるための自分用の補足資料になることを目指す。私の解釈がおかしく、メモが誤っている場合があるので注意。

今回は少し寄り道をして、緑本でおなじみ「データ解析のための統計モデリング入門」の第10章を写経していく。次回以降、「StanとRでベイズ統計モデリング」では階層ベイズモデルのチャプターに突入するので、「データ解析のための統計モデリング入門」の階層ベイズモデルも合わせて学習しておきたい。

10.1 例題:個体差と生存種子数(個体差あり)

ここで使用するデータは架空の植物の種子の生存確率のデータ。調査個体数は\(n=100\)で、各固体から8個の種子を取り出し、生存している種子数\(y_{i}\)を記録している。つまり、各個体\(i\)のデータは「8個の種子のうち\(y_{i}\)個がパラメタ\(q_{i}\)の確率で生存」しており、各個体は独立に生成されると仮定する。

library(tidyverse)
library(rstan)

d <- read_csv('~/Desktop/data7a.csv')
d
## # A tibble: 100 × 2
##       id     y
##    <dbl> <dbl>
##  1     1     0
##  2     2     2
##  3     3     7
##  4     4     8
##  5     5     1
##  6     6     7
##  7     7     8
##  8     8     8
##  9     9     1
## 10    10     1
## # ℹ 90 more rows

ただ、観測された個体(ブラック)ごとの個体差が大きく、過分散が発生し、二項分布(グレー)で近似できるはずが、そのようになってはいない。つまり、100個体を共通の生存確率\(q=0.504\)では、データを説明できない。

# 二項分布のデータを生成
binom_df <- tibble(x = 0:8, y = dbinom(0:8, 8, 0.504)*length(d$y))

d %>% 
  group_by(y) %>% 
  count() %>% 
  ggplot(., aes(y, n)) + 
  geom_point(size = 3) + 
  geom_path() + 
  geom_point(data = binom_df, aes(x, y), col = 'gray', size = 3, alpha = 1/2) + 
  geom_path(data = binom_df, aes(x, y), col = 'gray', alpha = 1/2) + 
  labs(x = '生存種子数 y_i', y = '観測された個体数', title = '生存種子別の個体数') + 
  theme_bw()

10.2 GLMMの階層ベイズモデル化

種子生存確率\(q\)が全固体で共通していると仮定する統計モデルでは、今回のデータは説明できない。そのため、個体差を取り込んだGLMMなどで分析する必要が出てくる。まずはリンク関数と線形予測子のメカニズムを考える。リンク関数と線形予測子は\(logit(q_{i})= \beta + r_{i}\)で切片\(\beta\)は全固体共通のパラメタで、個体差は\(r_{i}\)で表す。個体差\(r_{i}\)\(r_{i} \sim Normal(0, s)\)に従うとする。尤度は、

\[ p(\boldsymbol{Y}| \beta, r_{i}) = \prod_{i=1}^{n} \binom{8}{y_{i}} q_{i}^{y_{i}}(1 - q_{i})^{8-y_{i}} \]

となり、切片\(\beta\)の事前分布は平均0、標準偏差100の無情報事前分布を仮定する。

\[ p(\beta) = \frac{1}{\sqrt{2\pi \cdot 100^2}} \exp \left[ \frac{-\beta^{2}}{2 \cdot 100^2} \right] \]

そして、個体差\(r_{i}\)のパラメタの事前分布も必要になるので、平均0、標準偏差\(s\)の無情報事前分布を仮定する。標準偏差\(s\)は個体差を表す100個体の\(r_{i}\)がどのくらいばらつくのかをコントロールするパラメタとなる。

\[ p(r_{i}|s) = \frac{1}{\sqrt{2\pi s^2}} \exp \left[ \frac{-r^{2}}{2s^2} \right] \] さらに\(s\)の事前分布には、正であればよいので、無情報事前分布を仮定する。

\[ p(s) \sim unif(0, 10^4) \]

このような形で個体差\(r_{i}\)の事前分布\(p(r_{i}|s)\)の形を決める\(s\)という未知パラメタがあって、さらに\(s\)についても事前分布\(p(s)\)が設定されている時、\(p(r_{i}|s)\)を階層事前分布と呼ぶ。\(s\)は超パラメタ、\(p(s)\)は超事前分布と呼ばれる。

まとめると下記の図の様になる。画像はこちらよりお借りした。

最終的に事後分布は下記となり、これをMCMCでサンプリングすることで事後分布を推定する。ベイズモデルでは、推定したいパラメタは、事前分布とデータに基づいて事後分布が生成される。

\[ \overbrace{p(\beta, s, r_{i}|\boldsymbol{Y})}^{Posterior} \propto \overbrace{p(\boldsymbol{Y}| \beta, r_{i})}^{Likelihood} \overbrace{ p(\beta) \prod_{i=1}^{n}p(r_{i}|s)p(s)}^{Prior} \]

10.3.1 階層ベイズモデルのMCMCサンプリング

ここではパラメタの事後分布を推定する。まずはStanにわたすためにデータをリスト形式に変換する。

data <- list(
  N = nrow(d),
  Y = d$y
)

階層ベイズモデルのモデル式は下記の通り。

data {
  int<lower=0> N;     // sample size
  int<lower=0> Y[N];  // response variable
}
parameters {
  real beta;
  real r[N];
  real<lower=0> sigma;
}
transformed parameters {
  real q[N];

  for (i in 1:N) {
    q[i] <- inv_logit(beta + r[i]);
  }
}
model {
  for (i in 1:N) {
        Y[i] ~ binomial(8, q[i]); // binom
  }
  
  for (n in 1:N){
    r[n] ~ normal(0, sigma);
  }
  
  beta ~ normal(0, 100);      // non-informative prior
  sigma ~ uniform(0, 10000);  // non-informative prior
}

// generated quantities {
//   real y_pred[N];
//   for (i in 1:N)
//     y_pred[i] = binomial_rng(8, q[i]);
// }


//
// ベクトル化させたものでも同じはず
// data {
//   int<lower=0> N;
//   int y[N];
// }
// 
// parameters {
//   real beta;           
//   vector[N] r;      
//   real<lower=0> sigma;  
// }
// 
// model {
//   y ~ binomial(8, inv_logit(beta + r));
//   beta ~ normal(0, 100);
//   r ~ normal(0, sigma);
//   sigma ~ uniform(0, 10000);
// }

ここでは、stan_model()関数で最初にコンパイルしておいてから、

model1031 <- stan_model('note_ahirubayes05.stan')

sampling()関数でサンプリングする。

fit <- sampling(object = model1031, data = data, seed = 1989)

推定結果はこちら。各個体の個体差\(r_{i}\)や生存確率\(q_{i}\)の事後分布が推定されている。

print(fit, prob = c(0.025, 0.5, 0.975), digits_summary = 2)
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##           mean se_mean   sd    2.5%     50%   97.5% n_eff Rhat
## beta      0.06    0.01 0.34   -0.59    0.06    0.73   765 1.01
## r[1]     -3.90    0.03 1.77   -7.94   -3.66   -1.14  3241 1.00
## r[2]     -1.24    0.02 0.89   -3.13   -1.20    0.40  3132 1.00
## r[3]      1.97    0.02 1.11   -0.01    1.89    4.43  3702 1.00
## r[4]      3.75    0.03 1.74    0.96    3.52    7.80  3996 1.00
## r[5]     -2.10    0.02 1.10   -4.53   -2.01   -0.16  3307 1.00
## r[6]      1.96    0.02 1.12    0.04    1.86    4.45  4229 1.00
## r[7]      3.77    0.03 1.75    0.95    3.53    7.73  4053 1.00
## r[8]      3.80    0.03 1.78    1.03    3.56    8.10  2972 1.00
## r[9]     -2.11    0.02 1.12   -4.68   -1.99   -0.18  3729 1.00
## r[10]    -2.09    0.02 1.09   -4.42   -2.00   -0.22  3675 1.00
## r[11]    -0.07    0.01 0.80   -1.67   -0.06    1.47  2953 1.00
## r[12]    -3.89    0.03 1.79   -8.10   -3.62   -1.12  3013 1.00
## r[13]    -2.10    0.02 1.09   -4.49   -2.01   -0.21  3389 1.00
## r[14]    -0.07    0.02 0.80   -1.63   -0.08    1.54  2768 1.00
## r[15]     1.96    0.02 1.12    0.06    1.85    4.45  3533 1.00
## r[16]     3.79    0.03 1.82    0.97    3.51    8.11  3851 1.00
## r[17]     1.97    0.02 1.08    0.10    1.88    4.38  4379 1.00
## r[18]    -3.86    0.03 1.76   -7.97   -3.63   -1.04  4206 1.00
## r[19]    -1.22    0.02 0.89   -3.13   -1.18    0.44  3438 1.00
## r[20]    -1.22    0.02 0.89   -3.06   -1.18    0.44  3317 1.00
## r[21]    -2.09    0.02 1.09   -4.46   -2.03   -0.21  3278 1.00
## r[22]    -2.10    0.02 1.13   -4.62   -2.01   -0.19  3479 1.00
## r[23]     0.47    0.01 0.79   -1.04    0.45    2.14  2949 1.00
## r[24]     1.98    0.02 1.10    0.08    1.88    4.33  3807 1.00
## r[25]     3.79    0.03 1.75    1.09    3.55    7.88  3710 1.00
## r[26]    -0.60    0.02 0.81   -2.22   -0.60    0.99  2735 1.00
## r[27]     3.75    0.03 1.72    1.01    3.54    7.82  3991 1.00
## r[28]    -0.06    0.02 0.78   -1.61   -0.06    1.47  2670 1.00
## r[29]     1.11    0.02 0.91   -0.53    1.06    3.04  3262 1.00
## r[30]    -3.86    0.03 1.75   -7.94   -3.63   -0.99  3980 1.00
## r[31]     3.81    0.03 1.77    1.00    3.55    7.94  3471 1.00
## r[32]     1.98    0.02 1.09    0.12    1.89    4.35  3264 1.00
## r[33]     3.79    0.03 1.79    0.95    3.56    8.19  3500 1.00
## r[34]    -3.86    0.03 1.78   -7.97   -3.62   -1.05  3617 1.00
## r[35]    -1.24    0.02 0.91   -3.20   -1.19    0.42  2928 1.00
## r[36]     1.10    0.02 0.91   -0.53    1.05    2.97  3065 1.00
## r[37]     1.98    0.02 1.09    0.10    1.88    4.34  4703 1.00
## r[38]     3.84    0.03 1.80    1.03    3.57    8.21  2962 1.00
## r[39]    -1.22    0.02 0.87   -3.07   -1.19    0.37  3329 1.00
## r[40]    -2.10    0.02 1.09   -4.52   -1.98   -0.25  3442 1.00
## r[41]    -2.13    0.02 1.15   -4.70   -2.01   -0.11  4403 1.00
## r[42]    -3.89    0.03 1.78   -8.01   -3.63   -1.08  3251 1.00
## r[43]    -3.86    0.03 1.76   -8.07   -3.62   -1.12  3279 1.00
## r[44]     1.97    0.02 1.12    0.03    1.88    4.49  4045 1.00
## r[45]     3.80    0.03 1.79    1.04    3.59    7.97  3886 1.00
## r[46]     0.48    0.01 0.82   -1.09    0.46    2.15  3037 1.00
## r[47]     1.99    0.02 1.14    0.08    1.91    4.60  3342 1.00
## r[48]    -1.22    0.02 0.86   -3.04   -1.18    0.40  3076 1.00
## r[49]     3.75    0.03 1.73    1.00    3.53    7.71  3638 1.00
## r[50]    -2.12    0.02 1.14   -4.59   -2.02   -0.17  3650 1.00
## r[51]     0.48    0.02 0.83   -1.11    0.45    2.18  2873 1.00
## r[52]     2.00    0.02 1.14    0.05    1.88    4.64  2895 1.00
## r[53]    -0.62    0.02 0.82   -2.26   -0.60    0.99  2851 1.00
## r[54]     3.77    0.03 1.72    1.04    3.53    7.72  3597 1.00
## r[55]    -3.86    0.03 1.72   -7.89   -3.65   -1.10  3861 1.00
## r[56]     3.78    0.03 1.80    0.99    3.50    8.03  3635 1.00
## r[57]     1.10    0.02 0.87   -0.53    1.05    2.93  3100 1.00
## r[58]    -0.60    0.02 0.83   -2.32   -0.58    0.98  3038 1.00
## r[59]    -1.21    0.02 0.90   -3.11   -1.18    0.47  3285 1.00
## r[60]    -3.85    0.03 1.74   -7.90   -3.59   -1.17  3315 1.00
## r[61]    -3.92    0.03 1.80   -8.06   -3.67   -1.10  4049 1.00
## r[62]    -2.10    0.02 1.09   -4.45   -2.02   -0.22  3235 1.00
## r[63]    -1.22    0.02 0.90   -3.13   -1.18    0.49  3423 1.00
## r[64]     3.72    0.03 1.72    0.99    3.48    7.66  3653 1.00
## r[65]     1.98    0.02 1.08    0.07    1.89    4.33  3637 1.00
## r[66]     1.96    0.02 1.08    0.11    1.87    4.32  3633 1.00
## r[67]     1.99    0.02 1.15    0.04    1.86    4.56  3930 1.00
## r[68]     3.81    0.03 1.78    1.01    3.57    8.02  3773 1.00
## r[69]    -3.90    0.03 1.75   -7.96   -3.64   -1.08  3116 1.00
## r[70]    -3.88    0.03 1.80   -8.07   -3.63   -1.05  3706 1.00
## r[71]    -3.90    0.03 1.78   -8.13   -3.69   -1.12  3568 1.00
## r[72]     0.49    0.01 0.81   -1.08    0.48    2.15  3055 1.00
## r[73]    -2.10    0.02 1.14   -4.65   -1.98   -0.16  3438 1.00
## r[74]    -3.85    0.03 1.77   -8.10   -3.64   -1.01  3819 1.00
## r[75]    -3.87    0.03 1.76   -7.97   -3.64   -1.09  3629 1.00
## r[76]    -3.89    0.03 1.76   -7.94   -3.66   -1.09  3511 1.00
## r[77]     3.77    0.03 1.72    1.13    3.51    7.72  2857 1.00
## r[78]    -2.10    0.02 1.10   -4.49   -2.01   -0.19  3716 1.00
## r[79]     3.75    0.03 1.75    0.97    3.52    7.77  4538 1.00
## r[80]    -0.06    0.01 0.79   -1.63   -0.06    1.49  2801 1.00
## r[81]     1.98    0.02 1.09    0.12    1.88    4.38  3084 1.00
## r[82]    -1.23    0.02 0.91   -3.13   -1.18    0.45  3327 1.00
## r[83]    -2.14    0.02 1.14   -4.76   -2.03   -0.19  2373 1.00
## r[84]    -0.07    0.01 0.79   -1.62   -0.06    1.46  3007 1.00
## r[85]     1.99    0.02 1.11    0.08    1.89    4.47  3039 1.00
## r[86]    -3.89    0.03 1.76   -7.76   -3.67   -1.09  3604 1.00
## r[87]     3.76    0.03 1.76    0.97    3.52    7.72  3437 1.00
## r[88]    -2.10    0.02 1.08   -4.51   -2.02   -0.19  3619 1.00
## r[89]     3.78    0.03 1.75    1.02    3.54    7.90  4066 1.00
## r[90]     1.98    0.02 1.10    0.10    1.90    4.38  3040 1.00
## r[91]     1.10    0.01 0.93   -0.55    1.05    3.05  3907 1.00
## r[92]    -1.23    0.02 0.90   -3.13   -1.18    0.41  3098 1.00
## r[93]     3.78    0.03 1.76    0.94    3.58    7.63  3040 1.00
## r[94]     1.11    0.02 0.87   -0.46    1.07    2.94  3324 1.00
## r[95]     1.10    0.02 0.89   -0.50    1.07    3.01  3119 1.00
## r[96]    -2.10    0.02 1.11   -4.52   -2.01   -0.18  3354 1.00
## r[97]    -3.86    0.03 1.74   -7.82   -3.63   -1.04  3635 1.00
## r[98]    -0.04    0.02 0.80   -1.61   -0.05    1.53  2831 1.00
## r[99]     1.99    0.02 1.13    0.06    1.91    4.50  3782 1.00
## r[100]   -3.88    0.03 1.82   -8.19   -3.62   -0.95  3573 1.00
## sigma     3.04    0.01 0.36    2.42    3.01    3.82  1363 1.00
## q[1]      0.05    0.00 0.07    0.00    0.03    0.25  5337 1.00
## q[2]      0.26    0.00 0.14    0.05    0.24    0.58  8321 1.00
## q[3]      0.85    0.00 0.12    0.55    0.88    0.99  7119 1.00
## q[4]      0.95    0.00 0.07    0.75    0.97    1.00  4930 1.00
## q[5]      0.15    0.00 0.12    0.01    0.12    0.45  6261 1.00
## q[6]      0.85    0.00 0.12    0.55    0.87    0.99  6493 1.00
## q[7]      0.95    0.00 0.07    0.75    0.97    1.00  5409 1.00
## q[8]      0.95    0.00 0.07    0.76    0.97    1.00  4470 1.00
## q[9]      0.15    0.00 0.12    0.01    0.13    0.44  6913 1.00
## q[10]     0.15    0.00 0.11    0.01    0.13    0.44  6952 1.00
## q[11]     0.50    0.00 0.16    0.18    0.50    0.81  8347 1.00
## q[12]     0.05    0.00 0.07    0.00    0.03    0.25  5125 1.00
## q[13]     0.15    0.00 0.11    0.01    0.13    0.44  7565 1.00
## q[14]     0.50    0.00 0.16    0.19    0.50    0.81  8205 1.00
## q[15]     0.84    0.00 0.12    0.56    0.87    0.99  6912 1.00
## q[16]     0.95    0.00 0.07    0.75    0.97    1.00  5389 1.00
## q[17]     0.85    0.00 0.11    0.57    0.88    0.99  7818 1.00
## q[18]     0.05    0.00 0.07    0.00    0.03    0.26  5190 1.00
## q[19]     0.27    0.00 0.14    0.05    0.25    0.59  8981 1.00
## q[20]     0.27    0.00 0.14    0.05    0.25    0.59  7947 1.00
## q[21]     0.15    0.00 0.11    0.01    0.13    0.43  7688 1.00
## q[22]     0.16    0.00 0.12    0.01    0.13    0.45  7033 1.00
## q[23]     0.62    0.00 0.15    0.30    0.62    0.89  8919 1.00
## q[24]     0.85    0.00 0.11    0.57    0.88    0.99  6862 1.00
## q[25]     0.95    0.00 0.06    0.77    0.97    1.00  6761 1.00
## q[26]     0.38    0.00 0.16    0.12    0.37    0.70  7676 1.00
## q[27]     0.95    0.00 0.07    0.76    0.97    1.00  5366 1.00
## q[28]     0.50    0.00 0.16    0.19    0.50    0.81  8178 1.00
## q[29]     0.73    0.00 0.14    0.42    0.75    0.95 10547 1.00
## q[30]     0.05    0.00 0.07    0.00    0.03    0.27  5654 1.00
## q[31]     0.95    0.00 0.07    0.76    0.97    1.00  4828 1.00
## q[32]     0.85    0.00 0.11    0.58    0.87    0.99  7616 1.00
## q[33]     0.95    0.00 0.07    0.75    0.97    1.00  5198 1.00
## q[34]     0.05    0.00 0.07    0.00    0.03    0.27  5287 1.00
## q[35]     0.26    0.00 0.14    0.05    0.25    0.59  8264 1.00
## q[36]     0.73    0.00 0.14    0.41    0.75    0.95  8812 1.00
## q[37]     0.85    0.00 0.11    0.55    0.87    0.99  8265 1.00
## q[38]     0.95    0.00 0.06    0.76    0.97    1.00  5924 1.00
## q[39]     0.27    0.00 0.14    0.05    0.25    0.58  8563 1.00
## q[40]     0.15    0.00 0.11    0.01    0.13    0.43  7726 1.00
## q[41]     0.15    0.00 0.12    0.01    0.13    0.45  7694 1.00
## q[42]     0.05    0.00 0.07    0.00    0.03    0.26  5199 1.00
## q[43]     0.05    0.00 0.07    0.00    0.03    0.26  5115 1.00
## q[44]     0.85    0.00 0.12    0.55    0.87    0.99  8205 1.00
## q[45]     0.95    0.00 0.07    0.77    0.97    1.00  5734 1.00
## q[46]     0.62    0.00 0.16    0.29    0.63    0.89 10277 1.00
## q[47]     0.85    0.00 0.12    0.56    0.88    0.99  6867 1.00
## q[48]     0.27    0.00 0.14    0.05    0.25    0.57  7495 1.00
## q[49]     0.95    0.00 0.07    0.76    0.97    1.00  5214 1.00
## q[50]     0.15    0.00 0.12    0.01    0.13    0.46  8770 1.00
## q[51]     0.62    0.00 0.16    0.29    0.63    0.89  8545 1.00
## q[52]     0.85    0.00 0.12    0.55    0.88    0.99  7223 1.00
## q[53]     0.38    0.00 0.16    0.11    0.37    0.71  8739 1.00
## q[54]     0.95    0.00 0.06    0.77    0.97    1.00  5203 1.00
## q[55]     0.05    0.00 0.07    0.00    0.03    0.25  4180 1.00
## q[56]     0.95    0.00 0.07    0.75    0.97    1.00  6260 1.00
## q[57]     0.73    0.00 0.14    0.43    0.75    0.95  6815 1.00
## q[58]     0.38    0.00 0.16    0.11    0.38    0.71  9805 1.00
## q[59]     0.27    0.00 0.14    0.05    0.25    0.59  7759 1.00
## q[60]     0.05    0.00 0.07    0.00    0.03    0.25  5807 1.00
## q[61]     0.05    0.00 0.07    0.00    0.03    0.25  5887 1.00
## q[62]     0.15    0.00 0.11    0.01    0.13    0.44  6361 1.00
## q[63]     0.27    0.00 0.14    0.05    0.25    0.59  7981 1.00
## q[64]     0.95    0.00 0.07    0.75    0.97    1.00  5420 1.00
## q[65]     0.85    0.00 0.11    0.57    0.87    0.99  6881 1.00
## q[66]     0.85    0.00 0.11    0.56    0.87    0.99  6915 1.00
## q[67]     0.85    0.00 0.12    0.55    0.87    0.99  6607 1.00
## q[68]     0.95    0.00 0.07    0.76    0.97    1.00  6183 1.00
## q[69]     0.05    0.00 0.07    0.00    0.03    0.24  5324 1.00
## q[70]     0.05    0.00 0.07    0.00    0.03    0.26  5354 1.00
## q[71]     0.05    0.00 0.07    0.00    0.03    0.25  5169 1.00
## q[72]     0.62    0.00 0.16    0.29    0.63    0.89  9636 1.00
## q[73]     0.16    0.00 0.12    0.01    0.13    0.46  7601 1.00
## q[74]     0.05    0.00 0.07    0.00    0.03    0.26  4582 1.00
## q[75]     0.05    0.00 0.07    0.00    0.03    0.24  5030 1.00
## q[76]     0.05    0.00 0.07    0.00    0.03    0.24  5817 1.00
## q[77]     0.95    0.00 0.06    0.78    0.97    1.00  5290 1.00
## q[78]     0.15    0.00 0.11    0.01    0.13    0.44  7543 1.00
## q[79]     0.95    0.00 0.07    0.75    0.97    1.00  5212 1.00
## q[80]     0.50    0.00 0.16    0.19    0.50    0.81  9277 1.00
## q[81]     0.85    0.00 0.11    0.58    0.87    0.99  8265 1.00
## q[82]     0.27    0.00 0.15    0.05    0.25    0.60  7845 1.00
## q[83]     0.15    0.00 0.11    0.01    0.13    0.43  7061 1.00
## q[84]     0.50    0.00 0.16    0.20    0.50    0.81  9433 1.00
## q[85]     0.85    0.00 0.11    0.58    0.88    0.99  7561 1.00
## q[86]     0.05    0.00 0.07    0.00    0.03    0.24  5433 1.00
## q[87]     0.95    0.00 0.07    0.75    0.97    1.00  4845 1.00
## q[88]     0.15    0.00 0.11    0.01    0.13    0.45  6701 1.00
## q[89]     0.95    0.00 0.07    0.76    0.97    1.00  5043 1.00
## q[90]     0.85    0.00 0.11    0.56    0.88    0.99  8023 1.00
## q[91]     0.73    0.00 0.15    0.41    0.75    0.95 10062 1.00
## q[92]     0.27    0.00 0.14    0.05    0.25    0.60  9007 1.00
## q[93]     0.95    0.00 0.07    0.74    0.97    1.00  5517 1.00
## q[94]     0.74    0.00 0.14    0.44    0.76    0.95  8029 1.00
## q[95]     0.74    0.00 0.14    0.42    0.76    0.95  8715 1.00
## q[96]     0.15    0.00 0.11    0.01    0.13    0.44  8646 1.00
## q[97]     0.05    0.00 0.07    0.00    0.03    0.25  5789 1.00
## q[98]     0.51    0.00 0.16    0.20    0.51    0.81  8050 1.00
## q[99]     0.85    0.00 0.12    0.55    0.88    0.99  7505 1.00
## q[100]    0.06    0.00 0.07    0.00    0.03    0.27  5580 1.00
## lp__   -443.84    0.30 9.31 -462.67 -443.38 -426.83   951 1.00
## 
## Samples were drawn using NUTS(diag_e) at Sun Dec 31 12:21:46 2023.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

10.3.2 階層ベイズモデルの事後分布推定と予測

得られたMCMCサンプルを利用して、推定された事後分布を組み合わせて、生存種子ごとの個体数の分布を予測する。つまり、生存種子数\(y\)個の確率分布\(p(y|...)\)の計算を行う。

生存種子数\(y\)の確率分布は、二項分布\(p(y|\beta,r)\)と正規分布\(p(r|s)\)の無限混合分布であり、下記の式で表現できると仮定する。\(r\)についての積分は、事後分布\(p(r|s)\)に従うような個体を無限個集めてきて、その平均を計算している、という意味合い。

\[ p(y|\beta,s) = \int p(y|\beta,r)p(r|s) dr \]

生存種子数\(y\)の確率分布を決めるパラメタ\(\beta,s\)は、\(\beta,s\)ペアのすべてのMCMCサンプルごとに\(p(\beta,s)\)を評価し、\(y\)ごとにパーセンタイル点を示す。以下のコードは久保先生のサイトの図 10.4: fig10_04.Rを参考にしている。

n <- nrow(d)
size <- 8 

d_est <- rstan::extract(fit, permuted = TRUE)
beta_vals <- d_est$beta
sd_vals <- d_est$s

logistic <- function(z){1 / (1 + exp(-z))}

gaussian_binom <- function(r, y, size, beta, sd) {
  dbinom(y, size, logistic(beta + r)) * dnorm(r, 0, sd)
}

# integrate_gaussian_binom <- function(y, size, beta, sd) {
#   sapply(y, function(y) {
#     integrate(
#       f = gaussian_binom,
#       lower = -sd * 10,
#       upper = sd * 10,
#       y = y,
#       size = size,
#       beta = beta,
#       sd = sd
#     )$value
#   })
# }
# 
# 
# survive_prob <- sapply(
#   X = 1:nrow(beta_vals), 
#   FUN = function(i) {
#     integrate_gaussian_binom(
#       y = 0:size, 
#       size = size,
#       beta = beta_vals[i],
#       sd = sd_vals[i]
#     )
#   }
# )
# 積分を計算する関数
integrate_gaussian_binom <- function(y, size, beta, sd) {
  map_dbl(
    # yの各値(0:8)を渡して積分する
    .x = y, 
    .f = function(x) {
      integrate(
        f = gaussian_binom,
        # 下記はrの積分範囲
        lower = -sd * 10,
        upper = sd * 10,
        # 下記はgaussian_binom()への引数
        y = x,
        size = size,
        beta = beta,
        sd = sd
      )$value
    })
}

survive_prob <- map2_dfc(
  .x = beta_vals,
  .y = sd_vals,
  .f = function(x, y) {
    integrate_gaussian_binom(
      y = 0:size,
      size = size,
      beta = x,
      sd = y
    )}
)

sample <- apply(
  X = survive_prob, 
  MARGIN = 2, 
  FUN = function(prob) {
    summary(factor(
      sample(0:size, n, replace = TRUE, prob = prob), levels = 0:size
            )
    )
})

median <- tibble(survive = 0:size, number = apply(X = survive_prob * n, MARGIN = 1, FUN = median))
obs <- tibble(survive = 0:size, number = summary(as.factor(d$y)))
pred_val <- apply(sample, 1, quantile, probs = c(0.025, 0.975))
pred_interval <- tibble(survive = c(0:size, size:0), number = c(pred_val[1, ], rev(pred_val[2, ])))

ggplot() +
  geom_line(data = median, aes(x = survive, y = number)) +
  geom_point(data = median, aes(x = survive, y = number), shape = 1, size = 2) +
  geom_point(data = obs,    aes(x = survive, y = number), size = 2) +
  geom_polygon(data = pred_interval, aes(x = survive, y = number), alpha = 0.2) +
  labs(x = '生存種子数y', y = '個体数', title = '生存種子数yの予測分布p(y|β,r)') +
  theme_bw(base_family = 'HiraKakuPro-W3')

作図の補足

作図の部分について補足を書き足しておく。ここではデータサイズを100から10に変更し、MCMCサンプルも4000から3個に減らすことで、計算の過程を追いやすくしておく。こちらのスライドが詳しいので非常に参考になる。

n <- 10
size <- 8 

# パラメタの事後分布を取り出す
d_est <- rstan::extract(fit, permuted = TRUE)
beta_vals <- d_est$beta[1:3]
sd_vals <- d_est$s[1:3]

被積分関数を定義する。被積分関数は二項分布と個体差を表す正規分布の無限混合分布。

\[ p(y|\beta,s) = \int \overbrace{p(y|\beta,r)}^{二項分布} \overbrace{p(r|s)}^{個体差の事後分布} dr \]

# ロジスティック関数
logistic <- function(z){1 / (1 + exp(-z))}

# 被積分関数は二項分布と正規分布の無限混合分布
gaussian_binom <- function(r, y, size, beta, sd) {
  dbinom(x = y, size = size, prob = logistic(beta + r)) * 
    dnorm(x = r, mean = 0, sd = sd)
}

さきほど定義したgaussian_binom()関数を使って、\(y\)の値ごとに、MCMCサンプリングされたパラメタの事後分布の値を使って、integrate_gaussian_binom()関数で積分する。

# 積分を計算する関数
integrate_gaussian_binom <- function(y, size, beta, sd) {
  map_dbl(
    # yの各値(0:size)を渡して積分する
    .x = y, 
    .f = function(x) {
      integrate(
        f = gaussian_binom,
        # 下記はrの積分範囲
        lower = -sd * 10,
        upper = sd * 10,
        # 下記はgaussian_binom()への引数
        y = x,
        size = size,
        beta = beta,
        sd = sd
      )$value
    })
}

本来、MCMCサンプルは4000個あるので4000列できるが、ここでは3つに絞っているの3列しかない。

# MCMCサンプリングされたパラメタの事後分布で積分
survive_prob <- map2_dfc(
  .x = beta_vals,
  .y = sd_vals,
  .f = function(x, y) {
    integrate_gaussian_binom(
      y = 0:size,
      size = size,
      beta = x,
      sd = y
    )}
)
# apply(survive_prob, 2, sum)
# [1] 1 1 1
survive_prob
## # A tibble: 9 × 3
##     ...1   ...2   ...3
##    <dbl>  <dbl>  <dbl>
## 1 0.183  0.128  0.219 
## 2 0.0945 0.0939 0.0939
## 3 0.0711 0.0798 0.0669
## 4 0.0629 0.0749 0.0576
## 5 0.0616 0.0758 0.0555
## 6 0.0662 0.0824 0.0592
## 7 0.0796 0.0979 0.0711
## 8 0.115  0.133  0.105 
## 9 0.266  0.234  0.272

計算している内容としては、まず\(\beta,s\)の事後分布から値を1つずつ取り出して、

beta_vals[1];sd_vals[1]
## [1] 0.4666783
## [1] 3.112843

定義済みの無限混合分布であるgaussian_binom\(r\)で積分することで。\(y\)\([0,8]\)に対する確率が返される。

integrate_gaussian_binom(
  y = 0:size, 
  size = size, 
  beta = beta_vals[1], 
  sd = sd_vals[1]
  )
## [1] 0.18283746 0.09445306 0.07109860 0.06286374 0.06163024 0.06624750 0.07959305
## [8] 0.11514952 0.26612682

積分して得た確率分布survive_probを使って生存種子数をサンプリングする。1列ごとにパラメタの事後分布がセットになっている。再三にはなるが、本来、MCMCサンプルは4000個あるので4000列できるが、ここでは3つに絞っているの3列しかない。

sample <- apply(
  X = survive_prob, 
  MARGIN = 2, 
  FUN = function(prob) {
    summary(factor(sample(0:size, n, replace = TRUE, prob = prob), levels = 0:size))
})
sample
##   ...1 ...2 ...3
## 0    1    3    2
## 1    1    0    0
## 2    0    1    0
## 3    0    0    1
## 4    0    0    0
## 5    2    0    3
## 6    0    1    1
## 7    2    0    0
## 8    4    5    3

あとは可視化に必要なデータを計算していく。obsは観測データを生存種子ごとにカウントしたもので、medianは各事後分布のパラメタの組から生成した無限混合分布の生存種子\(y\)に対する確率を、生存種子ごとにパラメタの組から生成した確率を横断して、中央値を計算している。また、カウント数に対応させるよに\(n\)倍している。

obs <- tibble(survive = 0:size, number = summary(as.factor(d$y)))
median <- tibble(survive = 0:size, number = apply(X = survive_prob * n, MARGIN = 1, FUN = median))
list(
  obs = obs,
  median = median
)
## $obs
## # A tibble: 9 × 2
##   survive number
##     <int>  <int>
## 1       0     19
## 2       1     15
## 3       2     10
## 4       3      3
## 5       4      6
## 6       5      4
## 7       6      6
## 8       7     17
## 9       8     20
## 
## $median
## # A tibble: 9 × 2
##   survive number
##     <int>  <dbl>
## 1       0  1.83 
## 2       1  0.939
## 3       2  0.711
## 4       3  0.629
## 5       4  0.616
## 6       5  0.662
## 7       6  0.796
## 8       7  1.15 
## 9       8  2.66

予測区間を計算するために、各パラメタの組から生成された生存確率を使ってサンプリングされた生存種子の分布を横断して、生存種子ごとにパーセンタイルを計算している。

pred_val <- apply(X = sample, MARGIN = 1, FUN = quantile, probs = c(0.025, 0.975))
pred_interval <- tibble(
  survive = c(0:size, size:0), 
  number = c(pred_val[1, ], rev(pred_val[2, ])))

あとはこれを使って可視化すれば、さきほどの図が出来上がる。