UPDATE: 2023-12-18 22:36:54.800651
このノートは「StanとRでベイズ統計モデリング」の内容を写経することで、ベイズ統計への理解を深めていくために作成している。
基本的には気になった部分を写経しながら、ところどころ自分用の補足をメモすることで、「StanとRでベイズ統計モデリング」を読み進めるための自分用の補足資料になることを目指す。私の解釈がおかしく、メモが誤っている場合があるので注意。
今回は第4章「StanとRStanをはじめよう」のチャプターを写経していく。
コードが出てくるのは第4章からではあるが、第2章の予測分布で気になる部分があったのでメモを追加しておく。予測分布は下記の通り、確率モデルと事後分布を積分することで計算される。
\[ \begin{eqnarray} p_{pred}(y|Y) = \int p(y|\theta) p(\theta|Y) d\theta \\ \end{eqnarray} \]
下記の書籍に書かれている通り、予測分布を得る場合は数千個得られるパラメタのMCMCサンプルを変えながら、確率モデルからの乱数を得ることになる。
より簡単に、事後分布\(p(\theta|Y)\)からのMCMCサンプルから値を1つ選び、それを\(\theta^{+}\)として確率モデル\(p(y|\theta^{+})\)に従う乱数\(y^{+}\)を生成することを繰り返し、\(y^{+}\)をたくさん生成することで予測分布\(p_{pred}(y|Y)\)からのMCMCサンプルとみなすことができる。
どちらが良いのかは分からないが、他の書籍やブログなどには、MCMCサンプルの値を変えず、パラメタの事後分布の点推定量を計算し、それを利用して確率モデルから乱数を得る「条件付き予測分布」という方法が記載されている場合もある。このあたりの理解が最初は追いつかず、Stanでgenerated quantitiesの内容と頭の理解が一致せず、困った経験が過去にあったので、再勉強するにあたりメモを残しておく。基本的には事後予測分布を使用し、条件付き予測分布は使わなければ問題はなさそう。
書籍に書かれている下記のモデル式4-1への理解を深めておく。
\[ \begin{eqnarray} Y[n] &\sim& Normal(\mu, 1) \tag{4.1}\\ \mu &\sim& Normal(0, 100) \tag{4.2} \end{eqnarray} \]
これは書籍にも書かれているが、「データ1つごとに平均\(\mu\)、標準偏差1の正規分布から独立に確率的に生成された」ことを意味する。書き下すと下記の通りで、\(\mu\)は何らかのスカラであり、その\(\mu\)をもつ正規分布から生成されていることを意味する。スカラを強調しているのは、階層ベイズモデルなどを筆頭に複雑なモデルを扱う時、グループごとに\(\mu\)が変化することもあるためである。
Y[1] ~ Normal(mu, 1) かつ
Y[2] ~ Normal(mu, 1) かつ
Y[3] ~ Normal(mu, 1) かつ
...
Y[n] ~ Normal(mu, 1)
これを略した記述がY[n] ~ normal(mu, 1)
Stanの文法で記述すると、下記に対応する。
model{
for (n in 1:N){
Y[n] ~ normal(mu, 1);
}
mu ~ normal(0, 100) # 事前分布
}
モデルが複雑になってくると、どういうこと?これはどう記述すればよいのか?など、疑問は湧き出てくるので、しっかり基礎の基礎を理解しておきたい・・・(戒め)。
下記4つのモデルは別のモデルではなく等価なので解釈を誤らないように注意が必要。モデル4-5式はデータ1人ごとに平均\(a+bX[n]\)、標準偏差\(\sigma\)の正規分布から独立に生成されたことを意味している。
どのモデルも平均値が\(X\)の値によって変わる条件つき正規分布で、すべての\(X\)において等分散を仮定しているモデルである。最終的にパラメタ\(a,b\)の事後分布を得たいものの、各モデル式をそのままStanのmodelブロックで定義すると異なる表記になるので注意が必要。
\[ \begin{eqnarray} Y[n] &=& y_{base}[n] + \epsilon[n] \\ y_{base}[n] &=& a + b X[n]\\ \epsilon[n] &\sim& Normal(0, \sigma) \tag{4.4} \end{eqnarray} \]
\[ \begin{eqnarray} Y[n] &=& a + b X[n] + \epsilon[n] \\ \epsilon[n] &\sim& Normal(0, \sigma) \end{eqnarray} \]
\[ \begin{eqnarray} y_{base}[n] &=& a + b X[n] \\ Y[n] &\sim& Normal(y_{base}[n], \sigma) \end{eqnarray} \]
\[ \begin{eqnarray} Y[n] &\sim& Normal(a + b X[n], \sigma) \end{eqnarray} \]
書籍内で使用されているデータは下記の通り。
library(ggplot2)
library(rstan)
d <- read.csv('https://raw.githubusercontent.com/MatsuuraKentaro/RStanBook/master/chap04/input/data-salary.txt')
data <- list(N = nrow(d), X = d$X, Y = d$Y)
head(d)
## X Y
## 1 24 472
## 2 24 403
## 3 26 454
## 4 32 575
## 5 33 546
## 6 35 781
まずはStanコードを作成する。別ファイルに記載しているコードを下記に転記している。
data {
int N;
real X[N];
real Y[N];
}
parameters {
real a;
real b;
real<lower=0> sigma;
}
model {
for (n in 1:N){
Y[n] ~ normal(a + b*X[n], sigma);
}
}
modelブロックで記載しているモデルは、さきほどからしつこく書いている通り「データ1人ごとに平均\(a+bX[n]\)、標準偏差\(\sigma\)の正規分布から独立に生成された」ことを意味しているモデルを想定している。つまり、書き下すと下記のようになる。
Y[1] ~ normal(a + b*X[1], sigma);
Y[2] ~ normal(a + b*X[2], sigma);
...
Y[N] ~ normal(a + b*X[N], sigma);
ここでは、stan_model()
関数で最初にコンパイルしておいてから、
sampling()
関数でサンプリングする。
推定結果は下記の通り。指定しない限りparameters
ブロック、transformed parameters
ブロック、generated quantities
ブロックで指定されたパラメタのMCMCサンプルをサンプリングする。つまり、どのようなパラメタがサンプリングされるかは、モデル式をどのように記述するのかに依存する。
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
## a -119.18 2.26 75.35 -267.06 -167.20 -118.41 -71.11 31.43 1110 1.00
## b 21.89 0.05 1.68 18.61 20.80 21.88 22.99 25.27 1155 1.00
## sigma 85.37 0.42 15.98 60.54 74.41 83.32 93.59 125.17 1418 1.01
## lp__ -93.68 0.04 1.35 -97.19 -94.27 -93.34 -92.70 -92.14 1074 1.00
##
## Samples were drawn using NUTS(diag_e) at Mon Dec 18 22:37:18 2023.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
分析目的に適う作図や分析結果を得るためには想定しているモデル構造、それをもとに記述されるモデル式への理解を深める必要がある。
ベイズ信頼区間、予測区間を算出するためにすることは、まず各パラメタのMCMCサンプルをextract()
関数で取り出すこと。MCMCサンプルの長さはchains * (iter - warmup)/thin
という関係にあるので、今回は4 * (2000 - 1000)/1 = 4000
となる。つまり、先程表示されていたパラメタのサマリーは、4000個のMCMCサンプルから計算されていることになる。
## List of 4
## $ a : num [1:4000(1d)] -126.1 -12.8 -229.8 -21.3 -96.5 ...
## ..- attr(*, "dimnames")=List of 1
## .. ..$ iterations: NULL
## $ b : num [1:4000(1d)] 22.1 19.3 24.6 20 21.4 ...
## ..- attr(*, "dimnames")=List of 1
## .. ..$ iterations: NULL
## $ sigma: num [1:4000(1d)] 89.8 87.6 89.1 94.4 70.8 ...
## ..- attr(*, "dimnames")=List of 1
## .. ..$ iterations: NULL
## $ lp__ : num [1:4000(1d)] -92.4 -93.6 -93.6 -93.6 -92.2 ...
## ..- attr(*, "dimnames")=List of 1
## .. ..$ iterations: NULL
下記は、各パラメタの1-10行目を取り出しているが、
## set a b sigma
## 1 set1 -126.09782 22.09596 89.79984
## 2 set2 -12.79003 19.29683 87.55702
## 3 set3 -229.82267 24.56804 89.05753
## 4 set4 -21.32934 20.00322 94.39494
## 5 set5 -96.54672 21.42506 70.78023
## 6 set6 -145.84620 22.16302 67.00061
## 7 set7 -81.37651 21.17869 73.67658
## 8 set8 -270.45381 25.08404 90.17393
## 9 set9 -73.37723 20.10361 122.57991
## 10 set10 -61.34379 20.86757 75.74603
これは\(p(a,b,\sigma|X,Y)\)という同時分布からサンプリングした1つの組み(=1行)であり、この組みが4000組得られている。この4000組を使って信用区間、予測区間を算出する。条件付き予測分布の例があったので、少し混乱するが、書籍では23-60歳の範囲での予測区間を算出する前に、丁寧に1時点での解説がされているので、大変イメージしやすい。下記は50歳時点での予測分布を構築するコードである。
ms <- rstan::extract(fit)
N_mcmc <- length(ms$lp__)
y50_base <- ms$a + ms$b * 50
y50 <- rnorm(n = N_mcmc, mean = y50_base, sd = ms$sigma)
d_mcmc <- data.frame(a = ms$a, b = ms$b, sigma = ms$sigma, y50_base, y50)
ms$a + ms$b * 50
の部分がまさに下記を表している。
より簡単に、事後分布\(p(\theta|Y)\)からのMCMCサンプルから値を1つ選び、それを\(\theta^{+}\)として確率モデル\(p(y|\theta^{+})\)に従う乱数\(y^{+}\)を生成することを繰り返し、\(y^{+}\)をたくさん生成することで予測分布\(p_{pred}(y|Y)\)からのMCMCサンプルとみなすことができる。
ms$a
もms$b
も4000個のベクトルであって、ベクトルの計算に年齢50
を乗じることで、4000個の平均ベクトルを得ている。1番目のパラメタの1組を取り出して平均を計算し、2番目のパラメタの1組を取り出して平均を計算し、これを4000回繰り返す。そして、この平均を使って正規分布からの乱数rnorm(n = N_mcmc, mean = y50_base, sd = ms$sigma)
を得ている。
rnorm()
関数の平均にベクトルを渡すというイメージは分かりにくいかもしれないが、変化する平均をもとに、乱数を生成する。
## [1] -0.2070657 10.2774292 101.0844412 997.6543023 10000.4291247
これが50歳時点での予測分布となるので、あとはこれを23-60歳の範囲で同じように繰り返すことで予測区間が算出される。
実際はStan側からデータを受け取って予測区間となるMCMCサンプルを計算するではなく、Stan側で計算させることができる。先程説明した、計算したいパラメタとStanのモデル式が対応しているとはこのことである。事後予測分布については、ノート末尾でも補足する。
まずはStanコードを作成する。別ファイルに記載しているコードを下記に転記している。
data {
int N;
real X[N];
real Y[N];
int N_new;
real X_new[N_new];
}
parameters {
real a;
real b;
real<lower=0> sigma;
}
transformed parameters {
real y_base[N];
for (n in 1:N)
y_base[n] = a + b*X[n];
}
model {
for (n in 1:N)
Y[n] ~ normal(y_base[n], sigma);
}
generated quantities {
real y_base_new[N_new];
real y_new[N_new];
for (n in 1:N_new) {
y_base_new[n] = a + b*X_new[n];
y_new[n] = normal_rng(y_base_new[n], sigma);
}
}
generated quantities
ブロックを追加して、予測分布を算出するための予測値のMCMCサンプルを計算させている。あとはデータを用意して、
モデルをコンパイルし、
サンプリングする。
サンプリング結果を確認すると、parameters
ブロック、transformed parameters
ブロック、generated quantities
ブロックで指定されたパラメタの事後分布が得られている。
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 50% 97.5% n_eff Rhat
## a -118.45 2.29 74.36 -265.51 -118.82 32.61 1054 1
## b 21.87 0.05 1.65 18.51 21.87 25.05 1060 1
## sigma 84.65 0.38 15.35 60.18 82.63 120.37 1605 1
## y_base[1] 406.43 1.09 37.50 332.27 406.18 481.53 1184 1
## y_base[2] 406.43 1.09 37.50 332.27 406.18 481.53 1184 1
## y_base[3] 450.16 0.99 34.70 381.40 449.98 520.46 1224 1
## y_base[4] 581.38 0.70 27.00 527.67 580.81 635.08 1468 1
## y_base[5] 603.25 0.66 25.87 551.42 602.62 654.21 1542 1
## y_base[6] 646.99 0.57 23.77 600.19 646.62 694.27 1748 1
## y_base[7] 712.60 0.45 21.22 670.25 712.56 754.61 2242 1
## y_base[8] 756.34 0.38 20.02 716.44 756.39 795.93 2815 1
## y_base[9] 756.34 0.38 20.02 716.44 756.39 795.93 2815 1
## y_base[10] 821.95 0.29 19.17 784.86 822.32 859.61 4273 1
## y_base[11] 821.95 0.29 19.17 784.86 822.32 859.61 4273 1
## y_base[12] 843.82 0.29 19.16 806.79 843.84 881.44 4399 1
## y_base[13] 931.30 0.32 20.52 890.96 931.38 972.68 4150 1
## y_base[14] 1018.78 0.46 23.70 972.24 1018.73 1066.09 2709 1
## y_base[15] 1106.26 0.63 28.10 1051.87 1106.30 1162.11 2014 1
## y_base[16] 1106.26 0.63 28.10 1051.87 1106.30 1162.11 2014 1
## y_base[17] 1128.13 0.67 29.33 1070.69 1127.99 1186.57 1906 1
## y_base[18] 1149.99 0.72 30.60 1090.55 1149.87 1211.04 1815 1
## y_base[19] 1171.86 0.77 31.91 1109.85 1171.58 1235.93 1738 1
## y_base[20] 1171.86 0.77 31.91 1109.85 1171.58 1235.93 1738 1
## y_base_new[1] 384.56 1.14 38.93 307.33 384.32 462.65 1168 1
## y_base_new[2] 406.43 1.09 37.50 332.27 406.18 481.53 1184 1
## y_base_new[3] 428.30 1.04 36.09 356.87 428.15 500.90 1203 1
## y_base_new[4] 450.16 0.99 34.70 381.40 449.98 520.46 1224 1
## y_base_new[5] 472.03 0.94 33.34 405.60 471.87 539.82 1249 1
## y_base_new[6] 493.90 0.89 32.00 429.96 493.75 559.08 1279 1
## y_base_new[7] 515.77 0.85 30.69 454.70 515.54 578.18 1314 1
## y_base_new[8] 537.64 0.80 29.42 479.41 537.20 597.18 1356 1
## y_base_new[9] 559.51 0.75 28.19 503.82 558.99 615.93 1407 1
## y_base_new[10] 581.38 0.70 27.00 527.67 580.81 635.08 1468 1
## y_base_new[11] 603.25 0.66 25.87 551.42 602.62 654.21 1542 1
## y_base_new[12] 625.12 0.61 24.79 575.88 624.56 674.32 1634 1
## y_base_new[13] 646.99 0.57 23.77 600.19 646.62 694.27 1748 1
## y_base_new[14] 668.86 0.53 22.83 623.64 668.48 713.89 1887 1
## y_base_new[15] 690.73 0.49 21.98 646.50 690.57 734.50 2046 1
## y_base_new[16] 712.60 0.45 21.22 670.25 712.56 754.61 2242 1
## y_base_new[17] 734.47 0.41 20.56 693.14 734.38 774.77 2484 1
## y_base_new[18] 756.34 0.38 20.02 716.44 756.39 795.93 2815 1
## y_base_new[19] 778.21 0.34 19.60 740.02 778.52 816.72 3272 1
## y_base_new[20] 800.08 0.31 19.32 762.43 800.62 837.92 3806 1
## y_base_new[21] 821.95 0.29 19.17 784.86 822.32 859.61 4273 1
## y_base_new[22] 843.82 0.29 19.16 806.79 843.84 881.44 4399 1
## y_base_new[23] 865.69 0.29 19.30 828.15 865.81 904.16 4447 1
## y_base_new[24] 887.56 0.29 19.58 849.83 887.69 926.93 4415 1
## y_base_new[25] 909.43 0.30 19.98 870.80 909.66 949.34 4312 1
## y_base_new[26] 931.30 0.32 20.52 890.96 931.38 972.68 4150 1
## y_base_new[27] 953.17 0.35 21.17 911.16 953.38 995.46 3717 1
## y_base_new[28] 975.04 0.38 21.92 931.58 974.98 1019.16 3316 1
## y_base_new[29] 996.91 0.42 22.77 951.95 996.86 1042.68 2970 1
## y_base_new[30] 1018.78 0.46 23.70 972.24 1018.73 1066.09 2709 1
## y_base_new[31] 1040.65 0.50 24.71 992.46 1040.68 1089.88 2490 1
## y_base_new[32] 1062.52 0.54 25.79 1012.34 1062.62 1113.69 2301 1
## y_base_new[33] 1084.39 0.58 26.92 1032.26 1084.45 1138.06 2144 1
## y_base_new[34] 1106.26 0.63 28.10 1051.87 1106.30 1162.11 2014 1
## y_base_new[35] 1128.13 0.67 29.33 1070.69 1127.99 1186.57 1906 1
## y_base_new[36] 1149.99 0.72 30.60 1090.55 1149.87 1211.04 1815 1
## y_base_new[37] 1171.86 0.77 31.91 1109.85 1171.58 1235.93 1738 1
## y_base_new[38] 1193.73 0.81 33.24 1129.18 1193.44 1260.49 1672 1
## y_new[1] 383.23 1.72 96.02 196.11 384.61 573.26 3108 1
## y_new[2] 406.39 1.73 96.45 209.35 406.79 595.22 3123 1
## y_new[3] 428.54 1.75 92.02 249.03 429.50 610.62 2763 1
## y_new[4] 450.39 1.70 92.20 267.29 449.58 632.77 2954 1
## y_new[5] 472.25 1.62 90.76 290.45 472.86 649.38 3146 1
## y_new[6] 493.06 1.57 90.16 310.94 493.36 669.36 3280 1
## y_new[7] 514.11 1.62 90.98 337.08 515.52 696.24 3164 1
## y_new[8] 538.95 1.59 91.77 356.81 539.24 726.62 3338 1
## y_new[9] 561.11 1.47 90.88 386.75 559.14 739.11 3835 1
## y_new[10] 579.53 1.53 90.36 394.52 579.87 756.01 3482 1
## y_new[11] 605.20 1.55 88.31 428.94 605.44 781.14 3245 1
## y_new[12] 626.40 1.47 89.35 450.62 624.63 802.71 3674 1
## y_new[13] 645.41 1.42 89.85 464.76 645.63 817.62 3997 1
## y_new[14] 669.25 1.57 90.03 487.89 668.37 850.64 3272 1
## y_new[15] 689.81 1.59 89.70 510.17 690.11 866.96 3175 1
## y_new[16] 710.00 1.36 86.73 538.25 710.83 875.50 4039 1
## y_new[17] 735.53 1.41 90.55 555.77 736.74 908.69 4137 1
## y_new[18] 756.10 1.51 87.69 585.51 756.21 937.33 3375 1
## y_new[19] 778.56 1.46 87.42 603.69 777.35 947.16 3577 1
## y_new[20] 799.39 1.52 89.61 620.71 798.55 980.57 3472 1
## y_new[21] 820.37 1.33 86.72 648.73 820.34 990.79 4229 1
## y_new[22] 844.38 1.40 88.92 665.73 844.98 1019.82 4024 1
## y_new[23] 865.18 1.36 88.05 690.51 865.73 1042.00 4162 1
## y_new[24] 887.60 1.38 87.29 713.48 887.13 1063.13 4017 1
## y_new[25] 908.24 1.39 87.96 739.58 906.31 1088.30 3992 1
## y_new[26] 929.76 1.39 87.58 757.61 929.34 1101.97 3954 1
## y_new[27] 954.01 1.45 89.35 771.42 954.88 1129.81 3807 1
## y_new[28] 974.16 1.45 88.38 802.99 973.70 1149.67 3731 1
## y_new[29] 997.35 1.44 90.18 821.26 994.45 1175.92 3910 1
## y_new[30] 1017.07 1.38 87.34 838.10 1017.22 1188.33 4024 1
## y_new[31] 1041.04 1.55 87.55 864.90 1040.77 1215.09 3194 1
## y_new[32] 1063.09 1.46 89.44 883.66 1061.74 1241.29 3731 1
## y_new[33] 1083.72 1.58 91.51 892.11 1083.34 1265.46 3375 1
## y_new[34] 1106.51 1.51 89.52 925.88 1107.77 1278.90 3513 1
## y_new[35] 1128.90 1.45 90.51 945.31 1131.62 1301.03 3883 1
## y_new[36] 1149.61 1.50 89.46 974.13 1150.51 1325.19 3567 1
## y_new[37] 1170.58 1.59 91.77 987.67 1171.64 1347.54 3327 1
## y_new[38] 1195.27 1.47 91.73 1010.95 1196.52 1374.86 3917 1
## lp__ -93.65 0.03 1.26 -96.88 -93.34 -92.16 1452 1
##
## Samples were drawn using NUTS(diag_e) at Mon Dec 18 22:37:38 2023.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
ここから予測分布を作成していくが、まずはデータがどのように保存されているのか確認する。信用区間を算出したければms$y_new
をms$y_base_new
に読み替えればOK。
## num [1:4000, 1:38] 461 576 366 599 372 ...
## - attr(*, "dimnames")=List of 2
## ..$ iterations: NULL
## ..$ : NULL
予測区間を算出するために必要なのはms$y_new
であるが、matrix
クラスで保存されている。先程のa,b
はベクトルだったが、今回はmatrix
クラスで保存されている。これは23−60歳までのMCMCサンプルを計算しているためである。つまりmcmcサンプルの長さ × forloopのループ回数
という形式でmatrix
が保存されていることを意味する。
##
## iterations [,1] [,2] [,3] [,4] [,5] [,6] [,7]
## [1,] 460.8634 437.9357 518.3424 481.0554 477.2357 534.8487 390.8750
## [2,] 575.9002 443.6138 462.6731 480.1184 385.1748 508.4724 693.3259
## [3,] 366.1943 458.1170 385.2195 339.1653 341.4010 549.2674 658.2847
## [4,] 599.4339 496.1026 464.5705 285.4239 515.7201 469.4772 473.4473
## [5,] 372.0720 372.9459 428.0292 375.9517 544.5555 416.6049 417.3597
## [6,] 403.7616 424.9994 434.3494 287.1179 481.5824 359.6440 450.4261
## [7,] 356.3242 376.9793 359.0668 337.1274 486.7946 548.2675 325.5847
## [8,] 146.6523 521.4328 358.1214 267.9913 256.7348 459.1277 477.2827
## [9,] 345.8380 466.1090 530.5797 330.2329 466.2934 421.6227 344.9762
## [10,] 325.6443 621.3337 389.4070 475.1688 325.2862 627.9302 695.1057
## [11,] 517.3295 540.6945 559.2273 449.3653 327.9051 461.0802 417.7049
## [12,] 290.1946 386.1028 317.2166 452.6001 460.1248 486.7724 404.2192
## [13,] 258.0210 342.0837 249.0349 496.2835 336.8860 478.1819 613.6761
## [14,] 334.7587 385.1673 411.1875 458.6136 337.4290 545.6124 591.9120
## [15,] 389.0858 535.0654 486.2826 389.5704 382.5474 474.3361 659.8875
## [16,] 636.7629 530.6810 464.4844 553.0752 658.3440 523.2618 694.3509
## [17,] 375.9961 335.5095 396.1599 407.3054 370.2821 468.1920 657.3317
## [18,] 422.5172 571.5171 465.8221 558.1021 582.3216 362.8164 581.5894
## [19,] 324.4703 318.0017 421.6565 413.1590 393.6573 412.7902 542.7472
## [20,] 311.3647 428.5097 535.2399 576.4163 522.2956 528.2646 501.4678
##
## iterations [,8] [,9] [,10] [,11] [,12] [,13] [,14]
## [1,] 594.9074 455.1513 618.9944 656.6739 447.4682 718.7593 654.7729
## [2,] 589.8882 558.8310 611.4708 628.5700 679.5158 774.6877 593.4986
## [3,] 443.8143 559.2858 337.3706 332.3432 576.0598 666.7836 735.1523
## [4,] 627.8903 528.3965 336.4917 454.7115 433.9457 548.0407 582.7085
## [5,] 612.9106 516.9425 581.3942 626.8271 602.8598 651.1933 573.4231
## [6,] 439.9647 698.7475 571.1804 643.0411 730.6044 770.0660 618.7351
## [7,] 538.7477 589.4282 667.7911 505.3941 707.2915 591.5460 673.9480
## [8,] 472.5433 377.1111 529.3654 513.9780 575.2311 522.0048 539.9877
## [9,] 472.4539 488.6546 571.1610 522.2566 874.2138 538.8533 686.0758
## [10,] 654.8062 553.7593 602.7138 484.8147 777.9759 650.8207 679.4669
## [11,] 515.9074 404.5850 690.9860 529.4769 634.4511 802.4316 651.0564
## [12,] 470.7328 462.4233 507.6061 681.5549 581.1888 757.6221 698.5235
## [13,] 569.3603 547.7978 560.4178 623.9338 648.6549 590.9561 719.5660
## [14,] 563.4991 507.2462 615.6941 704.8850 568.5192 740.8673 670.0796
## [15,] 446.6293 482.9956 490.4006 671.5579 580.4193 599.2723 617.1431
## [16,] 593.7282 598.4804 590.4795 805.9819 532.4109 625.1921 702.4869
## [17,] 456.2713 460.7779 506.5233 507.5420 743.4607 552.3405 632.5317
## [18,] 514.8807 464.3147 643.7972 633.8365 654.2452 676.7209 644.1947
## [19,] 536.2386 475.7015 533.0859 709.4285 559.1847 633.5689 689.1405
## [20,] 357.6691 480.5350 641.7019 671.7305 759.2330 544.3446 695.6023
##
## iterations [,15] [,16] [,17] [,18] [,19] [,20] [,21]
## [1,] 634.6249 648.5844 687.0959 805.2507 642.1236 903.6075 807.4724
## [2,] 669.6236 810.6379 827.8528 664.9283 879.0198 765.4336 913.0720
## [3,] 503.9478 452.6813 634.8068 757.8929 702.5623 715.3880 895.2918
## [4,] 791.9150 790.7317 718.9376 665.9947 717.8696 801.4429 661.9766
## [5,] 628.1306 761.4416 672.5884 813.2941 774.6610 726.3557 761.7498
## [6,] 794.2247 714.5360 691.8072 812.8053 862.2054 874.3026 759.0869
## [7,] 708.0689 714.9211 746.5398 746.4316 806.8987 937.6288 805.2882
## [8,] 701.9692 655.7353 743.8688 702.4851 905.9119 811.4945 790.9721
## [9,] 612.5157 702.8137 580.3694 833.6936 870.1549 820.1008 798.0600
## [10,] 851.5035 678.2560 694.0315 775.3822 713.3565 1035.4219 1102.2633
## [11,] 786.4316 923.4823 862.1359 839.6637 607.0796 930.9741 783.0660
## [12,] 653.4352 732.6550 794.0398 740.6847 813.2789 687.9116 701.8360
## [13,] 690.5611 739.8178 786.9390 787.5250 808.5921 859.1639 824.1926
## [14,] 787.1457 719.5702 645.8272 732.9679 750.9665 690.1316 928.4002
## [15,] 658.9345 781.7945 811.0267 677.0498 809.9895 842.0706 713.2065
## [16,] 648.7334 773.8109 649.8857 821.6310 841.6855 876.3340 766.0753
## [17,] 649.9147 827.7006 845.1821 696.5604 797.2429 896.2624 830.7148
## [18,] 792.3552 750.7065 713.4160 702.4390 814.9879 738.1243 764.5461
## [19,] 877.1512 708.7832 794.2238 767.9896 820.4353 888.3139 793.9683
## [20,] 628.2958 765.5055 819.2150 666.4196 809.0813 893.6092 999.0721
##
## iterations [,22] [,23] [,24] [,25] [,26] [,27] [,28]
## [1,] 800.6874 886.0923 849.6370 917.4460 866.7207 840.9107 846.2594
## [2,] 718.6075 852.3272 969.4953 1022.4456 939.7496 987.8303 1049.1033
## [3,] 991.6738 1014.9260 729.0990 999.3453 909.5472 905.7015 965.9020
## [4,] 940.5345 605.9063 910.9429 959.5390 716.6045 1125.1973 948.5425
## [5,] 825.6660 965.4303 860.6505 913.6277 1025.7633 1019.5499 944.0361
## [6,] 918.2497 820.3603 853.6841 873.9950 742.7963 995.9853 851.9542
## [7,] 814.7578 883.2623 914.4702 936.6025 779.7538 984.0825 1012.4141
## [8,] 776.0786 826.6984 888.8652 833.2019 926.9280 950.0061 1091.4938
## [9,] 843.8269 960.3566 810.9016 801.1478 804.5326 1031.6536 1037.0339
## [10,] 753.9680 751.7611 922.0987 1009.6970 825.3445 893.8568 1020.4597
## [11,] 699.5343 885.6054 940.8069 722.9219 1068.1142 959.8222 1040.9228
## [12,] 863.0833 811.8399 893.4901 841.8906 881.5025 954.8141 1004.2856
## [13,] 898.3977 816.0073 930.8448 923.1194 870.4447 1004.3028 870.3332
## [14,] 818.1159 996.8225 838.2550 842.6895 905.5167 1018.3691 1083.6420
## [15,] 935.4024 807.3129 991.6398 772.0339 955.5630 898.0924 1080.0350
## [16,] 760.6484 794.6000 875.0895 1022.0439 1004.2797 955.5873 831.1263
## [17,] 724.9403 902.7757 831.0406 974.2944 1114.6682 990.2097 1042.6495
## [18,] 880.0326 756.5383 916.2503 786.4850 710.6524 914.7172 1009.2806
## [19,] 841.2719 833.1613 909.7609 850.8728 883.9677 1005.0176 977.4049
## [20,] 953.8126 776.7893 862.6911 830.0702 843.2086 950.0937 999.4058
##
## iterations [,29] [,30] [,31] [,32] [,33] [,34]
## [1,] 912.6560 1201.7413 958.3079 1163.4697 1020.4617 949.6320
## [2,] 894.8990 969.2003 903.3676 1023.1478 1037.9467 1136.3692
## [3,] 1083.7358 982.5188 1041.1794 748.0982 1007.1747 981.2476
## [4,] 1002.9557 949.2275 994.6867 1146.5347 1081.1884 1060.7675
## [5,] 1050.5162 964.5730 1014.3345 1062.2732 1167.1582 1150.3578
## [6,] 971.3147 1025.1822 939.6040 1118.9930 1179.1470 1171.9391
## [7,] 1034.5410 1063.9443 1141.4158 1187.8655 1128.4380 1111.7770
## [8,] 1063.0011 1003.7846 969.3867 941.9058 1096.8417 1014.9422
## [9,] 922.6610 968.8916 910.9271 1041.9371 955.0561 1001.0377
## [10,] 870.8395 1011.8294 939.2495 1101.5663 1244.1355 1100.6073
## [11,] 1058.6339 931.9229 930.9247 1165.4494 1020.8268 1091.8252
## [12,] 999.9197 973.6207 1004.8199 1004.2641 1106.5297 1099.3658
## [13,] 930.5405 982.4316 1022.2602 1139.5438 1152.2025 1008.7925
## [14,] 974.9086 975.1867 1022.0142 1088.5308 1222.0376 1048.9858
## [15,] 981.7409 878.2633 1023.7982 920.0489 927.5012 998.6094
## [16,] 997.2168 1078.0214 920.6720 1093.9006 923.3353 1233.5923
## [17,] 948.9044 1163.0407 1029.8391 1041.2351 1045.0430 1188.6281
## [18,] 1028.1171 947.5524 1072.2639 948.5490 1106.4383 1149.0600
## [19,] 1092.6323 939.1729 1005.8183 1115.0054 1057.8143 1182.0997
## [20,] 1007.2728 1071.3887 1003.9543 977.8216 1088.5528 1211.9842
##
## iterations [,35] [,36] [,37] [,38]
## [1,] 1174.4669 1210.2661 1063.009 1266.0082
## [2,] 1035.0309 1199.3672 1224.920 1085.3430
## [3,] 958.5278 1080.5149 1140.509 1251.8435
## [4,] 1205.5529 1378.3699 1245.887 1399.2663
## [5,] 1300.0677 1028.6940 1212.456 1323.7375
## [6,] 1185.9207 953.2947 1015.985 1145.6515
## [7,] 1178.2535 1260.9886 1256.103 1149.3105
## [8,] 1352.3564 1121.3138 1184.295 1099.1386
## [9,] 1054.9287 1052.3269 1090.964 1297.6090
## [10,] 1151.5935 1164.1439 1140.220 977.0392
## [11,] 1006.3409 1206.1782 1129.598 1297.3465
## [12,] 1049.2263 1118.0162 1136.372 1097.9544
## [13,] 1272.0999 1102.8730 1152.075 1260.3388
## [14,] 1170.4072 1193.7097 1266.998 1184.7448
## [15,] 1070.6106 1120.3130 1213.733 1253.5098
## [16,] 1091.5215 1148.6841 1050.904 1034.6988
## [17,] 1125.0222 1092.6740 1300.564 1266.2335
## [18,] 1017.0880 1135.8451 1183.040 1086.5221
## [19,] 1266.0713 1080.8330 1177.136 1180.9905
## [20,] 1203.8639 969.3366 1066.955 1183.2624
上記の通り、各年齢ごと(列として)にMCMCサンプルが得られている列ごとにパーセンタイルを計算すれば、信用区間を算出できる。転置まで含めてpurrr::map_dfr()
関数で書き直すことは可能ではあるが、apply()
関数のほうが高速(たぶん)なので、書籍の通りapply()
関数を使用する。
# purrr::map_dfr(.x = ms$y_new, .f = function(x){quantile(x, probs = c(0.025, 0.25, 0.50, 0.75, 0.975))})
qua <- apply(ms$y_new, 2, quantile, probs=c(0.025, 0.25, 0.50, 0.75, 0.975))
d_est <- data.frame(X = X_new, t(qua), check.names = FALSE)
d_est
## X 2.5% 25% 50% 75% 97.5%
## 1 23 196.1084 323.1401 384.6072 445.1679 573.2650
## 2 24 209.3494 344.6997 406.7865 466.8781 595.2174
## 3 25 249.0284 368.2070 429.4956 487.5158 610.6214
## 4 26 267.2860 391.1700 449.5822 509.6548 632.7679
## 5 27 290.4473 415.0021 472.8636 529.8990 649.3803
## 6 28 310.9410 435.1684 493.3632 551.4630 669.3591
## 7 29 337.0813 455.0606 515.5190 573.3125 696.2371
## 8 30 356.8058 480.3806 539.2433 596.6382 726.6229
## 9 31 386.7507 500.3823 559.1365 621.7323 739.1078
## 10 32 394.5249 520.7448 579.8658 638.7421 756.0088
## 11 33 428.9380 549.8678 605.4362 660.7895 781.1420
## 12 34 450.6157 568.5087 624.6334 683.0308 802.7096
## 13 35 464.7578 587.1189 645.6336 705.6502 817.6153
## 14 36 487.8904 611.4111 668.3740 726.9089 850.6422
## 15 37 510.1696 631.2036 690.1135 748.4581 866.9594
## 16 38 538.2495 654.1524 710.8309 767.6460 875.4960
## 17 39 555.7743 676.9001 736.7438 794.9146 908.6932
## 18 40 585.5104 697.7168 756.2129 812.3078 937.3261
## 19 41 603.6863 721.5498 777.3452 836.3449 947.1644
## 20 42 620.7144 743.6247 798.5522 856.9957 980.5720
## 21 43 648.7306 764.6278 820.3401 876.5976 990.7906
## 22 44 665.7280 787.8738 844.9750 900.9996 1019.8217
## 23 45 690.5057 807.4399 865.7273 920.3477 1041.9983
## 24 46 713.4799 831.1948 887.1316 945.1339 1063.1338
## 25 47 739.5763 850.5364 906.3070 964.1591 1088.2991
## 26 48 757.6098 871.7799 929.3356 987.4714 1101.9738
## 27 49 771.4245 895.3892 954.8782 1012.8636 1129.8147
## 28 50 802.9857 915.2792 973.6994 1031.5326 1149.6666
## 29 51 821.2559 939.2511 994.4451 1056.6653 1175.9194
## 30 52 838.0974 961.0024 1017.2175 1074.3876 1188.3268
## 31 53 864.9028 985.3825 1040.7734 1096.5392 1215.0917
## 32 54 883.6646 1006.6777 1061.7447 1121.2343 1241.2938
## 33 55 892.1116 1027.0763 1083.3382 1142.7502 1265.4592
## 34 56 925.8846 1047.7235 1107.7652 1165.5273 1278.8964
## 35 57 945.3126 1069.6463 1131.6150 1189.7430 1301.0336
## 36 58 974.1271 1089.4398 1150.5051 1208.2641 1325.1875
## 37 59 987.6673 1110.8871 1171.6379 1232.5634 1347.5397
## 38 60 1010.9454 1135.8988 1196.5153 1256.0369 1374.8607
あとはこのデータをもとに作図すれば予測区間つきの散布図を作ることができる。
ggplot() +
theme_bw(base_size = 15) +
geom_ribbon(data = d_est, aes(x = X, ymin = `2.5%`, ymax = `97.5%`), fill = 'black', alpha = 1/6) +
geom_ribbon(data = d_est, aes(x = X, ymin = `25%` , ymax = `75%`), fill = 'black', alpha = 2/6) +
geom_line(data = d_est, aes(x = X, y = `50%`), size = 1) +
geom_point(data = d, aes(x = X, y = Y), shape = 1, size = 3) +
coord_cartesian(xlim = c(22, 61), ylim = c(200, 1400)) +
scale_y_continuous(breaks = seq(from = 200, to = 1400, by = 400)) +
labs(y = 'Y', title = 'Prediction intervals')
事前予測分布(PRIOR predictive distribution)と事後予測分布(POSTERIOR predictive distribution)への理解を深める。事後予測分布は下記の通り定義される。
\[ \begin{eqnarray} p_{pred}(y|Y) = \int p(y|\theta) p(\theta|Y) d\theta \\ \end{eqnarray} \]
事後予測分布は以下のようにサンプリングできる。
n_samples <- 10000
# binom
n <- 10
x <- 6
# beta
a <- 2
b <- 2
# Draw sample from the posterior
posterior <- rbeta(n = n_samples, shape1 = x + a, shape2 = n - x + b)
# Generate data based on the prior samples
posterior_predictive <- rbinom(n_samples, size = n, prob = posterior)
ggplot(data.frame(posterior_predictive), aes(x = posterior_predictive)) +
geom_histogram(binwidth = 1) +
scale_x_continuous(breaks = c(0:n)) +
labs(x = 'y', title = 'Posterior predictive distribution') +
theme_bw()
事前予測分布はデータの周辺分布(marginal distribution of data)とも呼ばれ、下記の通り定義される。
\[ \begin{eqnarray} p(y) = \int p(\theta)p(y | \theta) d\theta \end{eqnarray} \]
\(p(y)\)は、事前分布\(p(\theta)\)からサンプルを抽出し、このサンプルを用いて\(p(y|\theta)\)を生成することでシュミレートできる。
a <- 2
b <- 2
# Draw from the prior
prior_samples <- rbeta(n = n_samples, shape1 = a, shape2 = b)
# Generate data based on the prior samples
y <- rbinom(n_samples, size = 100, prob = prior_samples)
ggplot() +
geom_histogram(data = data.frame(y),
aes(x = y),
binwidth = 1)