UPDATE: 2023-12-24 20:19:37.077966

はじめに

このノートは「StanとRでベイズ統計モデリング」の内容を写経することで、ベイズ統計への理解を深めていくために作成している。

基本的には気になった部分を写経しながら、ところどころ自分用の補足をメモすることで、「StanとRでベイズ統計モデリング」を読み進めるための自分用の補足資料になることを目指す。私の解釈がおかしく、メモが誤っている場合があるので注意。

今回は第8章「階層モデル」のチャプターを写経していく。

8.2 階層モデルの導入

この章で使用するデータの説明しておく。4業界\(GID\)の30企業\(KID\)の年齢\(X-23\)、年収\(Y\)が記録されている300人のデータを使用する。\(X\)から23が引かれているのは解釈がしやすくするため。

options(max.print = 999999)
library(dplyr)
library(ggplot2)
library(rstan)

d <- read.csv('https://raw.githubusercontent.com/MatsuuraKentaro/RStanBook/master/chap08/input/data-salary-3.txt')
head(d, 10)
##     X   Y KID GID
## 1   7 457   1   1
## 2  10 482   1   1
## 3  16 518   1   1
## 4  25 535   1   1
## 5   5 427   1   1
## 6  25 603   1   1
## 7  26 610   1   1
## 8  18 484   1   1
## 9  17 508   1   1
## 10  1 380   1   1

8.2.1 解析の目的とデータの分布の確認

ここでも分析の問題は、年功序列で賃金は上昇する傾向にあるが、業界や会社によって「新卒時点の基本年収」や「年齢の伴う昇給額は異なる」と考えられるため、その会社による業界差、会社差を検討したい。

可視化してみると、どの業界でも年齢に伴う昇給は確認できるが、新卒時点での基本年収や昇給額は業界によって異なることがわかる。

d$GID <- as.factor(d$GID)
res_lm <- lm(Y ~ X, data = d)
coef <- coef(res_lm)
ggplot(d, aes(X, Y, shape = GID)) +
  theme_bw(base_size = 15) +
  geom_abline(intercept = coef[1], slope = coef[2], linewidth = 2, alpha = 0.3) +
  facet_wrap(~ GID) +
  geom_line(stat = 'smooth', method = 'lm', se = FALSE) +
  geom_point(size = 3, alpha = 0.8) +
  scale_shape_manual(values = c(16, 2, 4)) +
  labs(x = 'X', y = 'Y')

業界ごと、企業ごとで回帰係数\(a,b\)を推定したものを集めてヒストグラムにしたものが下記のグラフ。\(a\)に関しては業界3のばらつきは小さいが、業界1,2のばらつきが大きい。このグラフから業界3は他の業界1,2よりも平均は高そうで、ばらつきも小さいことが想定される。

KIDGID <- unique(d[,3:4])
N <- nrow(d)
K <- 30
G <- 3
coefs <- as.data.frame(t(sapply(1:K, function(k) {
  d_sub <- subset(d, KID == k)
  coef(lm(Y ~ X, data = d_sub))
})))
colnames(coefs) <- c('a', 'b')
d_plot <- data.frame(coefs, KIDGID)
d_plot$GID_label <- factor(paste0('GID  =  ', d_plot$GID), levels  =  paste0('GID  =  ', 1:3))

bw <- diff(range(d_plot$a))/20
ggplot(data = d_plot, aes(x = a)) +
  theme_bw(base_size = 15) +
  facet_wrap(~GID_label, nrow = 3) +
  geom_histogram(binwidth = bw, color = 'black', fill = 'white') +
  geom_density(aes(y = after_stat(count)*bw), alpha = 0.2, color = 'black', fill = 'gray20') +
  geom_rug(sides = 'b') +
  labs(x = 'a', y = 'count')

\(b\)に関しても業界3のばらつきは小さいが、業界1,2のばらつきが大きい。このグラフから業界3は他の業界1,2よりも昇給額の平均は低く、ばらつきも小さいことが想定される。

bw <- diff(range(d_plot$b))/20
ggplot(data = d_plot, aes(x = b)) +
  theme_bw(base_size = 15) +
  facet_wrap(~GID_label, nrow = 3) +
  geom_histogram(binwidth = bw, color = 'black', fill = 'white') +
  geom_density(aes(y = after_stat(count)*bw), alpha = 0.2, color = 'black', fill = 'gray20') +
  geom_rug(sides = 'b') +
  labs(x = 'b', y = 'count')

8.2.2 メカニズムの想像ーその1

ここで想定しているモデルは「新卒年収」と「年収昇給額」の平均は業界ごとに異なるが、「新卒年収」と「年収昇給額」の会社差のばらつきは共通としている。

\(a\)について考えることで、モデルへの理解を深める。各業界の\(a_{業界平均}[g]\)を「すべての業界で共通の平均 」と「業界差」に分けて考える。つまり\(a_{業界平均}[g] = a_{全体平均} + a_{業界差}[g]\)であり、\(a_{業界差}[g]\)には平均0、標準偏差\(\sigma_{ag}\)の正規分布から生成されると考える。さらに\(\sigma_{ag}\)には無情報事前分布を設定する。

そして、各社の\(a[k]\)は会社の属する業界\(a_{業界平均}[g]\)を平均とする正規分布から生成されると考える。

8.2.3 モデル式の記述ーその1

ここで想定しているモデルは下記の通り。

モデル8-5

\[ \begin{eqnarray} Y[n] &\sim& Normal(a[KID[n]] + b[KID[n]] X[n], \sigma_{Y}) \\ a_{業界平均}[g] &\sim& Normal(a_{全体平均}, \sigma_{ag}) \\ b_{業界平均}[g] &\sim& Normal(b_{全体平均}, \sigma_{bg}) \\ a[k] &\sim& Normal(a_{業界平均}[K2G[k]], \sigma_{a}) \\ b[k] &\sim& Normal(b_{業界平均}[K2G[k]], \sigma_{b}) \\ \end{eqnarray} \]

データから\(\sigma_{Y}, a_{業界平均}[g], a_{全体平均}, \sigma_{ag}, b_{業界平均}[g], b_{全体平均}, \sigma_{bg}, a[k], \sigma_{a}, b[k], \sigma_{b}\)を推定する。

N <- nrow(d)
K <- 30
G <- 3
K2G <- unique(d[ , c('KID','GID')])$GID
data <- list(N = N, G = G, K = K, X = d$X, Y = d$Y, KID = d$KID, K2G = as.numeric(K2G))
data
## $N
## [1] 300
## 
## $G
## [1] 3
## 
## $K
## [1] 30
## 
## $X
##   [1]  7 10 16 25  5 25 26 18 17  1  5  4 19 10 21 12 17 22  9 18 21  6 15  4  7
##  [26] 10  2 15 27 14 18 20 18 11 26 22 25 28 24 22 20 19 12 15 10 16  1 12 10  7
##  [51] 17 15 12  9 14 13 15  9 14 11 10 13 26  3 14 23 10  7  0  3  7  4 24  6 34
##  [76] 33 24 23 23 22 14 26 21 20 12 17 24 15 26 27 20  7 26  4 22 15 27 16 20  4
## [101] 18  8 10 16  2  8 13 19  7  2 19  4  1  8  7  5 13 16  9 19  5 13  7 26 20
## [126] 16 22 24 26 18 34 19 30 22 21  2  9 27 20 20 16 14  5 13 12  1  8 18  8 10
## [151]  3  4 11  8 13 12  2  6  0 20  8 20 10  9 12  6 10 11 12  9  6  5 13 12 13
## [176] 15 24 28 31 18 25 19 24 13 24 24 26 28 29  8 25 30 27 30 10 29 13 15 15 14
## [201] 12 15 14 12 14 13 15 12 15 20 17 19 12 12 17 14  8 20 13 26  3 32 21 21 24
## [226]  9  7 10 15  5 12  8 13 14 15 14  5  8 26  6 27 32 16  2 21 27  5 32 22 28
## [251] 25  6 20 33 20 14 22 11 11  8 13 32  8  7  8 20 19 25 11 12  8 10 13 12 15
## [276] 16 12 15  2  8 19 14 14 12 14 18 13  8  7  5 14 20 25 16 22 17 23 20 26 20
## 
## $Y
##   [1]  457  482  518  535  427  603  610  484  508  380  453  391  559  453  517
##  [16]  553  653  763  538  708  740  437  646  422  444  504  376  522  623  515
##  [31]  542  529  540  411  666  641  592  722  726  728  927 1036  640  851  568
##  [46]  832  191  730  742  562  722  849  899  788  766  787  923  675  811  660
##  [61]  598  757 1097  358  701 1097  583  560  388  420  493  423 1007  557 1331
##  [76] 1398 1003  863  920  886  637  873  860  779  700  683  880  567  863  927
##  [91]  681  484  837  303  740  726  849  632  814  434  864  622  568  833  271
## [106]  267  589  741  244  406  714  378  161  561  671  338  625  913  627  988
## [121]  426  621  428 1198  828  809  974 1190 1172  818 1521  960 1299 1081 1016
## [136]  301  639 1228  821  765  781  858  532  810  763  370  466  724  565  358
## [151]  330  448  580  524  622  603  347  392  238  762  463  714  551  515  816
## [166]  451  624  686  637  605  476  353  741  691  830  516  814  852  773  675
## [181]  767  674  824  548  918  769  671  913  735  453  968 1109 1031 1199  644
## [196] 1137  598  736  766  698  662  675  626  739  724  653  756  781  608  798
## [211]  739  609  591  458  472  443  572 1104  764 1176  415 1471  972  985 1044
## [226]  478  438  537  726  446  609  475  624  670  661  688  453  521 1040  555
## [241] 1033 1253  798  519  825 1065  507 1301  989 1070  824  571  756  940  785
## [256]  701  807  663  662  620  687  943  637  583  654  801  751  815  621  650
## [271]  590  614  644  645  673  695  653  695  534  610  708  663  668  643  665
## [286]  704  637  579  577  546  651  718  800  670  746  690  740  720  805  727
## 
## $KID
##   [1]  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  2  2  2  2  2  2  2  2  2  2
##  [26]  2  2  3  3  3  3  3  3  3  3  3  3  4  4  4  5  5  5  5  5  5  5  5  5  5
##  [51]  5  6  6  6  6  6  6  6  6  6  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7
##  [76]  7  7  8  8  8  8  8  8  8  8  8  9  9  9  9  9  9  9  9  9  9  9 10 10 10
## [101] 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 11 11
## [126] 11 11 11 12 12 12 12 12 12 12 12 12 12 13 13 13 13 13 13 13 13 13 14 14 14
## [151] 14 14 14 14 14 14 14 14 14 14 14 14 14 14 15 15 15 15 15 15 15 15 15 15 15
## [176] 16 16 16 16 16 16 16 16 16 16 16 16 16 16 17 17 17 17 17 17 17 18 18 18 18
## [201] 18 18 18 18 18 18 18 18 19 19 19 19 19 19 19 19 20 20 20 20 20 20 20 20 20
## [226] 21 21 21 21 21 21 21 21 21 21 21 21 22 22 22 22 22 22 22 22 22 22 22 22 22
## [251] 23 23 23 23 24 24 24 24 24 24 24 24 25 25 25 25 25 25 26 26 26 26 26 26 26
## [276] 27 27 27 27 27 28 28 28 28 28 28 29 29 29 29 30 30 30 30 30 30 30 30 30 30
## 
## $K2G
##  [1] 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3

K2Gの部分が直感的にわかりにくいかもしれないが、この変数は階層モデルにおける「企業と業界」をつなぐ調整役として機能する。

# K2G <- unique(d[ , c('KID','GID')])$GID
unique(d[ , c('KID','GID')])
##     KID GID
## 1     1   1
## 16    2   1
## 28    3   1
## 38    4   1
## 41    5   2
## 52    6   2
## 61    7   2
## 78    8   2
## 87    9   2
## 98   10   2
## 124  11   2
## 129  12   2
## 139  13   2
## 148  14   2
## 165  15   2
## 176  16   2
## 190  17   2
## 197  18   2
## 209  19   2
## 217  20   2
## 226  21   2
## 238  22   2
## 251  23   3
## 255  24   3
## 263  25   3
## 269  26   3
## 276  27   3
## 281  28   3
## 287  29   3
## 291  30   3

8.2.4 Stanで実装ーその1

Stanファイルは下記の通り。

data {
  int N;
  int G;
  int K;
  real X[N];
  real Y[N];
  int<lower=1, upper=K> KID[N];
  int<lower=1, upper=G> K2G[K];
}

parameters {
  real a0;
  real b0;
  real a1[G];
  real b1[G];
  real a[K];
  real b[K];
  real<lower=0> s_ag;
  real<lower=0> s_bg;
  real<lower=0> s_a;
  real<lower=0> s_b;
  real<lower=0> s_Y;
}

model {
  s_ag ~ normal(0, 1e5);
  s_bg ~ normal(0, 1e5);
  s_a  ~ normal(0, 1e5);
  s_b  ~ normal(0, 1e5);
  s_Y  ~ normal(0, 1e5);
  
  for (g in 1:G) {
    a1[g] ~ normal(a0, s_ag);
    b1[g] ~ normal(b0, s_bg);
  }

  for (k in 1:K) {
    a[k] ~ normal(a1[K2G[k]], s_a);
    b[k] ~ normal(b1[K2G[k]], s_b);
  }

  for (n in 1:N)
    Y[n] ~ normal(a[KID[n]] + b[KID[n]]*X[n], s_Y);
}

ここでは、stan_model()関数で最初にコンパイルしておいてから、

model85 <- stan_model('note_ahirubayes08-85.stan')

sampling()関数でサンプリングする。

fit <- sampling(object = model85, data = data, seed = 1989)

推定結果はこちら。a1[g]a[k]についてまとめておく。

  • 業界ごとの平均切片a1[g]):

a1[g]の事後平均は、各業界ごとに異なる年収のベースラインを示す。例えば、業界1(g=1)の平均切片は360.3で、業界2(g=2)は299.9、これは業界ごとに異なる年収の平均値が存在し、業界が異なると切片も異なることを示す。

  • 業界ごとの切片のばらつきs_ag:

s_agは業界ごとの切片の標準偏差を表す。この値が大きいほど、業界間で切片にばらつきがあり、業界ごとに異なる企業があることを示唆する。つまり、業界内でも企業による異なる年収の傾向が存在する可能性がある。

  • 企業ごとの切片a[k]:

a[k]は企業ごとに異なる年収のベースラインを示す。企業ごとに切片が異なるため、同じ年齢でも異なる企業では異なる年収が期待される。

  • 企業ごとの切片のばらつきs_a:

s_aは企業ごとの切片の標準偏差す。この値が大きいほど、企業間で切片にばらつきがあり、企業ごとに異なる年収の傾向が存在することを示唆する。

print(fit, prob = c(0.025, 0.5, 0.975), digits_summary = 1)
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##          mean se_mean     sd    2.5%     50%   97.5% n_eff Rhat
## a0      547.3   167.2  929.2  -185.2   389.2  3162.4    31  1.1
## b0      -42.8    62.1  323.6 -1284.2    18.1    94.6    27  1.2
## a1[1]   360.3     0.8   30.5   299.9   361.4   419.2  1349  1.0
## a1[2]   299.9     0.4   14.2   272.8   299.8   328.2  1409  1.0
## a1[3]   497.4     1.1   30.2   437.2   497.4   555.3   783  1.0
## b1[1]    13.0     0.1    2.9     7.3    13.0    18.8  2402  1.0
## b1[2]    28.5     0.0    1.4    25.8    28.5    31.4  2240  1.0
## b1[3]    12.7     0.1    2.6     7.5    12.7    17.6  1107  1.0
## a[1]    366.5     0.7   27.9   311.3   366.9   420.7  1620  1.0
## a[2]    357.3     0.8   30.4   297.7   357.5   417.1  1460  1.0
## a[3]    351.5     1.0   38.3   273.7   353.5   423.2  1558  1.0
## a[4]    365.6     1.0   41.8   284.1   365.6   448.4  1788  1.0
## a[5]    293.5     0.5   25.1   241.4   294.3   341.1  2395  1.0
## a[6]    332.1     1.2   36.0   275.7   327.6   411.8   884  1.0
## a[7]    310.4     0.5   19.9   274.0   309.9   350.8  1798  1.0
## a[8]    302.3     0.6   30.0   245.1   301.1   366.9  2254  1.0
## a[9]    286.4     0.6   26.5   229.6   288.1   335.6  2083  1.0
## a[10]   271.6     0.7   21.4   228.6   272.4   309.9  1029  1.0
## a[11]   300.2     0.6   30.9   238.8   299.6   364.6  2820  1.0
## a[12]   294.7     0.5   27.2   237.7   294.4   349.5  3215  1.0
## a[13]   323.3     0.9   29.1   273.9   320.3   388.2  1132  1.0
## a[14]   288.6     0.4   20.7   246.2   289.3   328.3  2375  1.0
## a[15]   296.1     0.5   26.2   242.6   296.4   349.1  2601  1.0
## a[16]   292.4     0.6   29.6   229.1   294.2   348.1  2629  1.0
## a[17]   296.5     0.6   28.9   236.6   296.9   354.9  2626  1.0
## a[18]   305.6     0.6   29.9   250.2   303.9   369.8  2412  1.0
## a[19]   275.5     0.9   33.8   199.8   279.9   333.0  1535  1.0
## a[20]   306.9     0.6   27.5   254.6   305.0   365.7  2284  1.0
## a[21]   290.7     0.5   25.6   238.4   291.6   341.3  2834  1.0
## a[22]   327.5     0.9   27.9   280.1   324.8   385.5   933  1.0
## a[23]   497.4     1.1   37.6   423.1   497.4   569.6  1105  1.0
## a[24]   502.8     1.0   34.0   438.0   502.9   570.1  1141  1.0
## a[25]   505.1     1.1   35.4   435.7   504.8   578.7  1101  1.0
## a[26]   495.3     1.1   37.2   423.6   495.9   570.4  1123  1.0
## a[27]   499.7     1.1   34.9   431.9   499.4   567.9  1075  1.0
## a[28]   495.9     1.2   39.5   417.4   496.9   573.9  1045  1.0
## a[29]   493.9     1.0   35.1   424.5   493.8   561.2  1274  1.0
## a[30]   493.3     1.1   39.8   413.8   493.7   570.9  1204  1.0
## b[1]      8.8     0.0    1.8     5.3     8.8    12.2  1861  1.0
## b[2]     17.6     0.1    2.4    12.9    17.6    22.2  1974  1.0
## b[3]     10.8     0.1    2.1     7.0    10.7    15.0  1668  1.0
## b[4]     14.3     0.0    2.2     9.9    14.3    18.5  2071  1.0
## b[5]     33.1     0.0    2.1    29.2    33.1    37.2  2812  1.0
## b[6]     35.8     0.1    3.1    29.3    35.9    41.3  1256  1.0
## b[7]     30.8     0.0    1.2    28.4    30.8    33.2  2411  1.0
## b[8]     25.1     0.0    1.7    21.6    25.2    28.6  2924  1.0
## b[9]     22.4     0.0    1.5    19.7    22.4    25.4  2319  1.0
## b[10]    29.2     0.0    1.8    25.7    29.2    32.9  1407  1.0
## b[11]    32.4     0.0    1.8    28.9    32.4    35.9  3470  1.0
## b[12]    34.2     0.0    1.4    31.4    34.2    37.0  3735  1.0
## b[13]    28.7     0.1    2.3    24.0    28.7    33.2  1814  1.0
## b[14]    24.1     0.0    2.0    20.1    24.1    28.2  2836  1.0
## b[15]    33.7     0.1    2.8    28.3    33.7    39.3  2916  1.0
## b[16]    19.5     0.0    1.4    16.8    19.4    22.3  2904  1.0
## b[17]    28.2     0.0    1.5    25.3    28.2    31.1  2699  1.0
## b[18]    28.7     0.0    2.4    23.7    28.7    33.3  2779  1.0
## b[19]    20.9     0.1    2.4    16.5    20.8    26.1  1976  1.0
## b[20]    33.9     0.0    1.6    30.6    33.9    36.9  2731  1.0
## b[21]    26.5     0.0    2.5    21.6    26.5    31.5  3000  1.0
## b[22]    28.1     0.0    1.4    25.3    28.1    30.6  1257  1.0
## b[23]    13.2     0.0    2.0     9.4    13.2    17.0  1594  1.0
## b[24]    13.9     0.1    2.1     9.7    13.9    17.9  1486  1.0
## b[25]    13.4     0.1    2.4     8.7    13.4    18.3  1318  1.0
## b[26]    12.1     0.1    3.4     5.3    12.0    18.6  1340  1.0
## b[27]    12.7     0.1    3.2     6.4    12.7    18.9  1410  1.0
## b[28]    11.9     0.1    2.9     6.0    11.8    17.8  1366  1.0
## b[29]    11.6     0.1    4.2     3.4    11.6    19.7  1633  1.0
## b[30]    11.6     0.1    2.1     7.5    11.6    15.8  1425  1.0
## s_ag    690.6   279.7 2071.9    61.3   217.1  5292.0    55  1.1
## s_bg    168.4   117.0  718.2     5.3    20.3  1978.0    38  1.1
## s_a      28.1     0.7   12.3     7.4    27.3    54.0   309  1.0
## s_b       4.7     0.0    0.8     3.4     4.7     6.5  3084  1.0
## s_Y      65.2     0.0    2.8    60.1    65.1    71.0  4888  1.0
## lp__  -1585.3     0.8   13.4 -1607.6 -1586.8 -1554.0   259  1.0
## 
## Samples were drawn using NUTS(diag_e) at Sun Dec 24 20:20:13 2023.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

前回同様、Stanの挙動を確認しておく。前回同様a1[g],a[k]に焦点を当てる。

まず、この部分では「業界の全体平均」と「業界差」によって、各企業が属する業界の情報が作られる。

// Gは1-3
for (g in 1:G) {
  a1[g] ~ normal(a0, s_ag);
  b1[g] ~ normal(b0, s_bg);
}
// image
a1[1] -> A1g
a1[2] -> A2g
a1[3] -> A3g

次に、この部分では「各企業が属する業界の平均」と「企業差」によって各企業の情報が作られる。

// Kは1-30
// K2G
// [1] 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3
for (k in 1:K) {
  a[k] ~ normal(a1[K2G[k]], s_a);
  b[k] ~ normal(b1[K2G[k]], s_b);
}

// image: a[k] ~ normal(a1[K2G[k]], s_a)
index01: K2G[ 1] -> 1 -> a1[1] -> A1g -> a[ 1] ~ normal(A1g, s_a) -> A01 
index05: K2G[ 5] -> 2 -> a1[2] -> A2g -> a[ 5] ~ normal(A2g, s_a) -> A05
index15: K2G[15] -> 2 -> a1[2] -> A2g -> a[15] ~ normal(A2g, s_a) -> A15
index23: K2G[23] -> 3 -> a1[3] -> A3g -> a[23] ~ normal(A3g, s_a) -> A23
index30: K2G[30] -> 3 -> a1[3] -> A3g -> a[30] ~ normal(A3g, s_a) -> A30

さらに、さきほどの情報を利用して、下記のイメージで\(Y\)が生成される。

// Nは1-300
// KID
// [1] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3 4 4 4
for (n in 1:N){}
    Y[n] ~ normal(a[KID[n]] + b[KID[n]]*X[n], s_Y);
}

            X    Y KID GID
index1  :   7  457   1   1 | KID[  1] ->  1 -> a[ 1] -> A01 -> Y[  1] ~ normal(A01 + B01*X[  1], s_Y) 
index40 :  22  728   4   1 | KID[ 40] ->  4 -> a[ 4] -> A04 -> Y[ 40] ~ normal(A04 + B04*X[040], s_Y)  
index41 :  20  927   5   2 | KID[ 41] ->  5 -> a[ 5] -> A05 -> Y[ 41] ~ normal(A05 + B05*X[041], s_Y)  
index250:  28 1070  22   2 | KID[250] -> 22 -> a[22] -> A22 -> Y[250] ~ normal(A22 + B22*X[250], s_Y)  
index251:  25  824  23   3 | KID[251] -> 23 -> a[23] -> A23 -> Y[251] ~ normal(A23 + B23*X[251], s_Y)  
index300:  20  727  30   3 | KID[300] -> 30 -> a[30] -> A30 -> Y[300] ~ normal(A30 + B30*X[300], s_Y)  

8.2.5 メカニズムの想像ーその2

さきほどのモデルは、「新卒年収」と「年収昇給額」の平均は業界ごとに異なるが、「新卒年収」と「年収昇給額」の会社差のばらつきは共通としていた。ここでは、「新卒年収」と「年収昇給額」の会社差のばらつきは業界ごとに異なると仮定する。

8.2.6 モデル式の記述ーその2

ここで想定しているモデルは下記の通り。

モデル8-6

\[ \begin{eqnarray} Y[n] &\sim& Normal(a[KID[n]] + b[KID[n]] X[n], \sigma_{Y}[GID[n]]) \\ a_{業界平均}[g] &\sim& Normal(a_{全体平均}, \sigma_{ag}) \\ b_{業界平均}[g] &\sim& Normal(b_{全体平均}, \sigma_{bg}) \\ a[k] &\sim& Normal(a_{業界平均}[K2G[k]], \sigma_{a}[K2G[n]]) \\ b[k] &\sim& Normal(b_{業界平均}[K2G[k]], \sigma_{b}[K2G[n]]) \\ \end{eqnarray} \]

8.2.7 Stanで実装ーその2

先程のモデルとの違いは、s_a[K2G[k]], s_b[K2G[k]]の部分に現れている。

data {
  int N;
  int G;
  int K;
  real X[N];
  real Y[N];
  int<lower=1, upper=K> KID[N];
  int<lower=1, upper=G> K2G[K];
  int<lower=1, upper=G> GID[N];
}

parameters {
  real a0;
  real b0;
  real a1[G];
  real b1[G];
  real a[K];
  real b[K];
  real<lower=0> s_ag;
  real<lower=0> s_bg;
  real<lower=0> s_a[G];
  real<lower=0> s_b[G];
  real<lower=0> s_Y[G];
}

model {
  s_ag ~ normal(0, 1e5);
  s_bg ~ normal(0, 1e5);

  for (g in 1:G) {
    a1[g] ~ normal(a0, s_ag);
    b1[g] ~ normal(b0, s_bg);
  }

  for (k in 1:K) {
    a[k] ~ normal(a1[K2G[k]], s_a[K2G[k]]);
    b[k] ~ normal(b1[K2G[k]], s_b[K2G[k]]);
  }

  for (n in 1:N)
    Y[n] ~ normal(a[KID[n]] + b[KID[n]]*X[n], s_Y[GID[n]]);
}

ここでは、stan_model()関数で最初にコンパイルしておいてから、

model86 <- stan_model('note_ahirubayes08-86.stan')

sampling()関数でサンプリングする。

K2G <- unique(d[ , c('KID','GID')])$GID
data <- list(N = N, G = G, K = K, X = d$X, Y = d$Y, KID = d$KID, K2G = as.numeric(K2G), GID = as.numeric(d$GID))
fit <- sampling(object = model86, data = data, seed = 1989)
print(fit, prob = c(0.025, 0.5, 0.975), digits_summary = 1)
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##           mean se_mean     sd    2.5%     50%   97.5% n_eff Rhat
## a0       275.4    85.9 1000.1 -1322.1   394.2  1092.1   136    1
## b0       -12.1    34.1  473.6   -63.2    18.2   102.0   192    1
## a1[1]    373.7     2.7   89.8   197.0   368.0   552.1  1074    1
## a1[2]    300.1     0.6   17.1   266.5   299.6   335.0   847    1
## a1[3]    500.5     0.3   10.0   479.2   500.9   519.7  1239    1
## b1[1]     13.1     0.1    6.1     1.1    13.1    25.8  1945    1
## b1[2]     28.5     0.1    1.7    25.1    28.5    31.9   637    1
## b1[3]     12.4     0.0    0.6    10.9    12.4    13.4  1071    1
## a[1]     382.4     0.3   14.9   353.7   382.5   410.9  1826    1
## a[2]     335.3     0.4   17.4   302.4   335.2   369.7  2133    1
## a[3]     324.3     1.4   34.6   247.4   326.1   387.3   603    1
## a[4]     469.9     4.0  124.6   302.7   440.3   764.7   948    1
## a[5]     292.7     0.7   29.3   231.0   293.3   350.4  1567    1
## a[6]     333.3     1.9   42.4   269.5   325.7   431.2   524    1
## a[7]     310.8     0.8   23.8   267.5   309.6   361.4   956    1
## a[8]     302.2     1.2   36.8   227.3   300.6   382.5   975    1
## a[9]     285.3     0.9   33.1   209.8   288.8   344.9  1447    1
## a[10]    273.1     1.3   25.4   220.9   274.4   321.7   379    1
## a[11]    300.4     0.8   39.1   223.8   298.6   385.5  2143    1
## a[12]    294.7     0.7   30.9   229.5   295.3   356.0  2022    1
## a[13]    324.0     1.3   33.4   267.6   319.8   401.5   692    1
## a[14]    289.1     0.7   24.6   236.7   289.9   335.8  1413    1
## a[15]    295.4     0.7   31.1   230.0   296.2   356.0  1964    1
## a[16]    292.9     1.0   37.7   213.5   294.0   370.2  1404    1
## a[17]    295.3     0.8   33.9   222.7   296.4   361.6  1735    1
## a[18]    307.3     1.2   35.5   242.3   304.0   384.8   926    1
## a[19]    274.1     1.3   41.0   176.9   281.6   338.5  1058    1
## a[20]    306.0     0.7   30.3   247.8   304.2   370.4  1777    1
## a[21]    289.8     0.7   29.4   227.1   291.0   346.2  1836    1
## a[22]    328.0     1.4   33.7   272.6   323.9   401.5   572    1
## a[23]    497.7     0.2   10.9   475.4   498.2   518.8  2067    1
## a[24]    515.4     0.2    8.8   498.6   515.4   532.7  1587    1
## a[25]    524.1     0.3    9.6   505.9   523.9   543.2  1263    1
## a[26]    494.1     0.3   11.0   474.6   493.5   516.5  1082    1
## a[27]    506.9     0.2    8.6   491.1   506.5   525.3  1389    1
## a[28]    497.3     0.5   14.4   472.1   496.6   527.9   758    1
## a[29]    488.5     0.4   10.8   468.8   487.8   510.6   795    1
## a[30]    482.5     0.6   15.0   454.6   481.9   511.7   713    1
## b[1]       7.8     0.0    0.9     6.0     7.8     9.6  1989    1
## b[2]      19.4     0.0    1.3    16.8    19.4    21.9  2409    1
## b[3]      11.9     0.1    1.7     8.8    11.9    15.7   761    1
## b[4]      10.2     0.2    5.0    -1.8    11.5    17.2   979    1
## b[5]      33.2     0.1    2.5    28.5    33.1    38.2  1820    1
## b[6]      35.8     0.1    3.5    28.0    36.0    41.9   921    1
## b[7]      30.7     0.1    1.5    27.7    30.7    33.6   710    1
## b[8]      25.1     0.1    2.1    20.9    25.1    29.2  1153    1
## b[9]      22.5     0.1    1.8    19.2    22.4    26.3  1092    1
## b[10]     29.1     0.1    2.2    24.7    29.1    33.6   474    1
## b[11]     32.4     0.0    2.3    27.8    32.4    36.6  2415    1
## b[12]     34.2     0.0    1.6    31.2    34.2    37.4  2421    1
## b[13]     28.7     0.1    2.6    23.0    28.8    33.5  1295    1
## b[14]     24.0     0.1    2.4    19.4    23.9    28.9  2222    1
## b[15]     33.7     0.1    3.4    27.3    33.6    40.4  1491    1
## b[16]     19.5     0.0    1.7    16.1    19.4    23.2  1615    1
## b[17]     28.2     0.0    1.7    25.0    28.2    31.7  1800    1
## b[18]     28.6     0.1    2.9    22.8    28.7    34.2   751    1
## b[19]     21.0     0.1    2.9    15.8    20.7    27.3  1428    1
## b[20]     34.0     0.0    1.8    30.4    34.0    37.3  2224    1
## b[21]     26.6     0.1    2.9    21.0    26.5    32.5  1922    1
## b[22]     28.0     0.0    1.6    24.7    28.2    30.9  1095    1
## b[23]     13.1     0.0    0.5    12.2    13.1    14.1  2167    1
## b[24]     13.3     0.0    0.5    12.3    13.3    14.2  1600    1
## b[25]     12.5     0.0    0.6    11.4    12.5    13.5  1534    1
## b[26]     12.1     0.0    0.9    10.2    12.2    13.7  1094    1
## b[27]     12.2     0.0    0.7    10.7    12.3    13.5  1403    1
## b[28]     11.7     0.0    0.9     9.7    11.8    13.3   860    1
## b[29]     11.8     0.0    1.1     9.4    12.0    13.6   827    1
## b[30]     12.1     0.0    0.7    10.6    12.1    13.4   773    1
## s_ag     663.7   140.9 1928.0    66.4   222.3  4297.3   187    1
## s_bg     133.2    67.6 1299.6     6.0    18.5   428.4   369    1
## s_a[1]   149.2     7.4  197.2     9.9    91.0   664.3   718    1
## s_a[2]    33.5     1.2   18.8     3.1    31.5    74.9   228    1
## s_a[3]    20.5     0.3    9.4     5.7    19.0    42.7   923    1
## s_b[1]    10.6     0.3    9.2     3.1     7.8    33.7  1146    1
## s_b[2]     5.7     0.0    1.3     3.7     5.5     8.8  1691    1
## s_b[3]     1.0     0.0    0.6     0.2     0.9     2.5   632    1
## s_Y[1]    28.6     0.1    3.8    22.1    28.3    37.1  2469    1
## s_Y[2]    76.2     0.1    4.0    68.8    76.1    84.4  1844    1
## s_Y[3]    11.6     0.0    1.4     9.3    11.5    14.8  2705    1
## lp__   -1474.7     1.2   13.7 -1498.0 -1476.3 -1439.1   141    1
## 
## Samples were drawn using NUTS(diag_e) at Sun Dec 24 20:20:57 2023.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

書籍のモデルでは、事前分布が省略されていたので、省略しないバージョンもメモしておく。結果は似たような結果と鳴る。

model86_ <- stan_model('note_ahirubayes08-86_.stan')
fit <- sampling(object = model86_, data = data, seed = 1989)

結果はこちら。

print(fit, prob = c(0.025, 0.5, 0.975), digits_summary = 1)
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##           mean se_mean     sd    2.5%     50%   97.5% n_eff Rhat
## a0       413.5    35.1  698.9  -515.7   396.6  1263.7   396    1
## b0        15.5     2.2   47.8   -58.7    18.3    83.1   477    1
## a1[1]    381.3     2.8   91.4   235.2   367.3   580.0  1051    1
## a1[2]    299.7     0.5   16.9   266.5   299.4   334.7  1069    1
## a1[3]    500.6     0.3   10.1   479.3   500.8   519.9  1200    1
## b1[1]     12.9     0.2    6.0    -0.8    13.1    24.1   928    1
## b1[2]     28.5     0.0    1.7    25.2    28.6    31.8  1258    1
## b1[3]     12.4     0.0    0.6    11.0    12.4    13.4  1165    1
## a[1]     383.4     0.5   15.1   353.5   383.6   411.7   961    1
## a[2]     334.3     0.4   17.3   300.3   334.6   367.9  1893    1
## a[3]     322.3     1.0   33.4   255.9   323.2   384.8  1116    1
## a[4]     470.0     4.4  128.0   301.4   437.8   785.9   851    1
## a[5]     294.5     0.7   29.5   231.0   296.0   353.2  2013    1
## a[6]     329.7     1.8   42.5   267.8   321.1   432.1   564    1
## a[7]     309.6     0.6   23.0   267.1   308.1   356.9  1417    1
## a[8]     301.2     0.7   34.0   230.1   300.3   373.8  2244    1
## a[9]     286.7     0.9   32.0   213.2   290.2   343.8  1197    1
## a[10]    274.0     1.6   26.4   217.9   276.0   319.9   282    1
## a[11]    299.9     0.8   36.4   222.7   299.6   376.0  2162    1
## a[12]    295.3     0.7   30.5   230.0   296.6   354.8  2125    1
## a[13]    321.5     1.2   32.9   267.4   316.9   395.7   753    1
## a[14]    289.6     0.7   24.1   239.3   291.4   334.2  1243    1
## a[15]    295.7     0.7   29.6   232.5   296.5   353.9  2074    1
## a[16]    294.1     0.7   35.8   216.1   296.1   364.2  2292    1
## a[17]    296.3     0.8   33.9   222.0   297.7   363.6  1712    1
## a[18]    305.3     0.8   34.6   238.9   303.2   381.0  1755    1
## a[19]    276.9     1.5   40.1   180.6   283.2   342.8   677    1
## a[20]    306.4     0.7   31.2   248.4   304.4   374.5  2086    1
## a[21]    291.2     0.8   29.8   227.8   293.2   347.1  1415    1
## a[22]    325.7     1.5   32.2   273.0   321.1   397.3   459    1
## a[23]    497.6     0.2   10.7   475.2   498.0   518.2  2537    1
## a[24]    516.1     0.2    9.1   498.4   516.3   533.5  1513    1
## a[25]    525.1     0.2    9.5   505.6   525.1   543.7  1472    1
## a[26]    494.1     0.3   11.0   473.6   493.2   516.9  1257    1
## a[27]    507.2     0.2    8.3   491.6   507.3   524.0  2328    1
## a[28]    497.3     0.5   14.3   471.5   496.4   526.9   867    1
## a[29]    488.5     0.4   10.4   469.7   487.8   510.0   850    1
## a[30]    481.6     0.5   14.5   455.1   480.5   511.4   960    1
## b[1]       7.8     0.0    0.9     6.0     7.7     9.7  1593    1
## b[2]      19.4     0.0    1.3    16.9    19.4    22.0  1294    1
## b[3]      12.1     0.0    1.7     9.0    12.1    15.2  1337    1
## b[4]      10.3     0.2    5.1    -2.0    11.5    17.3   857    1
## b[5]      33.1     0.0    2.5    28.3    33.1    38.0  2545    1
## b[6]      36.0     0.1    3.6    28.3    36.3    42.1   647    1
## b[7]      30.8     0.0    1.5    27.9    30.9    33.6  1376    1
## b[8]      25.2     0.0    2.0    21.1    25.2    29.1  2401    1
## b[9]      22.4     0.0    1.8    19.1    22.4    26.2  1534    1
## b[10]     29.0     0.1    2.3    24.7    29.0    33.7   336    1
## b[11]     32.4     0.0    2.2    28.4    32.3    36.7  1999    1
## b[12]     34.1     0.0    1.6    31.0    34.1    37.4  2109    1
## b[13]     28.8     0.1    2.7    23.4    28.9    33.8   920    1
## b[14]     23.9     0.1    2.3    19.5    23.8    28.7  1724    1
## b[15]     33.8     0.1    3.2    27.5    33.6    40.6  1980    1
## b[16]     19.4     0.0    1.7    16.2    19.4    22.9  2306    1
## b[17]     28.2     0.0    1.7    24.8    28.2    31.6  2477    1
## b[18]     28.8     0.1    2.9    22.8    28.9    34.2  1545    1
## b[19]     20.9     0.1    2.8    15.5    20.7    27.1  1065    1
## b[20]     34.0     0.0    1.8    30.4    34.0    37.5  2250    1
## b[21]     26.5     0.1    2.9    20.7    26.5    32.6  1711    1
## b[22]     28.1     0.0    1.6    24.8    28.1    30.9  1008    1
## b[23]     13.1     0.0    0.5    12.3    13.1    14.1  1336    1
## b[24]     13.2     0.0    0.5    12.3    13.2    14.2  1413    1
## b[25]     12.4     0.0    0.6    11.3    12.4    13.5  1449    1
## b[26]     12.1     0.0    0.9    10.2    12.2    13.7  1164    1
## b[27]     12.2     0.0    0.7    10.8    12.3    13.5  2015    1
## b[28]     11.7     0.0    0.9     9.8    11.8    13.4   851    1
## b[29]     11.9     0.0    1.1     9.4    12.1    13.5   795    1
## b[30]     12.1     0.0    0.7    10.7    12.2    13.4   904    1
## s_ag     628.0    77.7 1842.8    62.9   241.7  3408.0   562    1
## s_bg      51.1     6.8  211.8     5.8    19.6   301.1   982    1
## s_a[1]   142.2     6.1  170.7    11.9    90.2   590.2   785    1
## s_a[2]    31.1     1.7   19.1     2.3    29.6    71.4   129    1
## s_a[3]    21.1     0.3    9.7     6.6    19.4    45.0  1046    1
## s_b[1]    10.6     0.3    9.2     3.2     7.8    34.5  1018    1
## s_b[2]     5.7     0.0    1.2     3.7     5.6     8.6  2518    1
## s_b[3]     1.0     0.0    0.6     0.1     0.9     2.4   664    1
## s_Y[1]    28.6     0.1    3.9    22.1    28.3    37.4  1903    1
## s_Y[2]    76.1     0.1    4.0    68.7    75.8    84.5  3336    1
## s_Y[3]    11.6     0.0    1.4     9.3    11.5    14.8  2066    1
## lp__   -1472.9     1.8   15.8 -1496.8 -1475.7 -1431.3    75    1
## 
## Samples were drawn using NUTS(diag_e) at Sun Dec 24 20:21:34 2023.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

先程と異なる点は標準偏差の部分だけではあるが、こちらのモデルでもaに関して同様に挙動を確認しておく。

// 事前分布
a0 ~ normal(0, 1e5);
b0 ~ normal(0, 1e5);
s_ag ~ normal(0, 1e5);
s_bg ~ normal(0, 1e5);

この部分で、a1[1],a1[2],a1[3],s_a[1],s_a[2],s_a[3]を生成する。

// G:1-3
  for (g in 1:G) {
    a1[g] ~ normal(a0, s_ag);
    b1[g] ~ normal(b0, s_bg);
    s_a[g] ~ normal(0, 1e5);
    s_b[g] ~ normal(0, 1e5);
    s_Y[g] ~ normal(0, 1e5);
  }

// image
a1[1] -> A1g
a1[2] -> A2g
a1[3] -> A3g
s_a[1] -> SA1g
s_a[2] -> SA2g
s_a[3] -> SA3g
s_Y[1] -> SY1g
s_Y[2] -> SY2g
s_Y[3] -> SY3g

次に、この部分では「各企業が属する業界の平均」と「企業差」によって各企業の情報が作られる。このとき、業界ごとに標準偏差も異なるので注意。

// K:1-30
// K2G
// [1] 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3
for (k in 1:K) {
  a[k] ~ normal(a1[K2G[k]], s_a[K2G[k]]);
  b[k] ~ normal(b1[K2G[k]], s_b[K2G[k]]);
}
// image: a[k] ~ normal(a1[K2G[k]], s_a[K2G[k]])
index01: K2G[ 1] -> 1 -> a1[1],s_a[1] -> A1g,SA1g -> a[ 1] ~ normal(A1g, SA1g) -> A01 
index05: K2G[ 5] -> 2 -> a1[2],s_a[2] -> A2g,SA2g -> a[ 5] ~ normal(A2g, SA2g) -> A05
index15: K2G[15] -> 2 -> a1[2],s_a[2] -> A2g,SA2g -> a[15] ~ normal(A2g, SA2g) -> A15
index23: K2G[23] -> 3 -> a1[3],s_a[3] -> A3g,SA3g -> a[23] ~ normal(A3g, SA3g) -> A23
index30: K2G[30] -> 3 -> a1[3],s_a[3] -> A3g,SA3g -> a[30] ~ normal(A3g, SA3g) -> A30

さらに、さきほどの情報を利用して、下記のイメージで\(Y\)が生成される。

// N:1-300
for (n in 1:N)
  Y[n] ~ normal(a[KID[n]] + b[KID[n]]*X[n], s_Y[GID[n]]);
}

// Nは1-300
// KID
// [1] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3 4 4 4
            X    Y KID GID
index1  :   7  457   1   1 | KID[  1] ->  1 -> a[ 1] -> A01 -> Y[  1] ~ normal(A01 + B01*X[  1], SY1g) 
index40 :  22  728   4   1 | KID[ 40] ->  4 -> a[ 4] -> A04 -> Y[ 40] ~ normal(A04 + B04*X[040], SY1g)  
index41 :  20  927   5   2 | KID[ 41] ->  5 -> a[ 5] -> A05 -> Y[ 41] ~ normal(A05 + B05*X[041], SY2g)  
index250:  28 1070  22   2 | KID[250] -> 22 -> a[22] -> A22 -> Y[250] ~ normal(A22 + B22*X[250], SY2g)  
index251:  25  824  23   3 | KID[251] -> 23 -> a[23] -> A23 -> Y[251] ~ normal(A23 + B23*X[251], SY3g)  
index300:  20  727  30   3 | KID[300] -> 30 -> a[30] -> A30 -> Y[300] ~ normal(A30 + B30*X[300], SY3g)