UPDATE: 2023-12-18 22:36:54.800651

はじめに

このノートは「StanとRでベイズ統計モデリング」の内容を写経することで、ベイズ統計への理解を深めていくために作成している。

基本的には気になった部分を写経しながら、ところどころ自分用の補足をメモすることで、「StanとRでベイズ統計モデリング」を読み進めるための自分用の補足資料になることを目指す。私の解釈がおかしく、メモが誤っている場合があるので注意。

今回は第4章「StanとRStanをはじめよう」のチャプターを写経していく。

2.5 ベイズ信頼区間とベイズ予測区間

コードが出てくるのは第4章からではあるが、第2章の予測分布で気になる部分があったのでメモを追加しておく。予測分布は下記の通り、確率モデルと事後分布を積分することで計算される。

事後予測分布

\[ \begin{eqnarray} p_{pred}(y|Y) = \int p(y|\theta) p(\theta|Y) d\theta \\ \end{eqnarray} \]

下記の書籍に書かれている通り、予測分布を得る場合は数千個得られるパラメタのMCMCサンプルを変えながら、確率モデルからの乱数を得ることになる。

より簡単に、事後分布\(p(\theta|Y)\)からのMCMCサンプルから値を1つ選び、それを\(\theta^{+}\)として確率モデル\(p(y|\theta^{+})\)に従う乱数\(y^{+}\)を生成することを繰り返し、\(y^{+}\)をたくさん生成することで予測分布\(p_{pred}(y|Y)\)からのMCMCサンプルとみなすことができる。

どちらが良いのかは分からないが、他の書籍やブログなどには、MCMCサンプルの値を変えず、パラメタの事後分布の点推定量を計算し、それを利用して確率モデルから乱数を得る「条件付き予測分布」という方法が記載されている場合もある。このあたりの理解が最初は追いつかず、Stanでgenerated quantitiesの内容と頭の理解が一致せず、困った経験が過去にあったので、再勉強するにあたりメモを残しておく。基本的には事後予測分布を使用し、条件付き予測分布は使わなければ問題はなさそう。

4.2.2 文法の基礎

書籍に書かれている下記のモデル式4-1への理解を深めておく。

モデル4-1

\[ \begin{eqnarray} Y[n] &\sim& Normal(\mu, 1) \tag{4.1}\\ \mu &\sim& Normal(0, 100) \tag{4.2} \end{eqnarray} \]

これは書籍にも書かれているが、「データ1つごとに平均\(\mu\)、標準偏差1の正規分布から独立に確率的に生成された」ことを意味する。書き下すと下記の通りで、\(\mu\)は何らかのスカラであり、その\(\mu\)をもつ正規分布から生成されていることを意味する。スカラを強調しているのは、階層ベイズモデルなどを筆頭に複雑なモデルを扱う時、グループごとに\(\mu\)が変化することもあるためである。

Y[1] ~ Normal(mu, 1) かつ
Y[2] ~ Normal(mu, 1) かつ
Y[3] ~ Normal(mu, 1) かつ
...
Y[n] ~ Normal(mu, 1)

これを略した記述がY[n] ~ normal(mu, 1)

Stanの文法で記述すると、下記に対応する。

model{
  for (n in 1:N){
    Y[n] ~ normal(mu, 1);
  }
  mu ~ normal(0, 100) # 事前分布
}

モデルが複雑になってくると、どういうこと?これはどう記述すればよいのか?など、疑問は湧き出てくるので、しっかり基礎の基礎を理解しておきたい・・・(戒め)。

4.4.3 モデル式の記述

下記4つのモデルは別のモデルではなく等価なので解釈を誤らないように注意が必要。モデル4-5式はデータ1人ごとに平均\(a+bX[n]\)、標準偏差\(\sigma\)の正規分布から独立に生成されたことを意味している。

どのモデルも平均値が\(X\)の値によって変わる条件つき正規分布で、すべての\(X\)において等分散を仮定しているモデルである。最終的にパラメタ\(a,b\)の事後分布を得たいものの、各モデル式をそのままStanのmodelブロックで定義すると異なる表記になるので注意が必要。

モデル4-2

\[ \begin{eqnarray} Y[n] &=& y_{base}[n] + \epsilon[n] \\ y_{base}[n] &=& a + b X[n]\\ \epsilon[n] &\sim& Normal(0, \sigma) \tag{4.4} \end{eqnarray} \]

モデル4-3

\[ \begin{eqnarray} Y[n] &=& a + b X[n] + \epsilon[n] \\ \epsilon[n] &\sim& Normal(0, \sigma) \end{eqnarray} \]

モデル4-4

\[ \begin{eqnarray} y_{base}[n] &=& a + b X[n] \\ Y[n] &\sim& Normal(y_{base}[n], \sigma) \end{eqnarray} \]

モデル4-5

\[ \begin{eqnarray} Y[n] &\sim& Normal(a + b X[n], \sigma) \end{eqnarray} \]

4.4.5 Stanで実装

書籍内で使用されているデータは下記の通り。

library(ggplot2)
library(rstan)

d <- read.csv('https://raw.githubusercontent.com/MatsuuraKentaro/RStanBook/master/chap04/input/data-salary.txt')
data <- list(N = nrow(d), X = d$X, Y = d$Y)

head(d)
##    X   Y
## 1 24 472
## 2 24 403
## 3 26 454
## 4 32 575
## 5 33 546
## 6 35 781

まずはStanコードを作成する。別ファイルに記載しているコードを下記に転記している。

data {
  int N;
  real X[N];
  real Y[N];
}
parameters {
  real a;
  real b;
  real<lower=0> sigma;
}
model {
  for (n in 1:N){
    Y[n] ~ normal(a + b*X[n], sigma);
  }
}

modelブロックで記載しているモデルは、さきほどからしつこく書いている通り「データ1人ごとに平均\(a+bX[n]\)、標準偏差\(\sigma\)の正規分布から独立に生成された」ことを意味しているモデルを想定している。つまり、書き下すと下記のようになる。

Y[1] ~ normal(a + b*X[1], sigma);
Y[2] ~ normal(a + b*X[2], sigma);
...
Y[N] ~ normal(a + b*X[N], sigma);

ここでは、stan_model()関数で最初にコンパイルしておいてから、

model443 <- stan_model('note_ahirubayes01.stan')

sampling()関数でサンプリングする。

fit <- sampling(object = model443, data = data, seed = 1234)

推定結果は下記の通り。指定しない限りparametersブロック、transformed parametersブロック、generated quantitiesブロックで指定されたパラメタのMCMCサンプルをサンプリングする。つまり、どのようなパラメタがサンプリングされるかは、モデル式をどのように記述するのかに依存する。

fit
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##          mean se_mean    sd    2.5%     25%     50%    75%  97.5% n_eff Rhat
## a     -119.18    2.26 75.35 -267.06 -167.20 -118.41 -71.11  31.43  1110 1.00
## b       21.89    0.05  1.68   18.61   20.80   21.88  22.99  25.27  1155 1.00
## sigma   85.37    0.42 15.98   60.54   74.41   83.32  93.59 125.17  1418 1.01
## lp__   -93.68    0.04  1.35  -97.19  -94.27  -93.34 -92.70 -92.14  1074 1.00
## 
## Samples were drawn using NUTS(diag_e) at Mon Dec 18 22:37:18 2023.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

分析目的に適う作図や分析結果を得るためには想定しているモデル構造、それをもとに記述されるモデル式への理解を深める必要がある。

4.4.11 ベイズ信頼区間とベイズ予測区間の算出

ベイズ信頼区間、予測区間を算出するためにすることは、まず各パラメタのMCMCサンプルをextract()関数で取り出すこと。MCMCサンプルの長さはchains * (iter - warmup)/thinという関係にあるので、今回は4 * (2000 - 1000)/1 = 4000となる。つまり、先程表示されていたパラメタのサマリーは、4000個のMCMCサンプルから計算されていることになる。

ms <- rstan::extract(fit, permuted = TRUE)
str(ms)
## List of 4
##  $ a    : num [1:4000(1d)] -126.1 -12.8 -229.8 -21.3 -96.5 ...
##   ..- attr(*, "dimnames")=List of 1
##   .. ..$ iterations: NULL
##  $ b    : num [1:4000(1d)] 22.1 19.3 24.6 20 21.4 ...
##   ..- attr(*, "dimnames")=List of 1
##   .. ..$ iterations: NULL
##  $ sigma: num [1:4000(1d)] 89.8 87.6 89.1 94.4 70.8 ...
##   ..- attr(*, "dimnames")=List of 1
##   .. ..$ iterations: NULL
##  $ lp__ : num [1:4000(1d)] -92.4 -93.6 -93.6 -93.6 -92.2 ...
##   ..- attr(*, "dimnames")=List of 1
##   .. ..$ iterations: NULL

下記は、各パラメタの1-10行目を取り出しているが、

data.frame(
  set = paste0('set', 1:10),
  a = ms$a[1:10],
  b = ms$b[1:10],
  sigma = ms$sigma[1:10]
)
##      set          a        b     sigma
## 1   set1 -126.09782 22.09596  89.79984
## 2   set2  -12.79003 19.29683  87.55702
## 3   set3 -229.82267 24.56804  89.05753
## 4   set4  -21.32934 20.00322  94.39494
## 5   set5  -96.54672 21.42506  70.78023
## 6   set6 -145.84620 22.16302  67.00061
## 7   set7  -81.37651 21.17869  73.67658
## 8   set8 -270.45381 25.08404  90.17393
## 9   set9  -73.37723 20.10361 122.57991
## 10 set10  -61.34379 20.86757  75.74603

これは\(p(a,b,\sigma|X,Y)\)という同時分布からサンプリングした1つの組み(=1行)であり、この組みが4000組得られている。この4000組を使って信用区間、予測区間を算出する。条件付き予測分布の例があったので、少し混乱するが、書籍では23-60歳の範囲での予測区間を算出する前に、丁寧に1時点での解説がされているので、大変イメージしやすい。下記は50歳時点での予測分布を構築するコードである。

ms <- rstan::extract(fit) 
N_mcmc <- length(ms$lp__)
y50_base <- ms$a + ms$b * 50
y50 <- rnorm(n = N_mcmc, mean = y50_base, sd = ms$sigma)
d_mcmc <- data.frame(a = ms$a, b = ms$b, sigma = ms$sigma, y50_base, y50)

ms$a + ms$b * 50の部分がまさに下記を表している。

より簡単に、事後分布\(p(\theta|Y)\)からのMCMCサンプルから値を1つ選び、それを\(\theta^{+}\)として確率モデル\(p(y|\theta^{+})\)に従う乱数\(y^{+}\)を生成することを繰り返し、\(y^{+}\)をたくさん生成することで予測分布\(p_{pred}(y|Y)\)からのMCMCサンプルとみなすことができる。

ms$ams$bも4000個のベクトルであって、ベクトルの計算に年齢50を乗じることで、4000個の平均ベクトルを得ている。1番目のパラメタの1組を取り出して平均を計算し、2番目のパラメタの1組を取り出して平均を計算し、これを4000回繰り返す。そして、この平均を使って正規分布からの乱数rnorm(n = N_mcmc, mean = y50_base, sd = ms$sigma)を得ている。

rnorm()関数の平均にベクトルを渡すというイメージは分かりにくいかもしれないが、変化する平均をもとに、乱数を生成する。

set.seed(1234)
rnorm(5, c(1,10,100,1000,10000), 1)
## [1]    -0.2070657    10.2774292   101.0844412   997.6543023 10000.4291247

これが50歳時点での予測分布となるので、あとはこれを23-60歳の範囲で同じように繰り返すことで予測区間が算出される。

hist(d_mcmc$y50, breaks = 50, main = 'Posterior Predicted Distribution of Income at Age 50')

実際はStan側からデータを受け取って予測区間となるMCMCサンプルを計算するではなく、Stan側で計算させることができる。先程説明した、計算したいパラメタとStanのモデル式が対応しているとはこのことである。事後予測分布については、ノート末尾でも補足する。

まずはStanコードを作成する。別ファイルに記載しているコードを下記に転記している。

data {
  int N;
  real X[N];
  real Y[N];
  int N_new;
  real X_new[N_new];
}

parameters {
  real a;
  real b;
  real<lower=0> sigma;
}

transformed parameters {
  real y_base[N];
  for (n in 1:N)
    y_base[n] = a + b*X[n];
}

model {
  for (n in 1:N)
    Y[n] ~ normal(y_base[n], sigma);
}

generated quantities {
  real y_base_new[N_new];
  real y_new[N_new];
  for (n in 1:N_new) {
    y_base_new[n] = a + b*X_new[n];
    y_new[n] = normal_rng(y_base_new[n], sigma);
  }
}

generated quantitiesブロックを追加して、予測分布を算出するための予測値のMCMCサンプルを計算させている。あとはデータを用意して、

X_new <- 23:60
data <- list(N = nrow(d), X = d$X, Y = d$Y, N_new = length(X_new), X_new = X_new)

モデルをコンパイルし、

model4412 <- stan_model('note_ahirubayes01-2.stan')

サンプリングする。

fit <- sampling(object = model4412, data = data, seed = 1234)

サンプリング結果を確認すると、parametersブロック、transformed parametersブロック、generated quantitiesブロックで指定されたパラメタの事後分布が得られている。

print(fit, probs = c(0.025, 0.5, 0.975))
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##                   mean se_mean    sd    2.5%     50%   97.5% n_eff Rhat
## a              -118.45    2.29 74.36 -265.51 -118.82   32.61  1054    1
## b                21.87    0.05  1.65   18.51   21.87   25.05  1060    1
## sigma            84.65    0.38 15.35   60.18   82.63  120.37  1605    1
## y_base[1]       406.43    1.09 37.50  332.27  406.18  481.53  1184    1
## y_base[2]       406.43    1.09 37.50  332.27  406.18  481.53  1184    1
## y_base[3]       450.16    0.99 34.70  381.40  449.98  520.46  1224    1
## y_base[4]       581.38    0.70 27.00  527.67  580.81  635.08  1468    1
## y_base[5]       603.25    0.66 25.87  551.42  602.62  654.21  1542    1
## y_base[6]       646.99    0.57 23.77  600.19  646.62  694.27  1748    1
## y_base[7]       712.60    0.45 21.22  670.25  712.56  754.61  2242    1
## y_base[8]       756.34    0.38 20.02  716.44  756.39  795.93  2815    1
## y_base[9]       756.34    0.38 20.02  716.44  756.39  795.93  2815    1
## y_base[10]      821.95    0.29 19.17  784.86  822.32  859.61  4273    1
## y_base[11]      821.95    0.29 19.17  784.86  822.32  859.61  4273    1
## y_base[12]      843.82    0.29 19.16  806.79  843.84  881.44  4399    1
## y_base[13]      931.30    0.32 20.52  890.96  931.38  972.68  4150    1
## y_base[14]     1018.78    0.46 23.70  972.24 1018.73 1066.09  2709    1
## y_base[15]     1106.26    0.63 28.10 1051.87 1106.30 1162.11  2014    1
## y_base[16]     1106.26    0.63 28.10 1051.87 1106.30 1162.11  2014    1
## y_base[17]     1128.13    0.67 29.33 1070.69 1127.99 1186.57  1906    1
## y_base[18]     1149.99    0.72 30.60 1090.55 1149.87 1211.04  1815    1
## y_base[19]     1171.86    0.77 31.91 1109.85 1171.58 1235.93  1738    1
## y_base[20]     1171.86    0.77 31.91 1109.85 1171.58 1235.93  1738    1
## y_base_new[1]   384.56    1.14 38.93  307.33  384.32  462.65  1168    1
## y_base_new[2]   406.43    1.09 37.50  332.27  406.18  481.53  1184    1
## y_base_new[3]   428.30    1.04 36.09  356.87  428.15  500.90  1203    1
## y_base_new[4]   450.16    0.99 34.70  381.40  449.98  520.46  1224    1
## y_base_new[5]   472.03    0.94 33.34  405.60  471.87  539.82  1249    1
## y_base_new[6]   493.90    0.89 32.00  429.96  493.75  559.08  1279    1
## y_base_new[7]   515.77    0.85 30.69  454.70  515.54  578.18  1314    1
## y_base_new[8]   537.64    0.80 29.42  479.41  537.20  597.18  1356    1
## y_base_new[9]   559.51    0.75 28.19  503.82  558.99  615.93  1407    1
## y_base_new[10]  581.38    0.70 27.00  527.67  580.81  635.08  1468    1
## y_base_new[11]  603.25    0.66 25.87  551.42  602.62  654.21  1542    1
## y_base_new[12]  625.12    0.61 24.79  575.88  624.56  674.32  1634    1
## y_base_new[13]  646.99    0.57 23.77  600.19  646.62  694.27  1748    1
## y_base_new[14]  668.86    0.53 22.83  623.64  668.48  713.89  1887    1
## y_base_new[15]  690.73    0.49 21.98  646.50  690.57  734.50  2046    1
## y_base_new[16]  712.60    0.45 21.22  670.25  712.56  754.61  2242    1
## y_base_new[17]  734.47    0.41 20.56  693.14  734.38  774.77  2484    1
## y_base_new[18]  756.34    0.38 20.02  716.44  756.39  795.93  2815    1
## y_base_new[19]  778.21    0.34 19.60  740.02  778.52  816.72  3272    1
## y_base_new[20]  800.08    0.31 19.32  762.43  800.62  837.92  3806    1
## y_base_new[21]  821.95    0.29 19.17  784.86  822.32  859.61  4273    1
## y_base_new[22]  843.82    0.29 19.16  806.79  843.84  881.44  4399    1
## y_base_new[23]  865.69    0.29 19.30  828.15  865.81  904.16  4447    1
## y_base_new[24]  887.56    0.29 19.58  849.83  887.69  926.93  4415    1
## y_base_new[25]  909.43    0.30 19.98  870.80  909.66  949.34  4312    1
## y_base_new[26]  931.30    0.32 20.52  890.96  931.38  972.68  4150    1
## y_base_new[27]  953.17    0.35 21.17  911.16  953.38  995.46  3717    1
## y_base_new[28]  975.04    0.38 21.92  931.58  974.98 1019.16  3316    1
## y_base_new[29]  996.91    0.42 22.77  951.95  996.86 1042.68  2970    1
## y_base_new[30] 1018.78    0.46 23.70  972.24 1018.73 1066.09  2709    1
## y_base_new[31] 1040.65    0.50 24.71  992.46 1040.68 1089.88  2490    1
## y_base_new[32] 1062.52    0.54 25.79 1012.34 1062.62 1113.69  2301    1
## y_base_new[33] 1084.39    0.58 26.92 1032.26 1084.45 1138.06  2144    1
## y_base_new[34] 1106.26    0.63 28.10 1051.87 1106.30 1162.11  2014    1
## y_base_new[35] 1128.13    0.67 29.33 1070.69 1127.99 1186.57  1906    1
## y_base_new[36] 1149.99    0.72 30.60 1090.55 1149.87 1211.04  1815    1
## y_base_new[37] 1171.86    0.77 31.91 1109.85 1171.58 1235.93  1738    1
## y_base_new[38] 1193.73    0.81 33.24 1129.18 1193.44 1260.49  1672    1
## y_new[1]        383.23    1.72 96.02  196.11  384.61  573.26  3108    1
## y_new[2]        406.39    1.73 96.45  209.35  406.79  595.22  3123    1
## y_new[3]        428.54    1.75 92.02  249.03  429.50  610.62  2763    1
## y_new[4]        450.39    1.70 92.20  267.29  449.58  632.77  2954    1
## y_new[5]        472.25    1.62 90.76  290.45  472.86  649.38  3146    1
## y_new[6]        493.06    1.57 90.16  310.94  493.36  669.36  3280    1
## y_new[7]        514.11    1.62 90.98  337.08  515.52  696.24  3164    1
## y_new[8]        538.95    1.59 91.77  356.81  539.24  726.62  3338    1
## y_new[9]        561.11    1.47 90.88  386.75  559.14  739.11  3835    1
## y_new[10]       579.53    1.53 90.36  394.52  579.87  756.01  3482    1
## y_new[11]       605.20    1.55 88.31  428.94  605.44  781.14  3245    1
## y_new[12]       626.40    1.47 89.35  450.62  624.63  802.71  3674    1
## y_new[13]       645.41    1.42 89.85  464.76  645.63  817.62  3997    1
## y_new[14]       669.25    1.57 90.03  487.89  668.37  850.64  3272    1
## y_new[15]       689.81    1.59 89.70  510.17  690.11  866.96  3175    1
## y_new[16]       710.00    1.36 86.73  538.25  710.83  875.50  4039    1
## y_new[17]       735.53    1.41 90.55  555.77  736.74  908.69  4137    1
## y_new[18]       756.10    1.51 87.69  585.51  756.21  937.33  3375    1
## y_new[19]       778.56    1.46 87.42  603.69  777.35  947.16  3577    1
## y_new[20]       799.39    1.52 89.61  620.71  798.55  980.57  3472    1
## y_new[21]       820.37    1.33 86.72  648.73  820.34  990.79  4229    1
## y_new[22]       844.38    1.40 88.92  665.73  844.98 1019.82  4024    1
## y_new[23]       865.18    1.36 88.05  690.51  865.73 1042.00  4162    1
## y_new[24]       887.60    1.38 87.29  713.48  887.13 1063.13  4017    1
## y_new[25]       908.24    1.39 87.96  739.58  906.31 1088.30  3992    1
## y_new[26]       929.76    1.39 87.58  757.61  929.34 1101.97  3954    1
## y_new[27]       954.01    1.45 89.35  771.42  954.88 1129.81  3807    1
## y_new[28]       974.16    1.45 88.38  802.99  973.70 1149.67  3731    1
## y_new[29]       997.35    1.44 90.18  821.26  994.45 1175.92  3910    1
## y_new[30]      1017.07    1.38 87.34  838.10 1017.22 1188.33  4024    1
## y_new[31]      1041.04    1.55 87.55  864.90 1040.77 1215.09  3194    1
## y_new[32]      1063.09    1.46 89.44  883.66 1061.74 1241.29  3731    1
## y_new[33]      1083.72    1.58 91.51  892.11 1083.34 1265.46  3375    1
## y_new[34]      1106.51    1.51 89.52  925.88 1107.77 1278.90  3513    1
## y_new[35]      1128.90    1.45 90.51  945.31 1131.62 1301.03  3883    1
## y_new[36]      1149.61    1.50 89.46  974.13 1150.51 1325.19  3567    1
## y_new[37]      1170.58    1.59 91.77  987.67 1171.64 1347.54  3327    1
## y_new[38]      1195.27    1.47 91.73 1010.95 1196.52 1374.86  3917    1
## lp__            -93.65    0.03  1.26  -96.88  -93.34  -92.16  1452    1
## 
## Samples were drawn using NUTS(diag_e) at Mon Dec 18 22:37:38 2023.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

ここから予測分布を作成していくが、まずはデータがどのように保存されているのか確認する。信用区間を算出したければms$y_newms$y_base_newに読み替えればOK。

ms <- rstan::extract(fit)
str(ms$y_new)
##  num [1:4000, 1:38] 461 576 366 599 372 ...
##  - attr(*, "dimnames")=List of 2
##   ..$ iterations: NULL
##   ..$           : NULL

予測区間を算出するために必要なのはms$y_newであるが、matrixクラスで保存されている。先程のa,bはベクトルだったが、今回はmatrixクラスで保存されている。これは23−60歳までのMCMCサンプルを計算しているためである。つまりmcmcサンプルの長さ × forloopのループ回数という形式でmatrixが保存されていることを意味する。

# mcmcサンプルの長さは20まで表示。全部の長さは4000
ms$y_new[1:20,]
##           
## iterations     [,1]     [,2]     [,3]     [,4]     [,5]     [,6]     [,7]
##       [1,] 460.8634 437.9357 518.3424 481.0554 477.2357 534.8487 390.8750
##       [2,] 575.9002 443.6138 462.6731 480.1184 385.1748 508.4724 693.3259
##       [3,] 366.1943 458.1170 385.2195 339.1653 341.4010 549.2674 658.2847
##       [4,] 599.4339 496.1026 464.5705 285.4239 515.7201 469.4772 473.4473
##       [5,] 372.0720 372.9459 428.0292 375.9517 544.5555 416.6049 417.3597
##       [6,] 403.7616 424.9994 434.3494 287.1179 481.5824 359.6440 450.4261
##       [7,] 356.3242 376.9793 359.0668 337.1274 486.7946 548.2675 325.5847
##       [8,] 146.6523 521.4328 358.1214 267.9913 256.7348 459.1277 477.2827
##       [9,] 345.8380 466.1090 530.5797 330.2329 466.2934 421.6227 344.9762
##      [10,] 325.6443 621.3337 389.4070 475.1688 325.2862 627.9302 695.1057
##      [11,] 517.3295 540.6945 559.2273 449.3653 327.9051 461.0802 417.7049
##      [12,] 290.1946 386.1028 317.2166 452.6001 460.1248 486.7724 404.2192
##      [13,] 258.0210 342.0837 249.0349 496.2835 336.8860 478.1819 613.6761
##      [14,] 334.7587 385.1673 411.1875 458.6136 337.4290 545.6124 591.9120
##      [15,] 389.0858 535.0654 486.2826 389.5704 382.5474 474.3361 659.8875
##      [16,] 636.7629 530.6810 464.4844 553.0752 658.3440 523.2618 694.3509
##      [17,] 375.9961 335.5095 396.1599 407.3054 370.2821 468.1920 657.3317
##      [18,] 422.5172 571.5171 465.8221 558.1021 582.3216 362.8164 581.5894
##      [19,] 324.4703 318.0017 421.6565 413.1590 393.6573 412.7902 542.7472
##      [20,] 311.3647 428.5097 535.2399 576.4163 522.2956 528.2646 501.4678
##           
## iterations     [,8]     [,9]    [,10]    [,11]    [,12]    [,13]    [,14]
##       [1,] 594.9074 455.1513 618.9944 656.6739 447.4682 718.7593 654.7729
##       [2,] 589.8882 558.8310 611.4708 628.5700 679.5158 774.6877 593.4986
##       [3,] 443.8143 559.2858 337.3706 332.3432 576.0598 666.7836 735.1523
##       [4,] 627.8903 528.3965 336.4917 454.7115 433.9457 548.0407 582.7085
##       [5,] 612.9106 516.9425 581.3942 626.8271 602.8598 651.1933 573.4231
##       [6,] 439.9647 698.7475 571.1804 643.0411 730.6044 770.0660 618.7351
##       [7,] 538.7477 589.4282 667.7911 505.3941 707.2915 591.5460 673.9480
##       [8,] 472.5433 377.1111 529.3654 513.9780 575.2311 522.0048 539.9877
##       [9,] 472.4539 488.6546 571.1610 522.2566 874.2138 538.8533 686.0758
##      [10,] 654.8062 553.7593 602.7138 484.8147 777.9759 650.8207 679.4669
##      [11,] 515.9074 404.5850 690.9860 529.4769 634.4511 802.4316 651.0564
##      [12,] 470.7328 462.4233 507.6061 681.5549 581.1888 757.6221 698.5235
##      [13,] 569.3603 547.7978 560.4178 623.9338 648.6549 590.9561 719.5660
##      [14,] 563.4991 507.2462 615.6941 704.8850 568.5192 740.8673 670.0796
##      [15,] 446.6293 482.9956 490.4006 671.5579 580.4193 599.2723 617.1431
##      [16,] 593.7282 598.4804 590.4795 805.9819 532.4109 625.1921 702.4869
##      [17,] 456.2713 460.7779 506.5233 507.5420 743.4607 552.3405 632.5317
##      [18,] 514.8807 464.3147 643.7972 633.8365 654.2452 676.7209 644.1947
##      [19,] 536.2386 475.7015 533.0859 709.4285 559.1847 633.5689 689.1405
##      [20,] 357.6691 480.5350 641.7019 671.7305 759.2330 544.3446 695.6023
##           
## iterations    [,15]    [,16]    [,17]    [,18]    [,19]     [,20]     [,21]
##       [1,] 634.6249 648.5844 687.0959 805.2507 642.1236  903.6075  807.4724
##       [2,] 669.6236 810.6379 827.8528 664.9283 879.0198  765.4336  913.0720
##       [3,] 503.9478 452.6813 634.8068 757.8929 702.5623  715.3880  895.2918
##       [4,] 791.9150 790.7317 718.9376 665.9947 717.8696  801.4429  661.9766
##       [5,] 628.1306 761.4416 672.5884 813.2941 774.6610  726.3557  761.7498
##       [6,] 794.2247 714.5360 691.8072 812.8053 862.2054  874.3026  759.0869
##       [7,] 708.0689 714.9211 746.5398 746.4316 806.8987  937.6288  805.2882
##       [8,] 701.9692 655.7353 743.8688 702.4851 905.9119  811.4945  790.9721
##       [9,] 612.5157 702.8137 580.3694 833.6936 870.1549  820.1008  798.0600
##      [10,] 851.5035 678.2560 694.0315 775.3822 713.3565 1035.4219 1102.2633
##      [11,] 786.4316 923.4823 862.1359 839.6637 607.0796  930.9741  783.0660
##      [12,] 653.4352 732.6550 794.0398 740.6847 813.2789  687.9116  701.8360
##      [13,] 690.5611 739.8178 786.9390 787.5250 808.5921  859.1639  824.1926
##      [14,] 787.1457 719.5702 645.8272 732.9679 750.9665  690.1316  928.4002
##      [15,] 658.9345 781.7945 811.0267 677.0498 809.9895  842.0706  713.2065
##      [16,] 648.7334 773.8109 649.8857 821.6310 841.6855  876.3340  766.0753
##      [17,] 649.9147 827.7006 845.1821 696.5604 797.2429  896.2624  830.7148
##      [18,] 792.3552 750.7065 713.4160 702.4390 814.9879  738.1243  764.5461
##      [19,] 877.1512 708.7832 794.2238 767.9896 820.4353  888.3139  793.9683
##      [20,] 628.2958 765.5055 819.2150 666.4196 809.0813  893.6092  999.0721
##           
## iterations    [,22]     [,23]    [,24]     [,25]     [,26]     [,27]     [,28]
##       [1,] 800.6874  886.0923 849.6370  917.4460  866.7207  840.9107  846.2594
##       [2,] 718.6075  852.3272 969.4953 1022.4456  939.7496  987.8303 1049.1033
##       [3,] 991.6738 1014.9260 729.0990  999.3453  909.5472  905.7015  965.9020
##       [4,] 940.5345  605.9063 910.9429  959.5390  716.6045 1125.1973  948.5425
##       [5,] 825.6660  965.4303 860.6505  913.6277 1025.7633 1019.5499  944.0361
##       [6,] 918.2497  820.3603 853.6841  873.9950  742.7963  995.9853  851.9542
##       [7,] 814.7578  883.2623 914.4702  936.6025  779.7538  984.0825 1012.4141
##       [8,] 776.0786  826.6984 888.8652  833.2019  926.9280  950.0061 1091.4938
##       [9,] 843.8269  960.3566 810.9016  801.1478  804.5326 1031.6536 1037.0339
##      [10,] 753.9680  751.7611 922.0987 1009.6970  825.3445  893.8568 1020.4597
##      [11,] 699.5343  885.6054 940.8069  722.9219 1068.1142  959.8222 1040.9228
##      [12,] 863.0833  811.8399 893.4901  841.8906  881.5025  954.8141 1004.2856
##      [13,] 898.3977  816.0073 930.8448  923.1194  870.4447 1004.3028  870.3332
##      [14,] 818.1159  996.8225 838.2550  842.6895  905.5167 1018.3691 1083.6420
##      [15,] 935.4024  807.3129 991.6398  772.0339  955.5630  898.0924 1080.0350
##      [16,] 760.6484  794.6000 875.0895 1022.0439 1004.2797  955.5873  831.1263
##      [17,] 724.9403  902.7757 831.0406  974.2944 1114.6682  990.2097 1042.6495
##      [18,] 880.0326  756.5383 916.2503  786.4850  710.6524  914.7172 1009.2806
##      [19,] 841.2719  833.1613 909.7609  850.8728  883.9677 1005.0176  977.4049
##      [20,] 953.8126  776.7893 862.6911  830.0702  843.2086  950.0937  999.4058
##           
## iterations     [,29]     [,30]     [,31]     [,32]     [,33]     [,34]
##       [1,]  912.6560 1201.7413  958.3079 1163.4697 1020.4617  949.6320
##       [2,]  894.8990  969.2003  903.3676 1023.1478 1037.9467 1136.3692
##       [3,] 1083.7358  982.5188 1041.1794  748.0982 1007.1747  981.2476
##       [4,] 1002.9557  949.2275  994.6867 1146.5347 1081.1884 1060.7675
##       [5,] 1050.5162  964.5730 1014.3345 1062.2732 1167.1582 1150.3578
##       [6,]  971.3147 1025.1822  939.6040 1118.9930 1179.1470 1171.9391
##       [7,] 1034.5410 1063.9443 1141.4158 1187.8655 1128.4380 1111.7770
##       [8,] 1063.0011 1003.7846  969.3867  941.9058 1096.8417 1014.9422
##       [9,]  922.6610  968.8916  910.9271 1041.9371  955.0561 1001.0377
##      [10,]  870.8395 1011.8294  939.2495 1101.5663 1244.1355 1100.6073
##      [11,] 1058.6339  931.9229  930.9247 1165.4494 1020.8268 1091.8252
##      [12,]  999.9197  973.6207 1004.8199 1004.2641 1106.5297 1099.3658
##      [13,]  930.5405  982.4316 1022.2602 1139.5438 1152.2025 1008.7925
##      [14,]  974.9086  975.1867 1022.0142 1088.5308 1222.0376 1048.9858
##      [15,]  981.7409  878.2633 1023.7982  920.0489  927.5012  998.6094
##      [16,]  997.2168 1078.0214  920.6720 1093.9006  923.3353 1233.5923
##      [17,]  948.9044 1163.0407 1029.8391 1041.2351 1045.0430 1188.6281
##      [18,] 1028.1171  947.5524 1072.2639  948.5490 1106.4383 1149.0600
##      [19,] 1092.6323  939.1729 1005.8183 1115.0054 1057.8143 1182.0997
##      [20,] 1007.2728 1071.3887 1003.9543  977.8216 1088.5528 1211.9842
##           
## iterations     [,35]     [,36]    [,37]     [,38]
##       [1,] 1174.4669 1210.2661 1063.009 1266.0082
##       [2,] 1035.0309 1199.3672 1224.920 1085.3430
##       [3,]  958.5278 1080.5149 1140.509 1251.8435
##       [4,] 1205.5529 1378.3699 1245.887 1399.2663
##       [5,] 1300.0677 1028.6940 1212.456 1323.7375
##       [6,] 1185.9207  953.2947 1015.985 1145.6515
##       [7,] 1178.2535 1260.9886 1256.103 1149.3105
##       [8,] 1352.3564 1121.3138 1184.295 1099.1386
##       [9,] 1054.9287 1052.3269 1090.964 1297.6090
##      [10,] 1151.5935 1164.1439 1140.220  977.0392
##      [11,] 1006.3409 1206.1782 1129.598 1297.3465
##      [12,] 1049.2263 1118.0162 1136.372 1097.9544
##      [13,] 1272.0999 1102.8730 1152.075 1260.3388
##      [14,] 1170.4072 1193.7097 1266.998 1184.7448
##      [15,] 1070.6106 1120.3130 1213.733 1253.5098
##      [16,] 1091.5215 1148.6841 1050.904 1034.6988
##      [17,] 1125.0222 1092.6740 1300.564 1266.2335
##      [18,] 1017.0880 1135.8451 1183.040 1086.5221
##      [19,] 1266.0713 1080.8330 1177.136 1180.9905
##      [20,] 1203.8639  969.3366 1066.955 1183.2624

上記の通り、各年齢ごと(列として)にMCMCサンプルが得られている列ごとにパーセンタイルを計算すれば、信用区間を算出できる。転置まで含めてpurrr::map_dfr()関数で書き直すことは可能ではあるが、apply()関数のほうが高速(たぶん)なので、書籍の通りapply()関数を使用する。

# purrr::map_dfr(.x = ms$y_new, .f = function(x){quantile(x, probs = c(0.025, 0.25, 0.50, 0.75, 0.975))})
qua <- apply(ms$y_new, 2, quantile, probs=c(0.025, 0.25, 0.50, 0.75, 0.975))
d_est <- data.frame(X = X_new, t(qua), check.names = FALSE)
d_est
##     X      2.5%       25%       50%       75%     97.5%
## 1  23  196.1084  323.1401  384.6072  445.1679  573.2650
## 2  24  209.3494  344.6997  406.7865  466.8781  595.2174
## 3  25  249.0284  368.2070  429.4956  487.5158  610.6214
## 4  26  267.2860  391.1700  449.5822  509.6548  632.7679
## 5  27  290.4473  415.0021  472.8636  529.8990  649.3803
## 6  28  310.9410  435.1684  493.3632  551.4630  669.3591
## 7  29  337.0813  455.0606  515.5190  573.3125  696.2371
## 8  30  356.8058  480.3806  539.2433  596.6382  726.6229
## 9  31  386.7507  500.3823  559.1365  621.7323  739.1078
## 10 32  394.5249  520.7448  579.8658  638.7421  756.0088
## 11 33  428.9380  549.8678  605.4362  660.7895  781.1420
## 12 34  450.6157  568.5087  624.6334  683.0308  802.7096
## 13 35  464.7578  587.1189  645.6336  705.6502  817.6153
## 14 36  487.8904  611.4111  668.3740  726.9089  850.6422
## 15 37  510.1696  631.2036  690.1135  748.4581  866.9594
## 16 38  538.2495  654.1524  710.8309  767.6460  875.4960
## 17 39  555.7743  676.9001  736.7438  794.9146  908.6932
## 18 40  585.5104  697.7168  756.2129  812.3078  937.3261
## 19 41  603.6863  721.5498  777.3452  836.3449  947.1644
## 20 42  620.7144  743.6247  798.5522  856.9957  980.5720
## 21 43  648.7306  764.6278  820.3401  876.5976  990.7906
## 22 44  665.7280  787.8738  844.9750  900.9996 1019.8217
## 23 45  690.5057  807.4399  865.7273  920.3477 1041.9983
## 24 46  713.4799  831.1948  887.1316  945.1339 1063.1338
## 25 47  739.5763  850.5364  906.3070  964.1591 1088.2991
## 26 48  757.6098  871.7799  929.3356  987.4714 1101.9738
## 27 49  771.4245  895.3892  954.8782 1012.8636 1129.8147
## 28 50  802.9857  915.2792  973.6994 1031.5326 1149.6666
## 29 51  821.2559  939.2511  994.4451 1056.6653 1175.9194
## 30 52  838.0974  961.0024 1017.2175 1074.3876 1188.3268
## 31 53  864.9028  985.3825 1040.7734 1096.5392 1215.0917
## 32 54  883.6646 1006.6777 1061.7447 1121.2343 1241.2938
## 33 55  892.1116 1027.0763 1083.3382 1142.7502 1265.4592
## 34 56  925.8846 1047.7235 1107.7652 1165.5273 1278.8964
## 35 57  945.3126 1069.6463 1131.6150 1189.7430 1301.0336
## 36 58  974.1271 1089.4398 1150.5051 1208.2641 1325.1875
## 37 59  987.6673 1110.8871 1171.6379 1232.5634 1347.5397
## 38 60 1010.9454 1135.8988 1196.5153 1256.0369 1374.8607

あとはこのデータをもとに作図すれば予測区間つきの散布図を作ることができる。

ggplot() +  
  theme_bw(base_size = 15) +
  geom_ribbon(data = d_est, aes(x = X, ymin = `2.5%`, ymax = `97.5%`), fill = 'black', alpha = 1/6) +
  geom_ribbon(data = d_est, aes(x = X, ymin = `25%` , ymax = `75%`), fill = 'black', alpha = 2/6) +
  geom_line(data = d_est, aes(x = X, y = `50%`), size = 1) +
  geom_point(data = d,     aes(x = X, y = Y), shape = 1, size = 3) +
  coord_cartesian(xlim = c(22, 61), ylim = c(200, 1400)) +
  scale_y_continuous(breaks = seq(from = 200, to = 1400, by = 400)) +
  labs(y = 'Y', title = 'Prediction intervals')

事前予測分布と事後予測分布

事前予測分布(PRIOR predictive distribution)と事後予測分布(POSTERIOR predictive distribution)への理解を深める。事後予測分布は下記の通り定義される。

事後予測分布

\[ \begin{eqnarray} p_{pred}(y|Y) = \int p(y|\theta) p(\theta|Y) d\theta \\ \end{eqnarray} \]

事後予測分布は以下のようにサンプリングできる。

    1. 事後分布\(p(\theta|Y)\)からサンプルを抽出する
    1. \(p(y|\theta)\)を使ってサンプルを生成する
n_samples <- 10000

# binom
n <- 10
x <- 6

# beta
a <- 2
b <- 2

# Draw sample from the posterior
posterior <- rbeta(n = n_samples, shape1 = x + a, shape2 = n - x + b)

# Generate data based on the prior samples
posterior_predictive <- rbinom(n_samples, size = n, prob = posterior)

ggplot(data.frame(posterior_predictive), aes(x = posterior_predictive)) +
  geom_histogram(binwidth = 1) + 
  scale_x_continuous(breaks = c(0:n)) + 
  labs(x = 'y', title = 'Posterior predictive distribution') +
  theme_bw()

事前予測分布はデータの周辺分布(marginal distribution of data)とも呼ばれ、下記の通り定義される。

事前予測分布

\[ \begin{eqnarray} p(y) = \int p(\theta)p(y | \theta) d\theta \end{eqnarray} \]

\(p(y)\)は、事前分布\(p(\theta)\)からサンプルを抽出し、このサンプルを用いて\(p(y|\theta)\)を生成することでシュミレートできる。

a <- 2
b <- 2

# Draw from the prior
prior_samples <- rbeta(n = n_samples, shape1 = a, shape2 = b)

# Generate data based on the prior samples
y <- rbinom(n_samples, size = 100, prob = prior_samples)

ggplot() +
  geom_histogram(data = data.frame(y), 
                 aes(x = y), 
                 binwidth = 1)