UPDATE: 2024-01-04 04:59:15.903385

はじめに

このノートは「StanとRでベイズ統計モデリング」の内容を写経することで、ベイズ統計への理解を深めていくために作成している。

基本的には気になった部分を写経しながら、ところどころ自分用の補足をメモすることで、「StanとRでベイズ統計モデリング」を読み進めるための自分用の補足資料になることを目指す。私の解釈がおかしく、メモが誤っている場合があるので注意。

今回は第10章「離散値をとるパラメータを使う」のチャプターから写経していく。

11.1 離散パラメタを扱うテクニック

stanでは離散値を取るパラメタを扱うことはできない。そのため、場合の数を足し上げて、離散バラメタを消去した形で対数尤度を表現することで対処する。

11.1.1 log_sum_exp関数

離散バラメタを消去した形で対数尤度を表現するためにはtarget記法と和の対数を計算するlog_sum_exp関数が必要になる。log_sum_exp関数は複数の実数値を引数に取る。

\[ log\_sum\_exp(x, y) = log(exp(x) + exp(y)) \]

ベクトルも渡すことができ、右辺がlog_sum_expという形になっている。

\[ log\_sum\_exp\left( \left( \begin{array}{c} x_1 \\ \vdots \\ x_n \end{array} \right) \right) = log\left( \sum_{k=1}^{K} exp(x_{k}) \right) \]

また、\(exp(x_{k}) = u_{k} \iff x_k = log \ u_k\)とおくと、下記の通り表現ができる。

\[ log\_sum\_exp\left( \left( \begin{array}{c} log \ u_1 \\ \vdots \\ log \ u_n \end{array} \right) \right) = log\left( \sum_{k=1}^{K} u_{k} \right) \]

これは右辺から左辺への変形とみると、logの引数である\(\Sigma\)をベクトルに変換している。そのベクトルの各要素は、\(\Sigma\)の各項目のlogを取ったものである。

11.1.2 周辺化消去

stanでは離散値を取るパラメタを扱うことはできない。そのため、モデル式を変形し、離散パラメタを消去する必要がある。その方法の1つに「場合の数をすべて数え上げて各々の場合の確率を算出して和を取ることで、離散パラメタを消去する」という方法がある。これを周辺化消去という。周辺化消去は「同時に起こることは掛け算、排他的に起こることは足し算」という確率の基本ルールを利用する。下記の例を通りsて理解を深める。

m <- matrix(
  c(3/10, 1/10,
    1/10, 5/10),
  2, 2, byrow = TRUE
)
colnames(m) <- c('y_Red', 'y_White')
rownames(m) <- c('x_BoxA', 'x_BoxB')
m
##        y_Red y_White
## x_BoxA   0.3     0.1
## x_BoxB   0.1     0.5

例えば、同時確率\(p(x_{BoxA},y_{Red})\)\(0.3\)であり、周辺確率\(p(x_{BoxA}),p(x_{BoxB})\)

apply(m, 2, sum)
##   y_Red y_White 
##     0.4     0.6

である。周辺確率は「求めたい変数以外の変数の全ての取り得る確率の足し合わせ」を行った。つまり、各箱において\(p(x_{BoxA}),p(x_{BoxB})\)を消去したことになる。例えば、赤玉の個数は「箱A、箱B」の赤玉を足せば良く、箱が\(N\)個あっても同じように足し合わせればよい。

ベルヌイ分布に従う離散パラメタを扱う例で理解を深める。高校の生徒を対象に喫煙経験を回答するアンケートを行う。ただ、喫煙経験は素直に回答できないので、コインを振って表であれば素直に回答し、裏であれば常にYesと回答する。Yesの場合は\(Y=1\)と回答する。このアンケートから喫煙経験\(q\)を推定したい。

モデル11-1

\[ \begin{eqnarray} coin[n] &\sim& Bernoulli(0.5) \\ \theta[1] &=& q \\ \theta[2] &=& 1.0 \\ Y[n] &\sim& Bernoulli(\theta[coin[n] + 1]) \\ \end{eqnarray} \]

このモデルでは\(coin[n]\)は表なら0、裏なら1をとる確率変数であり、分析者は知ることができない離散パラメタとなる。つまり、\(coin[n]\)は離散パラメタであり、int型のパラメタとして宣言できない。そこでコイントスを場合の数を数え上げて和を取ることで、int型のパラメタを消去する。1人の具体的な例を通して考える。

coin
coin

回答データ(Y=1 or Y=0)を生成するには、図中の上側のルートと下側のルートの2通りの場合があることを表している。この2つは同時に起こらず排他的なので、回答データを生成する確率分布はこれらの和である。

\[ \begin{eqnarray} p(y|q) &=& 0.5 * Bernoulli(y|q) + 0.5 * Bernoulli(y|1.0) \\ \end{eqnarray} \]

途中のコイントスを結果を知らなくても回答がYes、Noになる確率を求めることができる。結果として、1人がアンケートで回答した場合の尤度は下記のとおりである。

\[ \begin{eqnarray} p(Y|q) &=& 0.5 * Bernoulli(Y|q) + 0.5 * Bernoulli(Y|1.0) \\ \end{eqnarray} \]

Stanで実装するには、この尤度の対数をとって対数尤度をを求めて、targetを使った記法で足し込む。

\[ \begin{eqnarray} log \ p(Y|q) &=& \log\_sum\_exp(log \ 0.5 * log \ Bernoulli(Y|q) ,\ log \ 0.5 * log \ Bernoulli(Y|1.0)) \\ \end{eqnarray} \]

Stanのモデルは下記の通り。

data {
  int N;
  int<lower=0, upper=1> Y[N];
}

parameters {
  real<lower=0, upper=1> q;
}

model {
  for (n in 1:N)
    target += log_sum_exp(
      log(0.5) + bernoulli_lpmf(Y[n] | q),
      log(0.5) + bernoulli_lpmf(Y[n] | 1)
    );
}
library(dplyr)
library(rstan)
library(ggplot2)

options(max.print = 999999)
rstan_options(auto_write=TRUE)
options(mc.cores=parallel::detectCores())

d <- read.csv('https://raw.githubusercontent.com/MatsuuraKentaro/RStanBook/master/chap11/input/data-coin.txt')
data <- list(N = nrow(d), Y = d$Y)
data
## $N
## [1] 100
## 
## $Y
##   [1] 1 0 1 0 0 1 1 0 1 1 0 1 0 1 1 0 1 1 1 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 1 1 0
##  [38] 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 0 1 1 1 1 0 0 1 0 1 1 1 0 1 0 0 0 1 0 0 0 1
##  [75] 1 1 1 0 1 1 1 0 1 0 1 1 0 0 0 1 1 0 1 0 1 1 0 1 1 1

ここでは、stan_model()関数で最初にコンパイルしておいてから、

model111 <- stan_model('note_ahirubayes14-111.stan')

sampling()関数でサンプリングする。

fit <- sampling(object = model111, data = data, seed = 1989)

推定結果は下記の通り。

print(fit, prob = c(0.025, 0.5, 0.975), digits_summary = 1)
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##       mean se_mean  sd  2.5%   50% 97.5% n_eff Rhat
## q      0.2       0 0.1   0.0   0.2   0.4  1066    1
## lp__ -69.8       0 1.0 -72.5 -69.4 -69.1   617    1
## 
## Samples were drawn using NUTS(diag_e) at Thu Jan  4 04:59:19 2024.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

似たような例として、ポアソン分布を使った例を考える。パラメタ\(\lambda\)のポアソン分布から乱数を発生させ、出力させた値\(m\)の枚数のコインを一度に投げ、そのうち表は出た枚数を\(Y\)とする。ここでの\(m[n]\)は離散パラメタである。

モデル11-2

\[ \begin{eqnarray} m[n] &\sim& Poisson(\lambda) \\ Y[n] &\sim& Binomial(m[n], 0.5) \\ \end{eqnarray} \]

このケースでは、\(m\)の場合の数を数え上げて和をとって消去する。1つ分の尤度は下記の通り。

\[ \begin{eqnarray} p(Y|\lambda) &=& Poisson(0|\lambda)×binomial(Y|0,0.5) \\ &+& Poisson(1|\lambda)×binomial(Y|1,0.5) \\ &+& Poisson(2|\lambda)×binomial(Y|2,0.5) \\ &+& ... \end{eqnarray} \]

\(Y\)の最大値は9なので、\(m\)について高々40までを考慮すればよい。なぜなら、\(Binomial(9|40,0.5)\)の確率は\(Binomial(9|m,0.5)\)と比べて非常に小さいため、40以上は無視できる。

\[ \begin{eqnarray} p(Y|\lambda) &=& \sum_{m=Y}^{40} \left[ Poisson(m|\lambda)×Binomial(Y|m,0.5)\right] \end{eqnarray} \]

これに対数をとって対数尤度を計算する。

\[ \begin{eqnarray} \log \ p(Y|\lambda) &=& \log \left(\sum_{m=Y}^{40} \left[ Poisson(m|\lambda)×Binomial(Y|m,0.5)\right]\right) \\ &=& log\_sum\_exp \left( \begin{array}{c} log \ Poisson(Y|\lambda)×log \ Binomial(Y|Y,0.5) \\ log \ Poisson(Y+1|\lambda)×log \ Binomial(Y+1|Y,0.5) \\ \vdots \\ log \ Poisson(40|\lambda)×log \ Binomial(40|Y,0.5) \\ \end{array} \right) \end{eqnarray} \]

Stanのモデルは下記の通り。

  int N;
  int M_max;
  int<lower=0> Y[N];
}

parameters {
  real<lower=0> lambda;
}

model {
  for (n in 1:N) {
    vector[M_max-Y[n]+1] lp;
    for (m in Y[n]:M_max)
      lp[m-Y[n]+1] = poisson_lpmf(m | lambda) + binomial_lpmf(Y[n] | m, 0.5);
    target += log_sum_exp(lp);
  }
}
d <- read.csv('https://raw.githubusercontent.com/MatsuuraKentaro/RStanBook/master/chap11/input/data-poisson-binomial.txt')
data <- list(N = nrow(d), M_max = 40, Y = d$Y)
data
## $N
## [1] 100
## 
## $M_max
## [1] 40
## 
## $Y
##   [1] 6 5 8 2 5 7 6 3 3 7 5 6 3 3 7 4 2 2 4 4 4 3 2 2 4 4 5 9 7 5 5 6 4 5 5 2 6
##  [38] 4 9 5 6 5 9 5 5 4 4 3 5 9 0 1 3 7 5 7 3 3 7 7 4 6 3 2 3 2 4 2 4 4 3 3 7 5
##  [75] 8 8 3 4 7 5 2 8 7 6 4 5 1 8 7 4 3 9 3 6 5 5 3 5 6 9

ここでは、stan_model()関数で最初にコンパイルしておいてから、

model112 <- stan_model('note_ahirubayes14-112.stan')

sampling()関数でサンプリングする。

fit <- sampling(object = model112, data = data, seed = 1989)

推定結果は下記の通り。

print(fit, prob = c(0.025, 0.5, 0.975), digits_summary = 1)
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##          mean se_mean  sd   2.5%    50%  97.5% n_eff Rhat
## lambda    9.6       0 0.4    8.8    9.6   10.5  1374    1
## lp__   -212.0       0 0.7 -214.1 -211.7 -211.5  1714    1
## 
## Samples were drawn using NUTS(diag_e) at Thu Jan  4 04:59:22 2024.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

11.1.3 公式の活用

ポアソン分布と二項分布は、すでに知られる公式によって合成が可能であり、

モデル11-3

\[ \begin{eqnarray} m &\sim& Poisson(\lambda) \\ Y &\sim& Binomial(m, p) \\ Y &\sim& Poisson(\lambda p) \\ \end{eqnarray} \]

合成した結果、下記の通り表現できる。

モデル11-4

\[ \begin{eqnarray} Y &\sim& Poisson(\lambda p) \\ \end{eqnarray} \]

Stanのモデルは下記の通りシンプルに表現できる。\(p=0.5\)とした場合である。

data {
  int N;
  int<lower=0> Y[N];
}

parameters {
  real<lower=0> lambda;
}

model {
  for (n in 1:N)
    Y[n] ~ poisson(lambda*0.5);
}

11.2 混合正規分布

100人分の社員の能力測定値のスコアを使って混合正規分布と離散パラメタへの理解を深める。混合正規分布は確率\(a\)で1つ目の正規分布から生成され、確率\(1-a\)で2つ目の正規分布から生成される。

\[ Normal\_Mixture(a, \mu_{1}, \mu_{2}, \sigma_{1}, \sigma_{2}) = a × N(y|\mu_{1},\sigma_{1}) + (1-a) × N(y|\mu_{2},\sigma_{2}) \]

混合正規分布を使ってモデル式を表現すると、下記のようになる。

モデル11-5

\[ \begin{eqnarray} Y[n] \sim Normal\_Mixture(a, \mu_{1}, \mu_{2}, \sigma_{1}, \sigma_{2}) \end{eqnarray} \]

Stanのモデル式は下記の通り。log1m(a)関数は、log(1-a)をより安定して計算するための便利関数。

data {
  int N;
  vector[N] Y;
}

parameters {
  real<lower=0, upper=1> a;
  ordered[2] mu;
  vector<lower=0>[2] sigma;
}

model {
  for (n in 1:N)
    target += log_sum_exp(
      log(a)   + normal_lpdf(Y[n] | mu[1], sigma[1]),
      log1m(a) + normal_lpdf(Y[n] | mu[2], sigma[2])
    );
}

ここでは混合正規分布の混合数を\(K\)個に拡張した例をメモしておく。

d <- read.csv('https://raw.githubusercontent.com/MatsuuraKentaro/RStanBook/master/chap11/input/data-mix2.txt')
K <- 5
data <- list(N = nrow(d), K = K, Y = d$Y)
init <- list(a = rep(1,K)/K, mu = seq(10,40,len = K), s_mu = 20, sigma = rep(1,K))
data
## $N
## [1] 200
## 
## $K
## [1] 5
## 
## $Y
##   [1]  9.44  9.77 11.56 10.07 10.13 11.72 10.46  8.73  9.31  9.55 11.22 10.36
##  [13] 10.40 10.11  9.44 11.79 10.50  8.03 10.70  9.53  8.93  9.78  8.97  9.27
##  [25]  9.37  8.31 10.84 10.15  8.86 11.25 10.43  9.70 10.90 10.88 17.23 17.03
##  [37] 16.83 15.91 15.54 15.43 14.96 15.69 14.10 19.25 17.81 14.32 15.40 15.30
##  [49] 17.17 15.87 16.38 15.96 15.94 18.05 20.80 22.36 19.61 21.53 21.11 21.19
##  [61] 21.34 20.55 20.70 20.08 20.04 21.27 21.40 21.05 21.83 22.85 20.56 18.92
##  [73] 21.91 20.36 20.38 21.92 20.74 19.90 21.16 20.87 21.01 21.35 20.67 21.58
##  [85] 20.80 21.30 21.99 21.39 20.71 22.03 21.89 21.49 21.21 20.43 22.22 20.46
##  [97] 22.97 22.38 20.79 20.08 20.36 21.23 20.78 20.69 20.14 20.96 20.29 19.50
## [109] 20.66 21.83 20.48 21.55 19.54 20.95 21.47 21.27 21.10 20.42 20.24 20.08
## [121] 21.11 20.15 20.56 20.77 22.66 20.41 21.21 21.07 20.13 20.94 22.30 21.41
## [133] 21.04 20.62 22.24 25.74 22.89 25.31 26.60 22.91 25.27 24.21 22.77 22.83
## [145] 22.74 23.92 22.89 25.26 26.81 23.08 25.37 25.35 24.87 23.39 24.37 24.19
## [157] 25.12 24.09 25.57 24.09 25.66 23.35 23.11 28.07 24.04 24.83 25.20 23.97
## [169] 25.07 24.91 24.26 24.57 24.46 26.84 30.15 29.26 32.09 32.78 33.09 30.85
## [181] 29.34 35.16 31.13 29.84 31.41 31.51 34.77 32.21 33.89 30.75 32.54 31.19
## [193] 32.24 29.76 28.72 36.99 33.50 28.87 30.47 29.04

可視化するとこのような分布である。

dens <- density(d$Y)

ggplot(data = d, aes(x = Y)) +
  theme_bw(base_size = 18) +
  geom_histogram(color = 'black', fill = 'white') +
  geom_density(aes(y = after_stat(count)), alpha = 0.35, colour = 'black', fill = 'gray20') +
  geom_rug(sides = 'b') +
  labs(x = 'Y') + xlim(range(dens$x)) 

\[ Normal\_Mixture(y|\overrightarrow{ a }, \overrightarrow{ \mu }, \overrightarrow{ \sigma }) = \sum_{k=1}^{K} a_{k} Normal(y|\mu_{k}, \sigma_{k}) \]

モデル式は下記の通り。

モデル11-6

\[ \begin{eqnarray} \log Normal\_Mixture(y|\overrightarrow{ a }, \overrightarrow{ \mu }, \overrightarrow{ \sigma }) &=& \log \left[ \sum_{k=1}^{K} a_{k} Normal(y|\mu_{k}, \sigma_{k})\right] \\ &=& log\_sum\_exp \left( \begin{array}{c} \log{a_{1} Normal(y|\mu_{1},\sigma_{1})} \\ \vdots \\ \log{a_{k} Normal(y|\mu_{k},\sigma_{k})} \\ \end{array} \right) \\ &=& log\_sum\_exp \left( \begin{array}{c} \log{a_{1} + Normal\_lpdf(y|\mu_{1},\sigma_{1})} \\ \vdots \\ \log{a_{k} + Normal\_lpdf(y|\mu_{k},\sigma_{k})} \\ \end{array} \right) \end{eqnarray} \]

Stanのモデルは下記の通り。\(\overrightarrow{ a }\)は合計すると1になるためsimplex型で定義できる。また、位置パラメタ\(\mu\)は混合正規分布の場合、順序関係を持たせることが可能なのでorder型で定義できる。

data {
  int N;
  int K;
  vector[N] Y;
}

parameters {
  simplex[K] a;
  ordered[K] mu;
  vector<lower=0>[K] sigma;
  real<lower=0> s_mu;
}

model {
  mu ~ normal(mean(Y), s_mu);
  sigma ~ gamma(1.5, 1.0);
  for (n in 1:N) {
    vector[K] lp;
    for (k in 1:K)
      lp[k] = log(a[k]) + normal_lpdf(Y[n] | mu[k], sigma[k]);
    target += log_sum_exp(lp);
  }
}

ここでは、stan_model()関数で最初にコンパイルしておいてから、

model114 <- stan_model('note_ahirubayes14-114.stan')

sampling()関数でサンプリングする。

fit <- sampling(object = model114, data = data, seed = 1989)

推定結果は下記の通り。

print(fit, prob = c(0.025, 0.5, 0.975), digits_summary = 1)
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##            mean se_mean  sd   2.5%    50%  97.5% n_eff Rhat
## a[1]        0.2     0.0 0.0    0.1    0.2    0.2  3961    1
## a[2]        0.1     0.0 0.0    0.1    0.1    0.2   865    1
## a[3]        0.3     0.0 0.1    0.2    0.3    0.4  1243    1
## a[4]        0.3     0.0 0.1    0.1    0.3    0.4   859    1
## a[5]        0.2     0.0 0.0    0.1    0.1    0.2   564    1
## mu[1]      10.0     0.0 0.2    9.7   10.0   10.3  3653    1
## mu[2]      16.1     0.0 0.5   15.4   16.1   17.3   541    1
## mu[3]      20.9     0.0 0.1   20.6   20.9   21.2  2701    1
## mu[4]      23.5     0.0 0.6   22.2   23.5   24.7  1178    1
## mu[5]      30.9     0.0 1.1   27.9   31.1   32.2   526    1
## sigma[1]    1.0     0.0 0.1    0.8    1.0    1.3  3484    1
## sigma[2]    1.3     0.0 0.5    0.7    1.2    2.7   621    1
## sigma[3]    0.7     0.0 0.1    0.4    0.7    1.0  1451    1
## sigma[4]    1.8     0.0 0.5    0.9    1.8    2.8  1083    1
## sigma[5]    2.7     0.0 0.7    1.8    2.5    4.6   707    1
## s_mu       10.0     0.2 6.3    4.7    8.6   23.0  1253    1
## lp__     -601.1     0.1 3.2 -608.7 -600.7 -596.0   978    1
## 
## Samples were drawn using NUTS(diag_e) at Thu Jan  4 04:59:33 2024.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).