UPDATE: 2024-01-07 11:28:22.084485

はじめに

このノートは「ベイズ統計」に関する何らかの内容をまとめ、ベイズ統計への理解を深めていくために作成している。

今回はStanのチュートリアルにもある「Eight Schools」の問題を例に階層ベイスモデルへの理解を深める。Bayesian Data Analysisのチャプター5「Hierarchical models」と、非常にわかりやすい解説がされているこちらの記事を、参考にさせていただいた。

Eight Schoolsの問題と分析目的

TensorFlowのチュートリアルに、ここで扱うEight Schoolsの問題が説明されていたので、内容をお借りする。

『Bayesian Data Analysis』第 5.5 項(Gelman et al. 2013)の抄訳: 8つの高校で実施された SAT-V(Scholastic Aptitude Test-Verbal)特別指導プログラムの効果を分析するために、Educational Testing Service の調査が実施されました。各調査の結果変数は、SAT-V の特別実施のスコアでした。これは、Educational Testing Service が運営し、大学が入学決定を下す際に使用する標準化された多肢選択式テストです。スコアは 200 点から 800 点の間で変動し、平均は約 500 点、標準偏差は約 100 点です。SA​​T 試験は、特に試験の成績を上げるための短期的な取り組みを評価するものではなく、長期にわたる学習で得た知識と能力の開発を反映するように設計されています。それにもかかわらず、この調査の対象となった 8 つの高校では、その短期指導プログラムが SAT のスコアの引き上げに非常に役立つと考えました。また、8 つのプログラムのいずれも、他のプログラムよりも効果的であったり、いくつかのプログラムの効果の類似性が他のプログラムのものよりも高かったりすることを予め信じる理由はありませんでした。

要するに、8つの高校で、特別指導プログラムを実施。特別指導プログラムの前後でテストを受け、「どのくらい」テストの点数が上がったのか(下がったのか)が記録されている。特別指導プログラムは何点くらいの引き上げ効果があるのかを知りたい。

library(dplyr)
library(rstan)
library(ggplot2)

options(max.print = 999999)
rstan_options(auto_write=TRUE)
options(mc.cores=parallel::detectCores())

d <- data.frame(
  school = c('A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'),
  y = c(28, 8, -3, 7, -1, 1, 18, 12),
  s = c(15, 10, 16, 11, 9, 11, 10, 18)
  )
j <- nrow(d)
d
##   school  y  s
## 1      A 28 15
## 2      B  8 10
## 3      C -3 16
## 4      D  7 11
## 5      E -1  9
## 6      F  1 11
## 7      G 18 10
## 8      H 12 18

高校Aでは前後で28点の点数(標準誤差は15)の上昇が見られ、特別指導プログラムは有効に見える一方で、高校Cでは点数が-3(標準誤差は16)の下降が見られ、特別指導プログラムの効果はないようにも見える。さて、このような状況下で、どのようにすれば特別指導プログラムの効果を妥当な形で推定することができるだろうか。これがEight Schools問題。

ggplot(data = d) +
  theme_bw(base_size = 15) +
  geom_point(aes(school, y)) + 
  geom_errorbar(aes(school, y, ymax = y + s, ymin = y - s), width = 0.2) + 
  geom_hline(aes(yintercept = mean(d$y)), col = 'tomato', linetype = 'dashed') +
  labs(x = 'school', y = 'y', title = '8 School Problem')

メカニズムの想像

参考にしている記事にもある通り、3つのメカニズムを想定できる。ここでも同じように3つのモデルを想定する。

Models
Models

モデル1

各高校\(i=A,B,...,H\)ごとの点数を\(Y_{i}\)、標準誤差を\(S_{i}\)、指導効果を\(\theta_{i}\)とすると、\(Y_{i}\)は平均\(\theta_{i}\)、標準偏差\(S_{i}\)の正規分布から生成されると考える。8つの高校は互いに無関係。各高校の効果はその高校のデータのみで推定する。

モデル1

\[ \begin{eqnarray} Y_{i} \sim Normal(\theta_{i}, S_{i}) \end{eqnarray} \]

無情報事前分布を使い、8つ高校を個別に推定すると、各校の事後分布の95%信用区間は下記の通り。

# 各列ごとに各パラメタから乱数が生成され1000×8列
# matrix(rnorm(30, mean = c(10,100,1000), sd = c(1,5,10)), nrow = 10, byrow = TRUE)
m <- matrix(rnorm(1000 * j, mean = d$y, sd = d$s), nrow = 1000, byrow = TRUE)
tibble(
  median = apply(m, 2, median),
  lower  = apply(m, 2, function(x) quantile(x, .025)),
  upper  = apply(m, 2, function(x) quantile(x, .975))
) %>% 
  bind_cols(d) %>% 
  ggplot() +
  theme_bw(base_size = 15) +
  geom_point(aes(x = school, y = median), col = 'tomato', size = 3) +
  geom_errorbar(aes(school, median, ymax = upper, ymin = lower), width = 0.2, col = 'tomato') +
  geom_point(aes(x = school, y = y), alpha = 1/2) +
  labs(x = 'school', y = 'y', title = '8 School Problem')

推定を別々に行っており、事後分布の中央値の点推定値は、\(θ_{i}=y_{i}\)の近くに分布している。推定のばらつきを見ると、すべての高校の分布が重なっているので、各高校が別々の\(θ_{i}\)を持っているという前提はあまり妥当ではない。

モデル2

全高校共通で効果\(\theta\)が1つあると仮定する。そして、この\(\theta\)の正規分布から8つ高校のデータが生成されたと考える。モデル1とは異なり、全ての高校が\(\theta\)を通して関係を持っている。

モデル2

\[ \begin{eqnarray} Y_{i} \sim Normal(\theta, S) \end{eqnarray} \]

Bayesian Data Analysis](http://www.stat.columbia.edu/~gelman/book/)のp114,p120あたり、今回のようなデータ生成過程に対する事後分布の平均と標準偏差に関する計算方法が記載されている

# Bayesian Data Analysis: http://www.stat.columbia.edu/~gelman/book/
# p114, p120に計算式の理由は記載されている
# 標準誤差が大きい=特別指導プログラムの生徒が少ない=Yが大きくばらつく。
# 標準誤差の逆数で、標準誤差が大きいものの影響は小さくするというようなニュアンス
pool_s2 <- sum(1/(d$s)^2)
pool_y <- sum(1/d$s^2 * d$y)/pool_s2
set.seed(1989)
pool_post <- rnorm(10000, mean = pool_y, sd = sqrt(pool_s2^(-1)))
quantile(pool_post, c(.025, .5, .975))
##       2.5%        50%      97.5% 
## -0.2516728  7.6848589 15.4872032

1つの\(\theta\)によって生成されると仮定した場合、高校Aで観察された効果\(\theta_{H}=28\)が、このような仮定のもとでの分布から出てくる可能性は低いため、このモデルではデータのばらつきをうまく説明できない。

sum(pool_post > 28)
## [1] 0

モデル3

どの高校でも共通効果\(\mu\)を考える。そして、各高校ごとの効果である\(\theta_{i}\)は、共通効果\(\mu\)と標準偏差\(\sigma\)をもつ正規分布から生成されたと考える。同じ特別指導プログラムを実施しているので、同じような共通効果\(\mu\)が存在しており、違いは各高校ごとの生徒、指導者、何らかの高校に由来する差によるものと考える。

このモデルでは各高校の\(\theta_{i}\)が1つ1つの高校差を表しているのと同時に、各\(\theta_{i}\)が同じ正規分布から生成されたと考えることで、高校同士が互いに関係を持つことになる。パラメタの関係性が階層構造になっているので、階層ベイスモデルと呼ばれる。

高校に由来する差の大きさは\(\sigma\)によってコントロールされる。この\(\sigma\)が仮に\(\infty\)だとどうなるかというと、高校差が無限に大きくなる。つまり、各高校の\(\theta_{i}\)はどんな値でも取れる自由な状態となり、モデル1と同じになる。一方で、この\(\sigma\)が仮に0だとどうなるかというと、\(\theta_{i}\)はどの高校でも同じになるため、高校差はなくなり、モデル2と同じになる。

モデル3

\[ \begin{eqnarray} Y_{i} &\sim& Normal(\theta_{i}, S_{i}) \\ \theta_{i} &\sim& Normal(\mu, \sigma) \\ \end{eqnarray} \]

モデル3を実装する

モデル3のモデル構造をStanのモデルで書き直す。reparameterizationしたバージョンのモデル式ではなく、素直にモデル化している。

data {
  int<lower=1> J;
  real Y[J];
  real<lower=0> S[J];
}
parameters {
  real theta[J];
  real mu;
  real<lower=0> sigma;
}
model {
  for (j in 1:J) {
    theta[j] ~ normal(mu, sigma);
  }

  for (j in 1:J) {
    Y[j] ~ normal(theta[j], S[j]);
  }
}

データを用意して、

data <- list(Y = d$y, S = d$s, J = j)
data
## $Y
## [1] 28  8 -3  7 -1  1 18 12
## 
## $S
## [1] 15 10 16 11  9 11 10 18
## 
## $J
## [1] 8

ここでは、stan_model()関数で最初にコンパイルしておいてから、

model <- stan_model('note_bayes01−001.stan')

sampling()関数でサンプリングする。

fit <- sampling(object = model, data = data, seed = 1989)

推定結果を確認する。

  • 特別指導プログラムの効果\(\mu\)は7.64程度の点数の上昇を期待できる
  • ただ、95%信用区間が[-1.73, 17.71]なので、点数が下がる可能性もある
  • 高校差を表す\(\sigma\)は6.89程度のばらつきがある
print(fit)
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##            mean se_mean   sd   2.5%    25%    50%    75% 97.5% n_eff Rhat
## theta[1]  11.24    0.37 8.55  -2.51   5.86  10.05  15.67 31.49   536 1.01
## theta[2]   7.83    0.21 6.23  -3.83   3.92   7.55  11.69 20.73   852 1.01
## theta[3]   5.99    0.21 7.78 -10.91   1.75   6.47  10.75 20.55  1393 1.01
## theta[4]   7.30    0.20 6.54  -5.84   3.19   7.18  11.40 20.65  1068 1.01
## theta[5]   4.76    0.24 6.35  -9.37   0.93   5.18   8.96 16.31   678 1.01
## theta[6]   5.78    0.22 6.75  -9.45   1.87   5.93   9.91 18.57   937 1.01
## theta[7]  10.47    0.26 6.86  -1.13   6.08   9.61  14.57 25.18   716 1.01
## theta[8]   8.12    0.20 7.90  -7.61   3.32   7.66  12.61 24.78  1627 1.00
## mu         7.64    0.20 5.08  -1.73   4.27   7.50  10.94 17.71   667 1.01
## sigma      6.89    0.34 5.42   1.31   2.93   5.52   9.27 20.59   251 1.02
## lp__     -17.76    0.49 5.38 -27.62 -21.67 -18.19 -13.88 -7.21   119 1.03
## 
## Samples were drawn using NUTS(diag_e) at Sun Jan  7 11:28:26 2024.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

事後分布を可視化しておく。

plot(fit)

# 事後分布を可視化する関数
plot_posterior <- function(posterior_samples, 
                           title = 'Posterior Distribution',
                           x_title = 'Parameter',
                           xlim = NULL) {
  
  if (is.null(xlim)) {
    range <- range(posterior_samples) 
    xlim <- c(range[1] - abs(diff(range)) * 0.05, 
              range[2] + abs(diff(range)) * 0.05)
  }
  
  q <- quantile(posterior_samples, probs = c(0.025, 0.975))
  m <- mean(posterior_samples)
  
  ggplot(data.frame(x = posterior_samples), aes(x = x)) +
    geom_histogram(aes(y = ..density..), color = 'black', bins = 50, alpha = 0.2) +
    geom_density(fill = "steelblue", alpha = 0.5) + 
    geom_vline(xintercept = q[[1]], linetype = 3) + 
    geom_vline(xintercept = q[[2]], linetype = 3) + 
    geom_vline(xintercept = m, linetype = 3) + 
    labs(title = title,  
         x = x_title,
         y = "Density") + 
    scale_x_continuous(breaks = scales::pretty_breaks(n = 10)) + 
    xlim(xlim[[1]], xlim[[2]]) + 
    theme_minimal() +
    theme(plot.title = element_text(size = 14, face = "bold"),
          axis.title = element_text(size = 12))
}

out <- rstan::extract(fit)
plot_posterior(
    out$mu, 
    title = 'Posterior Probability Distribution of mu', 
    x_title = 'mu',
    xlim = c(-10, 30)
)