UPDATE: 2023-12-14 17:08:08.099969
このノートは「StanとRでベイズ統計モデリング」の内容を写経することで、ベイズ統計への理解を深めていくために作成している。
基本的には気になった部分を写経しながら、ところどころ自分用の補足をメモすることで、「StanとRでベイズ統計モデリング」を読み進めるための自分用の補足資料になることを目指す。私の解釈がおかしく、メモが誤っている場合があるので注意。
今回は第5章「基本的な回帰とモデルのチェック」のチャプターの前半を写経していく。
第5章では重回帰分析から一般化線形モデルを扱う章となっている。まずは重回帰分析から前回同様、気になった点をメモしていく。この章でも複数の等価なモデル式が下記の通り、紹介されている。どのモデルであっても、データ1人ごとに平均\(\mu[n] = b_{1} + b_{2}A[n] + b_{3}Score[n]\)、標準偏差\(\sigma\)の正規分布から独立に生成されていると仮定している。
\[ \begin{eqnarray} Y[n] &=& b_{1} + b_{2}A[n] + b_{3}Score[n] + \epsilon[n] \\ \epsilon[n] &\sim& Normal(0, \sigma) \end{eqnarray} \]
\[ \begin{eqnarray} Y[n] &=& Normal(b_{1} + b_{2}A[n] + b_{3}Score[n], \sigma) \end{eqnarray} \]
以降のページでは下記のモデル式が採用される。下記のモデル式はtransformed parameters
ブロックで線形結合の部分を\(\mu[n]\)に変換する必要がある。つまり、事後分布として\(\mu[n]\)が得られるということでもある。
\[ \begin{eqnarray} \mu[n] &=& b_{1} + b_{2}A[n] + b_{3}Score[n] \\ Y[n] &\sim& Normal(\mu[n], \sigma) \end{eqnarray} \]
モデル5-3をStanで実装した例は下記の通り。
data {
int N;
int<lower=0, upper=1> A[N];
real<lower=0, upper=1> Score[N];
real<lower=0, upper=1> Y[N];
}
parameters {
real b1;
real b2;
real b3;
real<lower=0> sigma;
}
transformed parameters {
real mu[N];
for (n in 1:N)
mu[n] = b1 + b2*A[n] + b3*Score[n];
}
model {
for (n in 1:N)
Y[n] ~ normal(mu[n], sigma);
}
generated quantities {
real y_pred[N];
real noise[N];
for (n in 1:N){
y_pred[n] = normal_rng(mu[n], sigma);
noise[n] = Y[n] - mu[n];
}
}
パラパラと読み進めていて、気になったのは下記の部分。
横軸に「n番目の学生における実測値Y[n]」、縦軸に「その学生の説明変数の値から算出されるYの予測分布の中央値と区間」をとった図を書くことにした(後述の図5.3)。そこで、Yの予測分布(からのMCMCサンプル)を得るために、generated quantitiesブロック平均mu[n]、標準偏差sigmaの正規分布からの乱数を発生させてy_predp[n]に代入している。
機械学習で予測モデルを構築した際に、汎化性能を可視化する方法の1つとして、観測値と予測値の散布図を作ることはある。ただ、この説明を最初読んだときに、ベイズの世界への理解の甘さから「\(Y\)の予測分布の中央値と区間」の部分の理解が曖昧だった。しかし、よくよく考えると、ベイズの世界では、generated quantities
ブロックを使うことで、学生ごとの出席率(\(Y\))の事後分布を得ることができる。つまり、\(Y\)の予測分布の中央値と区間が計算できるため、機械学習で予測モデルの汎化性能を表現する以上の情報(\(Y\)軸において)を持つ作図が可能となる。
(この段落の挙動に関する内容は推測である)generated quantities
ブロックでは、まずmu[n]
が学生数\(N\)の数だけ計算される。例えばmu[1]
であっても数千個のMCMCサンプルとして得られているので、その数千個の値を持つmu[1]
ベクトルがnormal_rng(mu[1], sigma)
として渡される。そして、数千個の値を持つy_pred[1]
が得られる。これが学生の数\(N\)分、ループして処理され、y_pred[n]
を\(N\)個得ることになる。
generated quantities {
real y_pred[N];
real noise[N];
for (n in 1:N){
y_pred[n] = normal_rng(mu[n], sigma);
noise[n] = Y[n] - mu[n];
}
}
Stanモデルの定義も終わっているので、書籍内で使用されているデータを利用して分析を実行する。
# パラメタの事後分布をすべて表示するため
options(max.print = 999999)
library(dplyr)
library(ggplot2)
library(GGally)
library(hexbin)
library(rstan)
d <- read.csv('https://raw.githubusercontent.com/MatsuuraKentaro/RStanBook/master/chap05/input/data-attendance-1.txt')
data <- list(N = nrow(d), A = d$A, Score = d$Score/200, Y = d$Y)
head(d)
## A Score Y
## 1 0 69 0.286
## 2 1 145 0.196
## 3 0 125 0.261
## 4 1 86 0.109
## 5 1 158 0.230
## 6 0 133 0.350
ここでは、stan_model()
関数で最初にコンパイルしておいてから、
sampling()
関数でサンプリングする。
推定結果は下記の通り。
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
## b1 0.12 0.00 0.03 0.06 0.10 0.12 0.15 0.19 1716 1
## b2 -0.14 0.00 0.01 -0.17 -0.15 -0.14 -0.13 -0.12 2622 1
## b3 0.32 0.00 0.05 0.22 0.29 0.32 0.36 0.43 1761 1
## sigma 0.05 0.00 0.01 0.04 0.05 0.05 0.06 0.06 2676 1
## mu[1] 0.24 0.00 0.02 0.20 0.22 0.24 0.25 0.27 1817 1
## mu[2] 0.22 0.00 0.01 0.19 0.21 0.22 0.22 0.24 2648 1
## mu[3] 0.33 0.00 0.01 0.31 0.32 0.33 0.33 0.35 3241 1
## mu[4] 0.12 0.00 0.01 0.09 0.11 0.12 0.13 0.15 2454 1
## mu[5] 0.24 0.00 0.01 0.21 0.23 0.24 0.25 0.26 2360 1
## mu[6] 0.34 0.00 0.01 0.32 0.33 0.34 0.35 0.36 3178 1
## mu[7] 0.30 0.00 0.01 0.28 0.30 0.30 0.31 0.32 2821 1
## mu[8] 0.22 0.00 0.01 0.19 0.21 0.22 0.23 0.24 2597 1
## mu[9] 0.36 0.00 0.01 0.34 0.35 0.36 0.37 0.38 2801 1
## mu[10] 0.36 0.00 0.01 0.34 0.35 0.36 0.37 0.38 2832 1
## mu[11] 0.21 0.00 0.01 0.18 0.20 0.21 0.22 0.23 2755 1
## mu[12] 0.35 0.00 0.01 0.33 0.34 0.35 0.35 0.37 3077 1
## mu[13] 0.17 0.00 0.01 0.15 0.16 0.17 0.18 0.19 3123 1
## mu[14] 0.30 0.00 0.01 0.28 0.30 0.30 0.31 0.32 2821 1
## mu[15] 0.30 0.00 0.01 0.28 0.29 0.30 0.31 0.32 2739 1
## mu[16] 0.14 0.00 0.01 0.12 0.13 0.14 0.15 0.17 2746 1
## mu[17] 0.31 0.00 0.01 0.29 0.30 0.31 0.31 0.33 2944 1
## mu[18] 0.26 0.00 0.01 0.23 0.25 0.26 0.27 0.28 1910 1
## mu[19] 0.42 0.00 0.02 0.39 0.41 0.42 0.44 0.46 2140 1
## mu[20] 0.23 0.00 0.01 0.20 0.22 0.23 0.24 0.26 2397 1
## mu[21] 0.12 0.00 0.01 0.09 0.11 0.12 0.13 0.15 2454 1
## mu[22] 0.16 0.00 0.01 0.13 0.15 0.16 0.16 0.18 2969 1
## mu[23] 0.15 0.00 0.01 0.13 0.14 0.15 0.16 0.18 2922 1
## mu[24] 0.21 0.00 0.01 0.19 0.20 0.21 0.22 0.24 2701 1
## mu[25] 0.17 0.00 0.01 0.15 0.16 0.17 0.18 0.19 3116 1
## mu[26] 0.19 0.00 0.01 0.16 0.18 0.19 0.20 0.21 3064 1
## mu[27] 0.32 0.00 0.01 0.30 0.31 0.32 0.32 0.34 3123 1
## mu[28] 0.32 0.00 0.01 0.30 0.31 0.32 0.32 0.34 3123 1
## mu[29] 0.38 0.00 0.01 0.36 0.38 0.38 0.39 0.41 2434 1
## mu[30] 0.31 0.00 0.01 0.29 0.30 0.31 0.31 0.33 2903 1
## mu[31] 0.25 0.00 0.02 0.22 0.24 0.25 0.26 0.28 2236 1
## mu[32] 0.10 0.00 0.02 0.07 0.09 0.10 0.11 0.13 2260 1
## mu[33] 0.20 0.00 0.01 0.18 0.20 0.20 0.21 0.23 2836 1
## mu[34] 0.18 0.00 0.01 0.16 0.17 0.18 0.19 0.20 3115 1
## mu[35] 0.33 0.00 0.01 0.31 0.32 0.33 0.33 0.35 3247 1
## mu[36] 0.34 0.00 0.01 0.32 0.33 0.34 0.34 0.36 3215 1
## mu[37] 0.15 0.00 0.01 0.13 0.14 0.15 0.16 0.17 2898 1
## mu[38] 0.30 0.00 0.01 0.28 0.30 0.30 0.31 0.32 2780 1
## mu[39] 0.27 0.00 0.01 0.24 0.26 0.27 0.28 0.29 2011 1
## mu[40] 0.27 0.00 0.01 0.24 0.26 0.27 0.28 0.29 1990 1
## mu[41] 0.33 0.00 0.01 0.31 0.33 0.33 0.34 0.35 3239 1
## mu[42] 0.34 0.00 0.01 0.32 0.33 0.34 0.35 0.36 3178 1
## mu[43] 0.32 0.00 0.01 0.30 0.32 0.32 0.33 0.34 3199 1
## mu[44] 0.36 0.00 0.01 0.34 0.36 0.36 0.37 0.39 2742 1
## mu[45] 0.42 0.00 0.02 0.38 0.41 0.42 0.43 0.46 2164 1
## mu[46] 0.29 0.00 0.01 0.27 0.29 0.29 0.30 0.31 2583 1
## mu[47] 0.21 0.00 0.02 0.17 0.19 0.21 0.22 0.25 1755 1
## mu[48] 0.37 0.00 0.01 0.34 0.36 0.37 0.38 0.39 2686 1
## mu[49] 0.28 0.00 0.01 0.26 0.28 0.28 0.29 0.31 2302 1
## mu[50] 0.14 0.00 0.01 0.12 0.13 0.14 0.15 0.17 2746 1
## y_pred[1] 0.24 0.00 0.05 0.12 0.20 0.24 0.27 0.34 3519 1
## y_pred[2] 0.22 0.00 0.05 0.11 0.18 0.22 0.25 0.32 4060 1
## y_pred[3] 0.32 0.00 0.05 0.22 0.29 0.33 0.36 0.42 3702 1
## y_pred[4] 0.12 0.00 0.05 0.01 0.08 0.12 0.16 0.22 3725 1
## y_pred[5] 0.24 0.00 0.05 0.13 0.20 0.24 0.27 0.35 3644 1
## y_pred[6] 0.34 0.00 0.05 0.23 0.30 0.34 0.37 0.44 3896 1
## y_pred[7] 0.30 0.00 0.05 0.19 0.27 0.30 0.34 0.41 3664 1
## y_pred[8] 0.22 0.00 0.05 0.11 0.18 0.22 0.25 0.32 3945 1
## y_pred[9] 0.36 0.00 0.05 0.25 0.33 0.36 0.40 0.46 3709 1
## y_pred[10] 0.36 0.00 0.05 0.26 0.32 0.36 0.39 0.46 3689 1
## y_pred[11] 0.21 0.00 0.05 0.11 0.17 0.21 0.24 0.32 3984 1
## y_pred[12] 0.35 0.00 0.05 0.25 0.31 0.35 0.38 0.45 4118 1
## y_pred[13] 0.17 0.00 0.05 0.07 0.14 0.17 0.21 0.28 3800 1
## y_pred[14] 0.30 0.00 0.05 0.20 0.27 0.30 0.34 0.41 3990 1
## y_pred[15] 0.30 0.00 0.05 0.20 0.26 0.30 0.34 0.40 3736 1
## y_pred[16] 0.14 0.00 0.05 0.04 0.10 0.14 0.18 0.24 3674 1
## y_pred[17] 0.31 0.00 0.05 0.21 0.27 0.31 0.34 0.41 3677 1
## y_pred[18] 0.26 0.00 0.05 0.15 0.22 0.26 0.29 0.36 3828 1
## y_pred[19] 0.42 0.00 0.05 0.32 0.39 0.42 0.46 0.53 3476 1
## y_pred[20] 0.23 0.00 0.05 0.13 0.20 0.23 0.27 0.34 3945 1
## y_pred[21] 0.12 0.00 0.05 0.01 0.08 0.12 0.16 0.23 3673 1
## y_pred[22] 0.15 0.00 0.05 0.05 0.12 0.15 0.19 0.26 4021 1
## y_pred[23] 0.15 0.00 0.05 0.05 0.12 0.15 0.19 0.26 3983 1
## y_pred[24] 0.21 0.00 0.05 0.11 0.18 0.21 0.25 0.31 3497 1
## y_pred[25] 0.17 0.00 0.05 0.07 0.14 0.17 0.20 0.27 3985 1
## y_pred[26] 0.19 0.00 0.05 0.08 0.15 0.19 0.22 0.29 4072 1
## y_pred[27] 0.32 0.00 0.05 0.22 0.28 0.32 0.35 0.42 3904 1
## y_pred[28] 0.32 0.00 0.05 0.21 0.28 0.32 0.35 0.42 3882 1
## y_pred[29] 0.38 0.00 0.05 0.28 0.35 0.38 0.42 0.49 3850 1
## y_pred[30] 0.31 0.00 0.05 0.20 0.27 0.31 0.34 0.41 4086 1
## y_pred[31] 0.25 0.00 0.05 0.14 0.21 0.25 0.28 0.35 3539 1
## y_pred[32] 0.10 0.00 0.05 -0.01 0.06 0.10 0.13 0.21 3572 1
## y_pred[33] 0.20 0.00 0.05 0.10 0.17 0.20 0.24 0.31 3933 1
## y_pred[34] 0.18 0.00 0.05 0.08 0.15 0.18 0.22 0.29 3942 1
## y_pred[35] 0.33 0.00 0.05 0.22 0.29 0.33 0.36 0.43 3803 1
## y_pred[36] 0.34 0.00 0.05 0.23 0.30 0.34 0.37 0.44 4080 1
## y_pred[37] 0.15 0.00 0.05 0.04 0.11 0.15 0.19 0.25 3646 1
## y_pred[38] 0.30 0.00 0.05 0.20 0.26 0.30 0.34 0.41 4003 1
## y_pred[39] 0.27 0.00 0.05 0.16 0.23 0.27 0.30 0.37 3719 1
## y_pred[40] 0.27 0.00 0.05 0.16 0.23 0.27 0.30 0.37 3488 1
## y_pred[41] 0.33 0.00 0.05 0.23 0.30 0.33 0.37 0.44 4034 1
## y_pred[42] 0.34 0.00 0.05 0.23 0.30 0.34 0.37 0.44 3893 1
## y_pred[43] 0.32 0.00 0.05 0.22 0.29 0.32 0.36 0.42 4199 1
## y_pred[44] 0.36 0.00 0.05 0.26 0.33 0.36 0.40 0.47 4009 1
## y_pred[45] 0.42 0.00 0.05 0.31 0.39 0.42 0.46 0.53 3440 1
## y_pred[46] 0.29 0.00 0.05 0.19 0.26 0.29 0.33 0.40 3818 1
## y_pred[47] 0.21 0.00 0.05 0.10 0.17 0.21 0.24 0.31 3578 1
## y_pred[48] 0.37 0.00 0.05 0.26 0.33 0.37 0.40 0.47 4039 1
## y_pred[49] 0.28 0.00 0.05 0.18 0.25 0.28 0.32 0.39 3901 1
## y_pred[50] 0.14 0.00 0.05 0.04 0.10 0.14 0.18 0.25 3893 1
## noise[1] 0.05 0.00 0.02 0.02 0.04 0.05 0.06 0.08 1817 1
## noise[2] -0.02 0.00 0.01 -0.04 -0.03 -0.02 -0.01 0.01 2648 1
## noise[3] -0.07 0.00 0.01 -0.08 -0.07 -0.07 -0.06 -0.05 3241 1
## noise[4] -0.01 0.00 0.01 -0.04 -0.02 -0.01 0.00 0.02 2454 1
## noise[5] -0.01 0.00 0.01 -0.03 -0.02 -0.01 0.00 0.02 2360 1
## noise[6] 0.01 0.00 0.01 -0.01 0.00 0.01 0.02 0.03 3178 1
## noise[7] 0.03 0.00 0.01 0.01 0.02 0.03 0.03 0.05 2821 1
## noise[8] -0.02 0.00 0.01 -0.05 -0.03 -0.02 -0.02 0.00 2597 1
## noise[9] 0.05 0.00 0.01 0.03 0.04 0.05 0.06 0.07 2801 1
## noise[10] 0.00 0.00 0.01 -0.02 -0.01 0.00 0.01 0.02 2832 1
## noise[11] 0.02 0.00 0.01 -0.01 0.01 0.02 0.02 0.04 2755 1
## noise[12] 0.08 0.00 0.01 0.06 0.07 0.08 0.08 0.10 3077 1
## noise[13] 0.01 0.00 0.01 -0.01 0.01 0.01 0.02 0.04 3123 1
## noise[14] -0.02 0.00 0.01 -0.04 -0.02 -0.02 -0.01 0.00 2821 1
## noise[15] 0.07 0.00 0.01 0.05 0.06 0.07 0.08 0.09 2739 1
## noise[16] 0.04 0.00 0.01 0.02 0.03 0.04 0.05 0.07 2746 1
## noise[17] 0.04 0.00 0.01 0.02 0.03 0.04 0.04 0.05 2944 1
## noise[18] 0.01 0.00 0.01 -0.02 0.00 0.01 0.02 0.03 1910 1
## noise[19] -0.01 0.00 0.02 -0.05 -0.02 -0.01 0.00 0.03 2140 1
## noise[20] 0.08 0.00 0.01 0.05 0.07 0.08 0.09 0.11 2397 1
## noise[21] -0.06 0.00 0.01 -0.09 -0.07 -0.06 -0.05 -0.03 2454 1
## noise[22] 0.12 0.00 0.01 0.09 0.11 0.12 0.12 0.14 2969 1
## noise[23] 0.08 0.00 0.01 0.05 0.07 0.08 0.09 0.10 2922 1
## noise[24] -0.02 0.00 0.01 -0.04 -0.03 -0.02 -0.01 0.01 2701 1
## noise[25] -0.07 0.00 0.01 -0.10 -0.08 -0.07 -0.07 -0.05 3116 1
## noise[26] -0.05 0.00 0.01 -0.07 -0.06 -0.05 -0.04 -0.03 3064 1
## noise[27] 0.02 0.00 0.01 0.00 0.01 0.02 0.03 0.04 3123 1
## noise[28] -0.01 0.00 0.01 -0.03 -0.02 -0.01 0.00 0.01 3123 1
## noise[29] 0.00 0.00 0.01 -0.03 -0.01 0.00 0.01 0.03 2434 1
## noise[30] -0.05 0.00 0.01 -0.07 -0.06 -0.05 -0.05 -0.04 2903 1
## noise[31] 0.00 0.00 0.02 -0.03 -0.01 0.00 0.01 0.03 2236 1
## noise[32] -0.07 0.00 0.02 -0.11 -0.08 -0.07 -0.06 -0.04 2260 1
## noise[33] -0.10 0.00 0.01 -0.13 -0.11 -0.10 -0.10 -0.08 2836 1
## noise[34] -0.03 0.00 0.01 -0.05 -0.04 -0.03 -0.02 -0.01 3115 1
## noise[35] 0.06 0.00 0.01 0.04 0.06 0.06 0.07 0.08 3247 1
## noise[36] -0.03 0.00 0.01 -0.04 -0.03 -0.03 -0.02 -0.01 3215 1
## noise[37] 0.02 0.00 0.01 0.00 0.01 0.02 0.03 0.04 2898 1
## noise[38] -0.10 0.00 0.01 -0.11 -0.10 -0.10 -0.09 -0.08 2780 1
## noise[39] -0.01 0.00 0.01 -0.03 -0.02 -0.01 0.00 0.02 2011 1
## noise[40] 0.03 0.00 0.01 0.00 0.02 0.03 0.04 0.05 1990 1
## noise[41] 0.03 0.00 0.01 0.01 0.02 0.03 0.03 0.05 3239 1
## noise[42] 0.02 0.00 0.01 0.00 0.01 0.02 0.02 0.04 3178 1
## noise[43] -0.03 0.00 0.01 -0.05 -0.04 -0.03 -0.03 -0.01 3199 1
## noise[44] -0.04 0.00 0.01 -0.06 -0.04 -0.04 -0.03 -0.01 2742 1
## noise[45] -0.04 0.00 0.02 -0.07 -0.05 -0.04 -0.02 0.00 2164 1
## noise[46] 0.01 0.00 0.01 -0.01 0.00 0.01 0.01 0.03 2583 1
## noise[47] -0.07 0.00 0.02 -0.11 -0.09 -0.07 -0.06 -0.03 1755 1
## noise[48] -0.01 0.00 0.01 -0.04 -0.02 -0.01 -0.01 0.01 2686 1
## noise[49] -0.02 0.00 0.01 -0.04 -0.02 -0.02 -0.01 0.00 2302 1
## noise[50] 0.09 0.00 0.01 0.07 0.08 0.09 0.10 0.12 2746 1
## lp__ 120.89 0.03 1.39 117.41 120.19 121.17 121.93 122.70 1603 1
##
## Samples were drawn using NUTS(diag_e) at Thu Dec 14 17:08:34 2023.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
モデルが計算できたので、下記の部分の作図を行う。
横軸に「n番目の学生における実測値Y[n]」、縦軸に「その学生の説明変数の値から算出されるYの予測分布の中央値と区間」をとった図を書くことにした(後述の図5.3)。そこで、Yの予測分布(からのMCMCサンプル)を得るために、generated quantitiesブロック平均mu[n]、標準偏差sigmaの正規分布からの乱数を発生させてy_predp[n]に代入している。
まずはMCMCサンプルを取り出して、学生ごとに信用区間を算出する。ms$y_pred
はMCMCサンプル4000行×学生50人分の行列形式で保存されているので、apply()
関数で集計して元のテーブルに情報を紐付ける。
ms <- rstan::extract(fit)
qua <- apply(ms$y_pred, 2, quantile, prob = c(0.1, 0.5, 0.9))
d_est <- data.frame(d, t(qua), check.names = FALSE)
d_est$A <- as.factor(d_est$A)
dim(ms$y_pred)
## [1] 4000 50
学生ごとに横軸に実測値、縦軸に予測分布の中央値と区間を可視化すれば完成。
ggplot(data = d_est, aes(x = Y, y = `50%`, ymin = `10%`, ymax = `90%`, shape = A, fill = A)) +
theme_bw(base_size = 18) + theme(legend.key.height = grid::unit(2.5, 'line')) +
coord_fixed(ratio = 1, xlim = c(0, 0.5), ylim = c(0, 0.5)) +
geom_pointrange(size = 0.5, color = 'grey5') +
geom_abline(aes(slope = 1, intercept = 0), col = 'black', alpha = 3/5, linetype = '31') +
scale_shape_manual(values = c(21, 24)) +
scale_fill_manual(values = c('white', 'grey70')) +
labs(x = 'Observed', y = 'Predicted') +
scale_x_continuous(breaks = seq(from = 0, to = 0.5, by = 0.1)) +
scale_y_continuous(breaks = seq(from = 0, to = 0.5, by = 0.1))
次は推定されたノイズの分布を可視化する。Stanのモデルの中で、下記の通りノイズは予め計算している。Y[n] - mu[n]
は、ぱっと見違和感があるが、Y[1]
は観測されたスカラ値でmu[1]
は分布から、差分を計算しているだけ。
generated quantities {
real y_pred[N];
real noise[N];
for (n in 1:N){
y_pred[n] = normal_rng(mu[n], sigma);
noise[n] = Y[n] - mu[n];
}
}
まずは学生50ごとに得られているノイズを可視化できるように前処理を行う。
# noiseの計算はgenerated quantitiesで行わない場合は下記でもOK
# t(-t(ms$mu) + d$Y)
# t(replicate(N_mcmc, d$Y)) - ms$mu
N_mcmc <- length(ms$lp__)
noise_mcmc <- ms$noise
d_est <- data.frame(noise_mcmc, check.names=FALSE) %>%
tidyr::pivot_longer(cols = everything(), names_to = 'Parameter') %>%
mutate(PersonID = readr::parse_number(Parameter))
list(
head = d_est %>% arrange(PersonID) %>% head(10),
tail = d_est %>% arrange(PersonID) %>% tail(10)
)
## $head
## # A tibble: 10 × 3
## Parameter value PersonID
## <chr> <dbl> <dbl>
## 1 1 0.0451 1
## 2 1 0.0465 1
## 3 1 0.0569 1
## 4 1 0.0548 1
## 5 1 0.0317 1
## 6 1 0.0417 1
## 7 1 0.0494 1
## 8 1 0.0440 1
## 9 1 0.0928 1
## 10 1 0.00902 1
##
## $tail
## # A tibble: 10 × 3
## Parameter value PersonID
## <chr> <dbl> <dbl>
## 1 50 0.0999 50
## 2 50 0.0890 50
## 3 50 0.109 50
## 4 50 0.104 50
## 5 50 0.0892 50
## 6 50 0.0954 50
## 7 50 0.0857 50
## 8 50 0.0909 50
## 9 50 0.110 50
## 10 50 0.116 50
次は、図に書き足す分布のMAP推定値を計算する。列(=学生の予測分布)ごとにカーネル密度推定を行い、密度が最大となる(x,y)
座標を計算している。
# MAP推定
d_mode <- apply(noise_mcmc, 2, function(x) {
dens <- density(x)
mode_i <- which.max(dens$y)
mode_x <- dens$x[mode_i]
mode_y <- dens$y[mode_i]
c(mode_x, mode_y)
}) %>%
t() %>%
data.frame() %>%
magrittr::set_colnames(c('X', 'Y'))
準備ができたので作図すれば完成。1つの分布が学生ごとの予測分布から計算されるノイズの分布となっている。
ggplot() +
theme_bw(base_size = 18) +
geom_line(data = d_est, aes(x = value, group = PersonID), stat = 'density', col = 'black', alpha = 0.4) +
geom_segment(data = d_mode, aes(x = X, xend = X, y = Y, yend = 0), col = 'black', linetype = 'dashed', alpha = 0.4) +
geom_rug(data = d_mode, aes(x = X), sides = 'b') +
labs(x = 'value', y = 'density')
書籍でも指摘されている通り、これは少し見づらいので、
noise[n]
の分布からMAP推定値を計算し、それらの分布をみることで、\(Normal(0,
\sigma)\)の分布と比較する方法も記載されている。
# ms$sigmaはms$sでも取れるみたい。
s_dens <- density(ms$s)
# s_dens$yの差大値をとる添字に対応するs_dens$xを標準偏差の代表として正規分布を計算
s_MAP <- s_dens$x[which.max(s_dens$y)]
bw <- 0.01
ggplot(data = d_mode, aes(x = X)) +
theme_bw(base_size = 18) +
geom_histogram(binwidth = bw, col = 'black', fill = 'white') +
geom_density(eval(bquote(aes(y = ..count..*.(bw)))), alpha = 0.5, col = 'black', fill = 'gray20') +
geom_rug(sides = 'b') +
stat_function(fun = function(x) nrow(d)*bw*dnorm(x, mean = 0, sd = s_MAP), linetype = 'dashed') +
labs(x = 'value', y = 'density') +
xlim(range(density(d_mode$X)$x))
この流れで、MCMCサンプルの散布図行列を書くことも推奨されている。散布図行列を確認することで、パラメタ間の関係性や各学生ごとのmu[n]
とパラメタの関係なども理解できる。
d <- data.frame(b1 = ms$b1, b2 = ms$b2, b3 = ms$b3, sigma = ms$sigma, `mu[1]` = ms$mu[,1], `mu[50]` = ms$mu[,50], lp__ = ms$lp__, check.names = FALSE)
N_col <- ncol(d)
ggp <- ggpairs(d, upper = 'blank', diag = 'blank', lower = 'blank')
for (i in 1:N_col) {
x <- d[,i]
bw <- (max(x)-min(x))/10
p <- ggplot(data.frame(x), aes(x)) +
theme_bw(base_size = 14) +
theme(axis.text.x = element_text(angle = 60, vjust = 1, hjust = 1)) +
geom_histogram(binwidth = bw, fill = 'white', color = 'grey5') +
geom_line(eval(bquote(aes(y = ..count..*.(bw)))), stat = 'density') +
geom_label(data = data.frame(x = -Inf, y = Inf, label = colnames(d)[i]), aes(x = x, y = y, label = label), hjust = 0, vjust = 1)
ggp <- putPlot(ggp, p, i, i)
}
zcolat <- seq(-1, 1, length = 81)
zcolre <- c(zcolat[1:40]+1, rev(zcolat[41:81]))
for (i in 1:(N_col-1)) {
for (j in (i+1):N_col) {
x <- as.numeric(d[,i])
y <- as.numeric(d[,j])
r <- cor(x, y, method = 'spearman', use = 'pairwise.complete.obs')
zcol <- lattice::level.colors(r, at = zcolat, col.regions = grey(zcolre))
textcol <- ifelse(abs(r) < 0.4, 'grey20', 'white')
ell <- ellipse::ellipse(r, level = 0.95, type = 'l', npoints = 50, scale = c(.2, .2), centre = c(.5, .5))
p <- ggplot(data.frame(ell), aes(x = x, y = y)) + theme_bw() + theme(
plot.background = element_blank(),
panel.grid.major = element_blank(), panel.grid.minor = element_blank(),
panel.border = element_blank(), axis.ticks = element_blank()) +
geom_polygon(fill = zcol, color = zcol) +
geom_text(data = NULL, x = .5, y = .5, label = 100*round(r, 2), size = 6, col = textcol)
ggp <- putPlot(ggp, p, i, j)
}
}
for (j in 1:(N_col-1)) {
for (i in (j+1):N_col) {
x <- d[,j]
y <- d[,i]
p <- ggplot(data.frame(x, y), aes(x = x, y = y)) +
theme_bw(base_size = 14) +
theme(axis.text.x = element_text(angle = 60, vjust = 1, hjust = 1)) +
geom_hex() +
scale_fill_gradientn(colours = gray.colors(7, start = 0.1, end = 0.9))
ggp <- putPlot(ggp, p, i, j)
}
}
print(ggp, left = 0.6, bottom = 0.6)