UPDATE: 2023-12-24 20:19:37.077966
このノートは「StanとRでベイズ統計モデリング」の内容を写経することで、ベイズ統計への理解を深めていくために作成している。
基本的には気になった部分を写経しながら、ところどころ自分用の補足をメモすることで、「StanとRでベイズ統計モデリング」を読み進めるための自分用の補足資料になることを目指す。私の解釈がおかしく、メモが誤っている場合があるので注意。
今回は第8章「階層モデル」のチャプターを写経していく。
この章で使用するデータの説明しておく。4業界\(GID\)の30企業\(KID\)の年齢\(X-23\)、年収\(Y\)が記録されている300人のデータを使用する。\(X\)から23が引かれているのは解釈がしやすくするため。
options(max.print = 999999)
library(dplyr)
library(ggplot2)
library(rstan)
d <- read.csv('https://raw.githubusercontent.com/MatsuuraKentaro/RStanBook/master/chap08/input/data-salary-3.txt')
head(d, 10)
## X Y KID GID
## 1 7 457 1 1
## 2 10 482 1 1
## 3 16 518 1 1
## 4 25 535 1 1
## 5 5 427 1 1
## 6 25 603 1 1
## 7 26 610 1 1
## 8 18 484 1 1
## 9 17 508 1 1
## 10 1 380 1 1
ここでも分析の問題は、年功序列で賃金は上昇する傾向にあるが、業界や会社によって「新卒時点の基本年収」や「年齢の伴う昇給額は異なる」と考えられるため、その会社による業界差、会社差を検討したい。
可視化してみると、どの業界でも年齢に伴う昇給は確認できるが、新卒時点での基本年収や昇給額は業界によって異なることがわかる。
d$GID <- as.factor(d$GID)
res_lm <- lm(Y ~ X, data = d)
coef <- coef(res_lm)
ggplot(d, aes(X, Y, shape = GID)) +
theme_bw(base_size = 15) +
geom_abline(intercept = coef[1], slope = coef[2], linewidth = 2, alpha = 0.3) +
facet_wrap(~ GID) +
geom_line(stat = 'smooth', method = 'lm', se = FALSE) +
geom_point(size = 3, alpha = 0.8) +
scale_shape_manual(values = c(16, 2, 4)) +
labs(x = 'X', y = 'Y')
業界ごと、企業ごとで回帰係数\(a,b\)を推定したものを集めてヒストグラムにしたものが下記のグラフ。\(a\)に関しては業界3のばらつきは小さいが、業界1,2のばらつきが大きい。このグラフから業界3は他の業界1,2よりも平均は高そうで、ばらつきも小さいことが想定される。
KIDGID <- unique(d[,3:4])
N <- nrow(d)
K <- 30
G <- 3
coefs <- as.data.frame(t(sapply(1:K, function(k) {
d_sub <- subset(d, KID == k)
coef(lm(Y ~ X, data = d_sub))
})))
colnames(coefs) <- c('a', 'b')
d_plot <- data.frame(coefs, KIDGID)
d_plot$GID_label <- factor(paste0('GID = ', d_plot$GID), levels = paste0('GID = ', 1:3))
bw <- diff(range(d_plot$a))/20
ggplot(data = d_plot, aes(x = a)) +
theme_bw(base_size = 15) +
facet_wrap(~GID_label, nrow = 3) +
geom_histogram(binwidth = bw, color = 'black', fill = 'white') +
geom_density(aes(y = after_stat(count)*bw), alpha = 0.2, color = 'black', fill = 'gray20') +
geom_rug(sides = 'b') +
labs(x = 'a', y = 'count')
\(b\)に関しても業界3のばらつきは小さいが、業界1,2のばらつきが大きい。このグラフから業界3は他の業界1,2よりも昇給額の平均は低く、ばらつきも小さいことが想定される。
bw <- diff(range(d_plot$b))/20
ggplot(data = d_plot, aes(x = b)) +
theme_bw(base_size = 15) +
facet_wrap(~GID_label, nrow = 3) +
geom_histogram(binwidth = bw, color = 'black', fill = 'white') +
geom_density(aes(y = after_stat(count)*bw), alpha = 0.2, color = 'black', fill = 'gray20') +
geom_rug(sides = 'b') +
labs(x = 'b', y = 'count')
ここで想定しているモデルは「新卒年収」と「年収昇給額」の平均は業界ごとに異なるが、「新卒年収」と「年収昇給額」の会社差のばらつきは共通としている。
\(a\)について考えることで、モデルへの理解を深める。各業界の\(a_{業界平均}[g]\)を「すべての業界で共通の平均 」と「業界差」に分けて考える。つまり\(a_{業界平均}[g] = a_{全体平均} + a_{業界差}[g]\)であり、\(a_{業界差}[g]\)には平均0、標準偏差\(\sigma_{ag}\)の正規分布から生成されると考える。さらに\(\sigma_{ag}\)には無情報事前分布を設定する。
そして、各社の\(a[k]\)は会社の属する業界\(a_{業界平均}[g]\)を平均とする正規分布から生成されると考える。
ここで想定しているモデルは下記の通り。
\[ \begin{eqnarray} Y[n] &\sim& Normal(a[KID[n]] + b[KID[n]] X[n], \sigma_{Y}) \\ a_{業界平均}[g] &\sim& Normal(a_{全体平均}, \sigma_{ag}) \\ b_{業界平均}[g] &\sim& Normal(b_{全体平均}, \sigma_{bg}) \\ a[k] &\sim& Normal(a_{業界平均}[K2G[k]], \sigma_{a}) \\ b[k] &\sim& Normal(b_{業界平均}[K2G[k]], \sigma_{b}) \\ \end{eqnarray} \]
データから\(\sigma_{Y}, a_{業界平均}[g], a_{全体平均}, \sigma_{ag}, b_{業界平均}[g], b_{全体平均}, \sigma_{bg}, a[k], \sigma_{a}, b[k], \sigma_{b}\)を推定する。
N <- nrow(d)
K <- 30
G <- 3
K2G <- unique(d[ , c('KID','GID')])$GID
data <- list(N = N, G = G, K = K, X = d$X, Y = d$Y, KID = d$KID, K2G = as.numeric(K2G))
data
## $N
## [1] 300
##
## $G
## [1] 3
##
## $K
## [1] 30
##
## $X
## [1] 7 10 16 25 5 25 26 18 17 1 5 4 19 10 21 12 17 22 9 18 21 6 15 4 7
## [26] 10 2 15 27 14 18 20 18 11 26 22 25 28 24 22 20 19 12 15 10 16 1 12 10 7
## [51] 17 15 12 9 14 13 15 9 14 11 10 13 26 3 14 23 10 7 0 3 7 4 24 6 34
## [76] 33 24 23 23 22 14 26 21 20 12 17 24 15 26 27 20 7 26 4 22 15 27 16 20 4
## [101] 18 8 10 16 2 8 13 19 7 2 19 4 1 8 7 5 13 16 9 19 5 13 7 26 20
## [126] 16 22 24 26 18 34 19 30 22 21 2 9 27 20 20 16 14 5 13 12 1 8 18 8 10
## [151] 3 4 11 8 13 12 2 6 0 20 8 20 10 9 12 6 10 11 12 9 6 5 13 12 13
## [176] 15 24 28 31 18 25 19 24 13 24 24 26 28 29 8 25 30 27 30 10 29 13 15 15 14
## [201] 12 15 14 12 14 13 15 12 15 20 17 19 12 12 17 14 8 20 13 26 3 32 21 21 24
## [226] 9 7 10 15 5 12 8 13 14 15 14 5 8 26 6 27 32 16 2 21 27 5 32 22 28
## [251] 25 6 20 33 20 14 22 11 11 8 13 32 8 7 8 20 19 25 11 12 8 10 13 12 15
## [276] 16 12 15 2 8 19 14 14 12 14 18 13 8 7 5 14 20 25 16 22 17 23 20 26 20
##
## $Y
## [1] 457 482 518 535 427 603 610 484 508 380 453 391 559 453 517
## [16] 553 653 763 538 708 740 437 646 422 444 504 376 522 623 515
## [31] 542 529 540 411 666 641 592 722 726 728 927 1036 640 851 568
## [46] 832 191 730 742 562 722 849 899 788 766 787 923 675 811 660
## [61] 598 757 1097 358 701 1097 583 560 388 420 493 423 1007 557 1331
## [76] 1398 1003 863 920 886 637 873 860 779 700 683 880 567 863 927
## [91] 681 484 837 303 740 726 849 632 814 434 864 622 568 833 271
## [106] 267 589 741 244 406 714 378 161 561 671 338 625 913 627 988
## [121] 426 621 428 1198 828 809 974 1190 1172 818 1521 960 1299 1081 1016
## [136] 301 639 1228 821 765 781 858 532 810 763 370 466 724 565 358
## [151] 330 448 580 524 622 603 347 392 238 762 463 714 551 515 816
## [166] 451 624 686 637 605 476 353 741 691 830 516 814 852 773 675
## [181] 767 674 824 548 918 769 671 913 735 453 968 1109 1031 1199 644
## [196] 1137 598 736 766 698 662 675 626 739 724 653 756 781 608 798
## [211] 739 609 591 458 472 443 572 1104 764 1176 415 1471 972 985 1044
## [226] 478 438 537 726 446 609 475 624 670 661 688 453 521 1040 555
## [241] 1033 1253 798 519 825 1065 507 1301 989 1070 824 571 756 940 785
## [256] 701 807 663 662 620 687 943 637 583 654 801 751 815 621 650
## [271] 590 614 644 645 673 695 653 695 534 610 708 663 668 643 665
## [286] 704 637 579 577 546 651 718 800 670 746 690 740 720 805 727
##
## $KID
## [1] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2
## [26] 2 2 3 3 3 3 3 3 3 3 3 3 4 4 4 5 5 5 5 5 5 5 5 5 5
## [51] 5 6 6 6 6 6 6 6 6 6 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7
## [76] 7 7 8 8 8 8 8 8 8 8 8 9 9 9 9 9 9 9 9 9 9 9 10 10 10
## [101] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 11 11
## [126] 11 11 11 12 12 12 12 12 12 12 12 12 12 13 13 13 13 13 13 13 13 13 14 14 14
## [151] 14 14 14 14 14 14 14 14 14 14 14 14 14 14 15 15 15 15 15 15 15 15 15 15 15
## [176] 16 16 16 16 16 16 16 16 16 16 16 16 16 16 17 17 17 17 17 17 17 18 18 18 18
## [201] 18 18 18 18 18 18 18 18 19 19 19 19 19 19 19 19 20 20 20 20 20 20 20 20 20
## [226] 21 21 21 21 21 21 21 21 21 21 21 21 22 22 22 22 22 22 22 22 22 22 22 22 22
## [251] 23 23 23 23 24 24 24 24 24 24 24 24 25 25 25 25 25 25 26 26 26 26 26 26 26
## [276] 27 27 27 27 27 28 28 28 28 28 28 29 29 29 29 30 30 30 30 30 30 30 30 30 30
##
## $K2G
## [1] 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3
K2G
の部分が直感的にわかりにくいかもしれないが、この変数は階層モデルにおける「企業と業界」をつなぐ調整役として機能する。
## KID GID
## 1 1 1
## 16 2 1
## 28 3 1
## 38 4 1
## 41 5 2
## 52 6 2
## 61 7 2
## 78 8 2
## 87 9 2
## 98 10 2
## 124 11 2
## 129 12 2
## 139 13 2
## 148 14 2
## 165 15 2
## 176 16 2
## 190 17 2
## 197 18 2
## 209 19 2
## 217 20 2
## 226 21 2
## 238 22 2
## 251 23 3
## 255 24 3
## 263 25 3
## 269 26 3
## 276 27 3
## 281 28 3
## 287 29 3
## 291 30 3
Stanファイルは下記の通り。
data {
int N;
int G;
int K;
real X[N];
real Y[N];
int<lower=1, upper=K> KID[N];
int<lower=1, upper=G> K2G[K];
}
parameters {
real a0;
real b0;
real a1[G];
real b1[G];
real a[K];
real b[K];
real<lower=0> s_ag;
real<lower=0> s_bg;
real<lower=0> s_a;
real<lower=0> s_b;
real<lower=0> s_Y;
}
model {
s_ag ~ normal(0, 1e5);
s_bg ~ normal(0, 1e5);
s_a ~ normal(0, 1e5);
s_b ~ normal(0, 1e5);
s_Y ~ normal(0, 1e5);
for (g in 1:G) {
a1[g] ~ normal(a0, s_ag);
b1[g] ~ normal(b0, s_bg);
}
for (k in 1:K) {
a[k] ~ normal(a1[K2G[k]], s_a);
b[k] ~ normal(b1[K2G[k]], s_b);
}
for (n in 1:N)
Y[n] ~ normal(a[KID[n]] + b[KID[n]]*X[n], s_Y);
}
ここでは、stan_model()
関数で最初にコンパイルしておいてから、
sampling()
関数でサンプリングする。
推定結果はこちら。a1[g]
とa[k]
についてまとめておく。
a1[g])
:a1[g]
の事後平均は、各業界ごとに異なる年収のベースラインを示す。例えば、業界1(g=1
)の平均切片は360.3で、業界2(g=2
)は299.9、これは業界ごとに異なる年収の平均値が存在し、業界が異なると切片も異なることを示す。
s_ag
:s_ag
は業界ごとの切片の標準偏差を表す。この値が大きいほど、業界間で切片にばらつきがあり、業界ごとに異なる企業があることを示唆する。つまり、業界内でも企業による異なる年収の傾向が存在する可能性がある。
a[k]
:a[k]
は企業ごとに異なる年収のベースラインを示す。企業ごとに切片が異なるため、同じ年齢でも異なる企業では異なる年収が期待される。
s_a
:s_a
は企業ごとの切片の標準偏差す。この値が大きいほど、企業間で切片にばらつきがあり、企業ごとに異なる年収の傾向が存在することを示唆する。
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 50% 97.5% n_eff Rhat
## a0 547.3 167.2 929.2 -185.2 389.2 3162.4 31 1.1
## b0 -42.8 62.1 323.6 -1284.2 18.1 94.6 27 1.2
## a1[1] 360.3 0.8 30.5 299.9 361.4 419.2 1349 1.0
## a1[2] 299.9 0.4 14.2 272.8 299.8 328.2 1409 1.0
## a1[3] 497.4 1.1 30.2 437.2 497.4 555.3 783 1.0
## b1[1] 13.0 0.1 2.9 7.3 13.0 18.8 2402 1.0
## b1[2] 28.5 0.0 1.4 25.8 28.5 31.4 2240 1.0
## b1[3] 12.7 0.1 2.6 7.5 12.7 17.6 1107 1.0
## a[1] 366.5 0.7 27.9 311.3 366.9 420.7 1620 1.0
## a[2] 357.3 0.8 30.4 297.7 357.5 417.1 1460 1.0
## a[3] 351.5 1.0 38.3 273.7 353.5 423.2 1558 1.0
## a[4] 365.6 1.0 41.8 284.1 365.6 448.4 1788 1.0
## a[5] 293.5 0.5 25.1 241.4 294.3 341.1 2395 1.0
## a[6] 332.1 1.2 36.0 275.7 327.6 411.8 884 1.0
## a[7] 310.4 0.5 19.9 274.0 309.9 350.8 1798 1.0
## a[8] 302.3 0.6 30.0 245.1 301.1 366.9 2254 1.0
## a[9] 286.4 0.6 26.5 229.6 288.1 335.6 2083 1.0
## a[10] 271.6 0.7 21.4 228.6 272.4 309.9 1029 1.0
## a[11] 300.2 0.6 30.9 238.8 299.6 364.6 2820 1.0
## a[12] 294.7 0.5 27.2 237.7 294.4 349.5 3215 1.0
## a[13] 323.3 0.9 29.1 273.9 320.3 388.2 1132 1.0
## a[14] 288.6 0.4 20.7 246.2 289.3 328.3 2375 1.0
## a[15] 296.1 0.5 26.2 242.6 296.4 349.1 2601 1.0
## a[16] 292.4 0.6 29.6 229.1 294.2 348.1 2629 1.0
## a[17] 296.5 0.6 28.9 236.6 296.9 354.9 2626 1.0
## a[18] 305.6 0.6 29.9 250.2 303.9 369.8 2412 1.0
## a[19] 275.5 0.9 33.8 199.8 279.9 333.0 1535 1.0
## a[20] 306.9 0.6 27.5 254.6 305.0 365.7 2284 1.0
## a[21] 290.7 0.5 25.6 238.4 291.6 341.3 2834 1.0
## a[22] 327.5 0.9 27.9 280.1 324.8 385.5 933 1.0
## a[23] 497.4 1.1 37.6 423.1 497.4 569.6 1105 1.0
## a[24] 502.8 1.0 34.0 438.0 502.9 570.1 1141 1.0
## a[25] 505.1 1.1 35.4 435.7 504.8 578.7 1101 1.0
## a[26] 495.3 1.1 37.2 423.6 495.9 570.4 1123 1.0
## a[27] 499.7 1.1 34.9 431.9 499.4 567.9 1075 1.0
## a[28] 495.9 1.2 39.5 417.4 496.9 573.9 1045 1.0
## a[29] 493.9 1.0 35.1 424.5 493.8 561.2 1274 1.0
## a[30] 493.3 1.1 39.8 413.8 493.7 570.9 1204 1.0
## b[1] 8.8 0.0 1.8 5.3 8.8 12.2 1861 1.0
## b[2] 17.6 0.1 2.4 12.9 17.6 22.2 1974 1.0
## b[3] 10.8 0.1 2.1 7.0 10.7 15.0 1668 1.0
## b[4] 14.3 0.0 2.2 9.9 14.3 18.5 2071 1.0
## b[5] 33.1 0.0 2.1 29.2 33.1 37.2 2812 1.0
## b[6] 35.8 0.1 3.1 29.3 35.9 41.3 1256 1.0
## b[7] 30.8 0.0 1.2 28.4 30.8 33.2 2411 1.0
## b[8] 25.1 0.0 1.7 21.6 25.2 28.6 2924 1.0
## b[9] 22.4 0.0 1.5 19.7 22.4 25.4 2319 1.0
## b[10] 29.2 0.0 1.8 25.7 29.2 32.9 1407 1.0
## b[11] 32.4 0.0 1.8 28.9 32.4 35.9 3470 1.0
## b[12] 34.2 0.0 1.4 31.4 34.2 37.0 3735 1.0
## b[13] 28.7 0.1 2.3 24.0 28.7 33.2 1814 1.0
## b[14] 24.1 0.0 2.0 20.1 24.1 28.2 2836 1.0
## b[15] 33.7 0.1 2.8 28.3 33.7 39.3 2916 1.0
## b[16] 19.5 0.0 1.4 16.8 19.4 22.3 2904 1.0
## b[17] 28.2 0.0 1.5 25.3 28.2 31.1 2699 1.0
## b[18] 28.7 0.0 2.4 23.7 28.7 33.3 2779 1.0
## b[19] 20.9 0.1 2.4 16.5 20.8 26.1 1976 1.0
## b[20] 33.9 0.0 1.6 30.6 33.9 36.9 2731 1.0
## b[21] 26.5 0.0 2.5 21.6 26.5 31.5 3000 1.0
## b[22] 28.1 0.0 1.4 25.3 28.1 30.6 1257 1.0
## b[23] 13.2 0.0 2.0 9.4 13.2 17.0 1594 1.0
## b[24] 13.9 0.1 2.1 9.7 13.9 17.9 1486 1.0
## b[25] 13.4 0.1 2.4 8.7 13.4 18.3 1318 1.0
## b[26] 12.1 0.1 3.4 5.3 12.0 18.6 1340 1.0
## b[27] 12.7 0.1 3.2 6.4 12.7 18.9 1410 1.0
## b[28] 11.9 0.1 2.9 6.0 11.8 17.8 1366 1.0
## b[29] 11.6 0.1 4.2 3.4 11.6 19.7 1633 1.0
## b[30] 11.6 0.1 2.1 7.5 11.6 15.8 1425 1.0
## s_ag 690.6 279.7 2071.9 61.3 217.1 5292.0 55 1.1
## s_bg 168.4 117.0 718.2 5.3 20.3 1978.0 38 1.1
## s_a 28.1 0.7 12.3 7.4 27.3 54.0 309 1.0
## s_b 4.7 0.0 0.8 3.4 4.7 6.5 3084 1.0
## s_Y 65.2 0.0 2.8 60.1 65.1 71.0 4888 1.0
## lp__ -1585.3 0.8 13.4 -1607.6 -1586.8 -1554.0 259 1.0
##
## Samples were drawn using NUTS(diag_e) at Sun Dec 24 20:20:13 2023.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
前回同様、Stanの挙動を確認しておく。前回同様a1[g],a[k]
に焦点を当てる。
まず、この部分では「業界の全体平均」と「業界差」によって、各企業が属する業界の情報が作られる。
// Gは1-3
for (g in 1:G) {
a1[g] ~ normal(a0, s_ag);
b1[g] ~ normal(b0, s_bg);
}
// image
a1[1] -> A1g
a1[2] -> A2g
a1[3] -> A3g
次に、この部分では「各企業が属する業界の平均」と「企業差」によって各企業の情報が作られる。
// Kは1-30
// K2G
// [1] 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3
for (k in 1:K) {
a[k] ~ normal(a1[K2G[k]], s_a);
b[k] ~ normal(b1[K2G[k]], s_b);
}
// image: a[k] ~ normal(a1[K2G[k]], s_a)
index01: K2G[ 1] -> 1 -> a1[1] -> A1g -> a[ 1] ~ normal(A1g, s_a) -> A01
index05: K2G[ 5] -> 2 -> a1[2] -> A2g -> a[ 5] ~ normal(A2g, s_a) -> A05
index15: K2G[15] -> 2 -> a1[2] -> A2g -> a[15] ~ normal(A2g, s_a) -> A15
index23: K2G[23] -> 3 -> a1[3] -> A3g -> a[23] ~ normal(A3g, s_a) -> A23
index30: K2G[30] -> 3 -> a1[3] -> A3g -> a[30] ~ normal(A3g, s_a) -> A30
さらに、さきほどの情報を利用して、下記のイメージで\(Y\)が生成される。
// Nは1-300
// KID
// [1] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3 4 4 4
for (n in 1:N){}
Y[n] ~ normal(a[KID[n]] + b[KID[n]]*X[n], s_Y);
}
X Y KID GID
index1 : 7 457 1 1 | KID[ 1] -> 1 -> a[ 1] -> A01 -> Y[ 1] ~ normal(A01 + B01*X[ 1], s_Y)
index40 : 22 728 4 1 | KID[ 40] -> 4 -> a[ 4] -> A04 -> Y[ 40] ~ normal(A04 + B04*X[040], s_Y)
index41 : 20 927 5 2 | KID[ 41] -> 5 -> a[ 5] -> A05 -> Y[ 41] ~ normal(A05 + B05*X[041], s_Y)
index250: 28 1070 22 2 | KID[250] -> 22 -> a[22] -> A22 -> Y[250] ~ normal(A22 + B22*X[250], s_Y)
index251: 25 824 23 3 | KID[251] -> 23 -> a[23] -> A23 -> Y[251] ~ normal(A23 + B23*X[251], s_Y)
index300: 20 727 30 3 | KID[300] -> 30 -> a[30] -> A30 -> Y[300] ~ normal(A30 + B30*X[300], s_Y)
さきほどのモデルは、「新卒年収」と「年収昇給額」の平均は業界ごとに異なるが、「新卒年収」と「年収昇給額」の会社差のばらつきは共通としていた。ここでは、「新卒年収」と「年収昇給額」の会社差のばらつきは業界ごとに異なると仮定する。
ここで想定しているモデルは下記の通り。
\[ \begin{eqnarray} Y[n] &\sim& Normal(a[KID[n]] + b[KID[n]] X[n], \sigma_{Y}[GID[n]]) \\ a_{業界平均}[g] &\sim& Normal(a_{全体平均}, \sigma_{ag}) \\ b_{業界平均}[g] &\sim& Normal(b_{全体平均}, \sigma_{bg}) \\ a[k] &\sim& Normal(a_{業界平均}[K2G[k]], \sigma_{a}[K2G[n]]) \\ b[k] &\sim& Normal(b_{業界平均}[K2G[k]], \sigma_{b}[K2G[n]]) \\ \end{eqnarray} \]
先程のモデルとの違いは、s_a[K2G[k]], s_b[K2G[k]]
の部分に現れている。
data {
int N;
int G;
int K;
real X[N];
real Y[N];
int<lower=1, upper=K> KID[N];
int<lower=1, upper=G> K2G[K];
int<lower=1, upper=G> GID[N];
}
parameters {
real a0;
real b0;
real a1[G];
real b1[G];
real a[K];
real b[K];
real<lower=0> s_ag;
real<lower=0> s_bg;
real<lower=0> s_a[G];
real<lower=0> s_b[G];
real<lower=0> s_Y[G];
}
model {
s_ag ~ normal(0, 1e5);
s_bg ~ normal(0, 1e5);
for (g in 1:G) {
a1[g] ~ normal(a0, s_ag);
b1[g] ~ normal(b0, s_bg);
}
for (k in 1:K) {
a[k] ~ normal(a1[K2G[k]], s_a[K2G[k]]);
b[k] ~ normal(b1[K2G[k]], s_b[K2G[k]]);
}
for (n in 1:N)
Y[n] ~ normal(a[KID[n]] + b[KID[n]]*X[n], s_Y[GID[n]]);
}
ここでは、stan_model()
関数で最初にコンパイルしておいてから、
sampling()
関数でサンプリングする。
K2G <- unique(d[ , c('KID','GID')])$GID
data <- list(N = N, G = G, K = K, X = d$X, Y = d$Y, KID = d$KID, K2G = as.numeric(K2G), GID = as.numeric(d$GID))
fit <- sampling(object = model86, data = data, seed = 1989)
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 50% 97.5% n_eff Rhat
## a0 275.4 85.9 1000.1 -1322.1 394.2 1092.1 136 1
## b0 -12.1 34.1 473.6 -63.2 18.2 102.0 192 1
## a1[1] 373.7 2.7 89.8 197.0 368.0 552.1 1074 1
## a1[2] 300.1 0.6 17.1 266.5 299.6 335.0 847 1
## a1[3] 500.5 0.3 10.0 479.2 500.9 519.7 1239 1
## b1[1] 13.1 0.1 6.1 1.1 13.1 25.8 1945 1
## b1[2] 28.5 0.1 1.7 25.1 28.5 31.9 637 1
## b1[3] 12.4 0.0 0.6 10.9 12.4 13.4 1071 1
## a[1] 382.4 0.3 14.9 353.7 382.5 410.9 1826 1
## a[2] 335.3 0.4 17.4 302.4 335.2 369.7 2133 1
## a[3] 324.3 1.4 34.6 247.4 326.1 387.3 603 1
## a[4] 469.9 4.0 124.6 302.7 440.3 764.7 948 1
## a[5] 292.7 0.7 29.3 231.0 293.3 350.4 1567 1
## a[6] 333.3 1.9 42.4 269.5 325.7 431.2 524 1
## a[7] 310.8 0.8 23.8 267.5 309.6 361.4 956 1
## a[8] 302.2 1.2 36.8 227.3 300.6 382.5 975 1
## a[9] 285.3 0.9 33.1 209.8 288.8 344.9 1447 1
## a[10] 273.1 1.3 25.4 220.9 274.4 321.7 379 1
## a[11] 300.4 0.8 39.1 223.8 298.6 385.5 2143 1
## a[12] 294.7 0.7 30.9 229.5 295.3 356.0 2022 1
## a[13] 324.0 1.3 33.4 267.6 319.8 401.5 692 1
## a[14] 289.1 0.7 24.6 236.7 289.9 335.8 1413 1
## a[15] 295.4 0.7 31.1 230.0 296.2 356.0 1964 1
## a[16] 292.9 1.0 37.7 213.5 294.0 370.2 1404 1
## a[17] 295.3 0.8 33.9 222.7 296.4 361.6 1735 1
## a[18] 307.3 1.2 35.5 242.3 304.0 384.8 926 1
## a[19] 274.1 1.3 41.0 176.9 281.6 338.5 1058 1
## a[20] 306.0 0.7 30.3 247.8 304.2 370.4 1777 1
## a[21] 289.8 0.7 29.4 227.1 291.0 346.2 1836 1
## a[22] 328.0 1.4 33.7 272.6 323.9 401.5 572 1
## a[23] 497.7 0.2 10.9 475.4 498.2 518.8 2067 1
## a[24] 515.4 0.2 8.8 498.6 515.4 532.7 1587 1
## a[25] 524.1 0.3 9.6 505.9 523.9 543.2 1263 1
## a[26] 494.1 0.3 11.0 474.6 493.5 516.5 1082 1
## a[27] 506.9 0.2 8.6 491.1 506.5 525.3 1389 1
## a[28] 497.3 0.5 14.4 472.1 496.6 527.9 758 1
## a[29] 488.5 0.4 10.8 468.8 487.8 510.6 795 1
## a[30] 482.5 0.6 15.0 454.6 481.9 511.7 713 1
## b[1] 7.8 0.0 0.9 6.0 7.8 9.6 1989 1
## b[2] 19.4 0.0 1.3 16.8 19.4 21.9 2409 1
## b[3] 11.9 0.1 1.7 8.8 11.9 15.7 761 1
## b[4] 10.2 0.2 5.0 -1.8 11.5 17.2 979 1
## b[5] 33.2 0.1 2.5 28.5 33.1 38.2 1820 1
## b[6] 35.8 0.1 3.5 28.0 36.0 41.9 921 1
## b[7] 30.7 0.1 1.5 27.7 30.7 33.6 710 1
## b[8] 25.1 0.1 2.1 20.9 25.1 29.2 1153 1
## b[9] 22.5 0.1 1.8 19.2 22.4 26.3 1092 1
## b[10] 29.1 0.1 2.2 24.7 29.1 33.6 474 1
## b[11] 32.4 0.0 2.3 27.8 32.4 36.6 2415 1
## b[12] 34.2 0.0 1.6 31.2 34.2 37.4 2421 1
## b[13] 28.7 0.1 2.6 23.0 28.8 33.5 1295 1
## b[14] 24.0 0.1 2.4 19.4 23.9 28.9 2222 1
## b[15] 33.7 0.1 3.4 27.3 33.6 40.4 1491 1
## b[16] 19.5 0.0 1.7 16.1 19.4 23.2 1615 1
## b[17] 28.2 0.0 1.7 25.0 28.2 31.7 1800 1
## b[18] 28.6 0.1 2.9 22.8 28.7 34.2 751 1
## b[19] 21.0 0.1 2.9 15.8 20.7 27.3 1428 1
## b[20] 34.0 0.0 1.8 30.4 34.0 37.3 2224 1
## b[21] 26.6 0.1 2.9 21.0 26.5 32.5 1922 1
## b[22] 28.0 0.0 1.6 24.7 28.2 30.9 1095 1
## b[23] 13.1 0.0 0.5 12.2 13.1 14.1 2167 1
## b[24] 13.3 0.0 0.5 12.3 13.3 14.2 1600 1
## b[25] 12.5 0.0 0.6 11.4 12.5 13.5 1534 1
## b[26] 12.1 0.0 0.9 10.2 12.2 13.7 1094 1
## b[27] 12.2 0.0 0.7 10.7 12.3 13.5 1403 1
## b[28] 11.7 0.0 0.9 9.7 11.8 13.3 860 1
## b[29] 11.8 0.0 1.1 9.4 12.0 13.6 827 1
## b[30] 12.1 0.0 0.7 10.6 12.1 13.4 773 1
## s_ag 663.7 140.9 1928.0 66.4 222.3 4297.3 187 1
## s_bg 133.2 67.6 1299.6 6.0 18.5 428.4 369 1
## s_a[1] 149.2 7.4 197.2 9.9 91.0 664.3 718 1
## s_a[2] 33.5 1.2 18.8 3.1 31.5 74.9 228 1
## s_a[3] 20.5 0.3 9.4 5.7 19.0 42.7 923 1
## s_b[1] 10.6 0.3 9.2 3.1 7.8 33.7 1146 1
## s_b[2] 5.7 0.0 1.3 3.7 5.5 8.8 1691 1
## s_b[3] 1.0 0.0 0.6 0.2 0.9 2.5 632 1
## s_Y[1] 28.6 0.1 3.8 22.1 28.3 37.1 2469 1
## s_Y[2] 76.2 0.1 4.0 68.8 76.1 84.4 1844 1
## s_Y[3] 11.6 0.0 1.4 9.3 11.5 14.8 2705 1
## lp__ -1474.7 1.2 13.7 -1498.0 -1476.3 -1439.1 141 1
##
## Samples were drawn using NUTS(diag_e) at Sun Dec 24 20:20:57 2023.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
書籍のモデルでは、事前分布が省略されていたので、省略しないバージョンもメモしておく。結果は似たような結果と鳴る。
model86_ <- stan_model('note_ahirubayes08-86_.stan')
fit <- sampling(object = model86_, data = data, seed = 1989)
結果はこちら。
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 50% 97.5% n_eff Rhat
## a0 413.5 35.1 698.9 -515.7 396.6 1263.7 396 1
## b0 15.5 2.2 47.8 -58.7 18.3 83.1 477 1
## a1[1] 381.3 2.8 91.4 235.2 367.3 580.0 1051 1
## a1[2] 299.7 0.5 16.9 266.5 299.4 334.7 1069 1
## a1[3] 500.6 0.3 10.1 479.3 500.8 519.9 1200 1
## b1[1] 12.9 0.2 6.0 -0.8 13.1 24.1 928 1
## b1[2] 28.5 0.0 1.7 25.2 28.6 31.8 1258 1
## b1[3] 12.4 0.0 0.6 11.0 12.4 13.4 1165 1
## a[1] 383.4 0.5 15.1 353.5 383.6 411.7 961 1
## a[2] 334.3 0.4 17.3 300.3 334.6 367.9 1893 1
## a[3] 322.3 1.0 33.4 255.9 323.2 384.8 1116 1
## a[4] 470.0 4.4 128.0 301.4 437.8 785.9 851 1
## a[5] 294.5 0.7 29.5 231.0 296.0 353.2 2013 1
## a[6] 329.7 1.8 42.5 267.8 321.1 432.1 564 1
## a[7] 309.6 0.6 23.0 267.1 308.1 356.9 1417 1
## a[8] 301.2 0.7 34.0 230.1 300.3 373.8 2244 1
## a[9] 286.7 0.9 32.0 213.2 290.2 343.8 1197 1
## a[10] 274.0 1.6 26.4 217.9 276.0 319.9 282 1
## a[11] 299.9 0.8 36.4 222.7 299.6 376.0 2162 1
## a[12] 295.3 0.7 30.5 230.0 296.6 354.8 2125 1
## a[13] 321.5 1.2 32.9 267.4 316.9 395.7 753 1
## a[14] 289.6 0.7 24.1 239.3 291.4 334.2 1243 1
## a[15] 295.7 0.7 29.6 232.5 296.5 353.9 2074 1
## a[16] 294.1 0.7 35.8 216.1 296.1 364.2 2292 1
## a[17] 296.3 0.8 33.9 222.0 297.7 363.6 1712 1
## a[18] 305.3 0.8 34.6 238.9 303.2 381.0 1755 1
## a[19] 276.9 1.5 40.1 180.6 283.2 342.8 677 1
## a[20] 306.4 0.7 31.2 248.4 304.4 374.5 2086 1
## a[21] 291.2 0.8 29.8 227.8 293.2 347.1 1415 1
## a[22] 325.7 1.5 32.2 273.0 321.1 397.3 459 1
## a[23] 497.6 0.2 10.7 475.2 498.0 518.2 2537 1
## a[24] 516.1 0.2 9.1 498.4 516.3 533.5 1513 1
## a[25] 525.1 0.2 9.5 505.6 525.1 543.7 1472 1
## a[26] 494.1 0.3 11.0 473.6 493.2 516.9 1257 1
## a[27] 507.2 0.2 8.3 491.6 507.3 524.0 2328 1
## a[28] 497.3 0.5 14.3 471.5 496.4 526.9 867 1
## a[29] 488.5 0.4 10.4 469.7 487.8 510.0 850 1
## a[30] 481.6 0.5 14.5 455.1 480.5 511.4 960 1
## b[1] 7.8 0.0 0.9 6.0 7.7 9.7 1593 1
## b[2] 19.4 0.0 1.3 16.9 19.4 22.0 1294 1
## b[3] 12.1 0.0 1.7 9.0 12.1 15.2 1337 1
## b[4] 10.3 0.2 5.1 -2.0 11.5 17.3 857 1
## b[5] 33.1 0.0 2.5 28.3 33.1 38.0 2545 1
## b[6] 36.0 0.1 3.6 28.3 36.3 42.1 647 1
## b[7] 30.8 0.0 1.5 27.9 30.9 33.6 1376 1
## b[8] 25.2 0.0 2.0 21.1 25.2 29.1 2401 1
## b[9] 22.4 0.0 1.8 19.1 22.4 26.2 1534 1
## b[10] 29.0 0.1 2.3 24.7 29.0 33.7 336 1
## b[11] 32.4 0.0 2.2 28.4 32.3 36.7 1999 1
## b[12] 34.1 0.0 1.6 31.0 34.1 37.4 2109 1
## b[13] 28.8 0.1 2.7 23.4 28.9 33.8 920 1
## b[14] 23.9 0.1 2.3 19.5 23.8 28.7 1724 1
## b[15] 33.8 0.1 3.2 27.5 33.6 40.6 1980 1
## b[16] 19.4 0.0 1.7 16.2 19.4 22.9 2306 1
## b[17] 28.2 0.0 1.7 24.8 28.2 31.6 2477 1
## b[18] 28.8 0.1 2.9 22.8 28.9 34.2 1545 1
## b[19] 20.9 0.1 2.8 15.5 20.7 27.1 1065 1
## b[20] 34.0 0.0 1.8 30.4 34.0 37.5 2250 1
## b[21] 26.5 0.1 2.9 20.7 26.5 32.6 1711 1
## b[22] 28.1 0.0 1.6 24.8 28.1 30.9 1008 1
## b[23] 13.1 0.0 0.5 12.3 13.1 14.1 1336 1
## b[24] 13.2 0.0 0.5 12.3 13.2 14.2 1413 1
## b[25] 12.4 0.0 0.6 11.3 12.4 13.5 1449 1
## b[26] 12.1 0.0 0.9 10.2 12.2 13.7 1164 1
## b[27] 12.2 0.0 0.7 10.8 12.3 13.5 2015 1
## b[28] 11.7 0.0 0.9 9.8 11.8 13.4 851 1
## b[29] 11.9 0.0 1.1 9.4 12.1 13.5 795 1
## b[30] 12.1 0.0 0.7 10.7 12.2 13.4 904 1
## s_ag 628.0 77.7 1842.8 62.9 241.7 3408.0 562 1
## s_bg 51.1 6.8 211.8 5.8 19.6 301.1 982 1
## s_a[1] 142.2 6.1 170.7 11.9 90.2 590.2 785 1
## s_a[2] 31.1 1.7 19.1 2.3 29.6 71.4 129 1
## s_a[3] 21.1 0.3 9.7 6.6 19.4 45.0 1046 1
## s_b[1] 10.6 0.3 9.2 3.2 7.8 34.5 1018 1
## s_b[2] 5.7 0.0 1.2 3.7 5.6 8.6 2518 1
## s_b[3] 1.0 0.0 0.6 0.1 0.9 2.4 664 1
## s_Y[1] 28.6 0.1 3.9 22.1 28.3 37.4 1903 1
## s_Y[2] 76.1 0.1 4.0 68.7 75.8 84.5 3336 1
## s_Y[3] 11.6 0.0 1.4 9.3 11.5 14.8 2066 1
## lp__ -1472.9 1.8 15.8 -1496.8 -1475.7 -1431.3 75 1
##
## Samples were drawn using NUTS(diag_e) at Sun Dec 24 20:21:34 2023.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
先程と異なる点は標準偏差の部分だけではあるが、こちらのモデルでもa
に関して同様に挙動を確認しておく。
// 事前分布
a0 ~ normal(0, 1e5);
b0 ~ normal(0, 1e5);
s_ag ~ normal(0, 1e5);
s_bg ~ normal(0, 1e5);
この部分で、a1[1],a1[2],a1[3],s_a[1],s_a[2],s_a[3]
を生成する。
// G:1-3
for (g in 1:G) {
a1[g] ~ normal(a0, s_ag);
b1[g] ~ normal(b0, s_bg);
s_a[g] ~ normal(0, 1e5);
s_b[g] ~ normal(0, 1e5);
s_Y[g] ~ normal(0, 1e5);
}
// image
a1[1] -> A1g
a1[2] -> A2g
a1[3] -> A3g
s_a[1] -> SA1g
s_a[2] -> SA2g
s_a[3] -> SA3g
s_Y[1] -> SY1g
s_Y[2] -> SY2g
s_Y[3] -> SY3g
次に、この部分では「各企業が属する業界の平均」と「企業差」によって各企業の情報が作られる。このとき、業界ごとに標準偏差も異なるので注意。
// K:1-30
// K2G
// [1] 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3
for (k in 1:K) {
a[k] ~ normal(a1[K2G[k]], s_a[K2G[k]]);
b[k] ~ normal(b1[K2G[k]], s_b[K2G[k]]);
}
// image: a[k] ~ normal(a1[K2G[k]], s_a[K2G[k]])
index01: K2G[ 1] -> 1 -> a1[1],s_a[1] -> A1g,SA1g -> a[ 1] ~ normal(A1g, SA1g) -> A01
index05: K2G[ 5] -> 2 -> a1[2],s_a[2] -> A2g,SA2g -> a[ 5] ~ normal(A2g, SA2g) -> A05
index15: K2G[15] -> 2 -> a1[2],s_a[2] -> A2g,SA2g -> a[15] ~ normal(A2g, SA2g) -> A15
index23: K2G[23] -> 3 -> a1[3],s_a[3] -> A3g,SA3g -> a[23] ~ normal(A3g, SA3g) -> A23
index30: K2G[30] -> 3 -> a1[3],s_a[3] -> A3g,SA3g -> a[30] ~ normal(A3g, SA3g) -> A30
さらに、さきほどの情報を利用して、下記のイメージで\(Y\)が生成される。
// N:1-300
for (n in 1:N)
Y[n] ~ normal(a[KID[n]] + b[KID[n]]*X[n], s_Y[GID[n]]);
}
// Nは1-300
// KID
// [1] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3 4 4 4
X Y KID GID
index1 : 7 457 1 1 | KID[ 1] -> 1 -> a[ 1] -> A01 -> Y[ 1] ~ normal(A01 + B01*X[ 1], SY1g)
index40 : 22 728 4 1 | KID[ 40] -> 4 -> a[ 4] -> A04 -> Y[ 40] ~ normal(A04 + B04*X[040], SY1g)
index41 : 20 927 5 2 | KID[ 41] -> 5 -> a[ 5] -> A05 -> Y[ 41] ~ normal(A05 + B05*X[041], SY2g)
index250: 28 1070 22 2 | KID[250] -> 22 -> a[22] -> A22 -> Y[250] ~ normal(A22 + B22*X[250], SY2g)
index251: 25 824 23 3 | KID[251] -> 23 -> a[23] -> A23 -> Y[251] ~ normal(A23 + B23*X[251], SY3g)
index300: 20 727 30 3 | KID[300] -> 30 -> a[30] -> A30 -> Y[300] ~ normal(A30 + B30*X[300], SY3g)