UPDATE: 2024-01-21 16:55:31.439467

はじめに

このノートは「StanとRでベイズ統計モデリング」の内容を写経することで、ベイズ統計への理解を深めていくために作成している。ここでは制約付きデータ型であるordered型への理解を先に深めておく。

制約付きデータ型のordered型

Stan でパラメータに大小関係に制約をつける場合、制約付きデータ型のordered型を利用するが、これのメリットや挙動がよくわかっていない。ドキュメントには下記のような記載がある。

あるモデリングタスクでは、順序関係をもつベクトル値確率変数Xが必要な場合がある。1つの例は,順序ロジスティック回帰におけるカットポイントの集合である。制約条件ではKベクトル\(x \in \mathbb{R}^K\)は下記を満たす

\[ x_{k} \lt x_{k+1} \\ for \ k \in \{ 1, \ldots, K-1 \} \]

ドキュメントを読む限り、これ以上情報は見つけられなかったので、ネットで調べたところ、下記のありがたい記事を見つけることができた。この記事を参考にordered型への理解を深める。

サンプルデータは、ECサイトの読み込み時間とCVの関係に関するデータ。読込み時間が短くなると、CV率は高くなることがわかっている前提。この関係がわかっていることが、あとあと重要になる。まずは何も考慮せずに素直にモデリングを行う。ここでは

library(tidyverse)
library(rstan)

options(max.print = 999999)
rstan_options(auto_write=TRUE)
options(mc.cores=parallel::detectCores())

true_cvr <- c(0.209, 0.126, 0.096, 0.093, 0.086, 0.077, 0.067, 0.057)
load_time <- c("00-01", "01-03", "03-07", "07-13", "13-21", "21-35", "35-60", "60+")
session <- c(1000, 6000, 4000, 1500, 700, 500, 200, 150)
set.seed(71)
cv <- unlist(Map(function(n, p) rbinom(1, n, p), session, true_cvr))
data <- list(N = length(cv), cv = cv, session = session, load_time = load_time)
data.frame(load_time, cv, session)
##   load_time  cv session
## 1     00-01 208    1000
## 2     01-03 769    6000
## 3     03-07 366    4000
## 4     07-13 142    1500
## 5     13-21  54     700
## 6     21-35  35     500
## 7     35-60  17     200
## 8       60+  11     150

モデルはこちら。

data {
  int<lower=0> N;
  int<lower=0> cv[N];
  int<lower=0> session[N];
}
parameters {
  real<lower=0, upper=1> cvr[N];
}
model {
  cv ~ binomial(session, cvr);
}

ここでは、stan_model()関数で最初にコンパイルしておいてから、

model_or1 <- stan_model('model-ordered1.stan')

sampling()関数でサンプリングする。

fit_or1 <- sampling(object = model_or1, data = data, seed = 1989)

推定結果は下記の通り。

print(fit_or1, prob = c(0.025, 0.5, 0.975), digits_summary = 3)
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##             mean se_mean    sd      2.5%       50%     97.5% n_eff  Rhat
## cvr[1]     0.209   0.000 0.013     0.184     0.209     0.234  6940 0.999
## cvr[2]     0.128   0.000 0.004     0.120     0.128     0.136  6612 0.999
## cvr[3]     0.092   0.000 0.005     0.083     0.092     0.101  6636 0.999
## cvr[4]     0.095   0.000 0.008     0.081     0.095     0.111  7200 1.000
## cvr[5]     0.078   0.000 0.010     0.059     0.078     0.099  7694 1.000
## cvr[6]     0.071   0.000 0.012     0.051     0.071     0.096  6987 0.999
## cvr[7]     0.089   0.000 0.020     0.054     0.088     0.132  6698 1.000
## cvr[8]     0.079   0.000 0.021     0.043     0.077     0.124  7058 0.999
## lp__   -4940.378   0.045 1.962 -4945.090 -4940.042 -4937.533  1879 1.000
## 
## Samples were drawn using NUTS(diag_e) at Sun Jan 21 16:55:33 2024.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

推定結果を可視化すると、上手く推定できていない。07-1303-07よりもcv率が高くなっているが、これは本来の関係とは異なる。35-60, 60+も同様である。

ms1 <- rstan::extract(fit_or1, pars = 'cvr')$cvr 

df <- ms1 %>%
  data.frame() %>%
  setNames(load_time) %>%
  pivot_longer(
    cols = everything(),
    names_to = 'load_time',
    values_to = 'cvr'
  )
  
cvr_hat <- colMeans(ms1)
point_df <- data.frame(load_time, true_cvr, cvr_hat) %>%
  pivot_longer(
    cols = -load_time,
    names_to = 'cvr_type',
    values_to = 'cvr'
  ) %>% 
  arrange(cvr_type)

ggplot(df, aes(x = load_time, y = cvr)) + 
  geom_violin() + ylim(0, NA) +
  geom_point(data = point_df, aes(x = load_time, y = cvr, col = cvr_type)) + 
  theme_bw()

ここで事前知識として持っている「読込み時間が短くなると、cv率は高くなることがわかっているとする」という関係を利用する。parametersブロックで、ordered型のパラメタを作成する。モデルはこちら。

data {
  int<lower=0> N;
  int<lower=0> cv[N];
  int<lower=0> session[N];
}
parameters {
  ordered[N] cvr_rev;
}
transformed parameters {
  real<lower=0, upper=1> cvr[N];
  for(i in 1:N) {
    cvr[i] <- inv_logit(cvr_rev[N - i + 1]);
  }
}
model {
  cv ~ binomial(session, cvr);
}

ordered型は、「小さい順」という制約である一方で、今回ここで求めたいパラメタは大きい順。つまり、これを逆順にする必要がある。少し理解しにくいので、この点は後で補足する。

そして、ordered型にはlowerやupperの制約がつけられないため、\(-\infty \sim +\infty\)までの値を取ってしまう。cvrは0-1の範囲である必要があるので、cvr_revをサンプリングすると、範囲外の値を取ってしまう。そのため、inv_logit()関数(logistic関数)で\(-\infty \sim +\infty\)の値を0-1に変換する。

ここでは、stan_model()関数で最初にコンパイルしておいてから、

model_or2 <- stan_model('model-ordered2.stan')

sampling()関数でサンプリングする。

fit_or2 <- sampling(object = model_or2, data = data, seed = 1989)

推定結果は下記の通り。cvr_revがマイナスの値を取っていることがわかる。

print(fit_or2, prob = c(0.025, 0.5, 0.975), digits_summary = 3)
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##                 mean se_mean    sd      2.5%       50%     97.5% n_eff  Rhat
## cvr_rev[1]    -2.911   0.006 0.204    -3.358    -2.888    -2.572  1329 1.004
## cvr_rev[2]    -2.683   0.002 0.133    -2.967    -2.670    -2.456  2922 1.001
## cvr_rev[3]    -2.561   0.002 0.103    -2.775    -2.557    -2.376  3048 1.000
## cvr_rev[4]    -2.447   0.002 0.085    -2.631    -2.441    -2.297  2974 1.000
## cvr_rev[5]    -2.323   0.001 0.056    -2.439    -2.319    -2.222  2453 1.001
## cvr_rev[6]    -2.262   0.001 0.048    -2.355    -2.263    -2.169  2518 1.001
## cvr_rev[7]    -1.917   0.001 0.038    -1.991    -1.917    -1.843  3082 1.000
## cvr_rev[8]    -1.337   0.001 0.077    -1.489    -1.338    -1.190  2650 1.001
## cvr[1]         0.208   0.000 0.013     0.184     0.208     0.233  2708 1.001
## cvr[2]         0.128   0.000 0.004     0.120     0.128     0.137  3096 1.000
## cvr[3]         0.094   0.000 0.004     0.087     0.094     0.103  2483 1.001
## cvr[4]         0.089   0.000 0.005     0.080     0.090     0.098  2462 1.001
## cvr[5]         0.080   0.000 0.006     0.067     0.080     0.091  2929 1.000
## cvr[6]         0.072   0.000 0.007     0.059     0.072     0.085  3011 1.000
## cvr[7]         0.064   0.000 0.008     0.049     0.065     0.079  2947 1.001
## cvr[8]         0.053   0.000 0.010     0.034     0.053     0.071  1465 1.004
## lp__       -4935.412   0.067 2.297 -4940.723 -4935.028 -4932.179  1172 1.004
## 
## Samples were drawn using NUTS(diag_e) at Sun Jan 21 16:55:35 2024.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

推定結果を可視化すると、上手く推定できている事がわかる。

ms2 <- rstan::extract(fit_or2, pars = 'cvr')$cvr 

df <- ms2 %>%
  data.frame() %>%
  setNames(load_time) %>%
  pivot_longer(
    cols = everything(),
    names_to = 'load_time',
    values_to = 'cvr'
  )
  
cvr_hat <- colMeans(ms2)
point_df <- data.frame(load_time, true_cvr, cvr_hat) %>%
  pivot_longer(
    cols = -load_time,
    names_to = 'cvr_type',
    values_to = 'cvr'
  )

ggplot(df, aes(x = load_time, y = cvr)) + 
  geom_violin() + ylim(0, NA) +
  geom_point(data = point_df, aes(x = load_time, y = cvr, col = cvr_type)) + 
  theme_bw()

下記の部分の補足をメモしておく。

ordered型は、「小さい順」という制約である一方で、今回ここで求めたいパラメタは大きい順。つまり、これを逆順にする必要がある。

小さい順という制約がついているので、推定する際に「cv率が小さい順」で渡すと手間が省ける。今はcv率が「大きい順」にデータが並んでいるのと、小さい順にできるという事前知識をもっていることが肝要である。通常の推定では、このような順序関係がわかっているケースは少ないかもしれない。

data
## $N
## [1] 8
## 
## $cv
## [1] 208 769 366 142  54  35  17  11
## 
## $session
## [1] 1000 6000 4000 1500  700  500  200  150
## 
## $load_time
## [1] "00-01" "01-03" "03-07" "07-13" "13-21" "21-35" "35-60" "60+"

このデータを「小さい順」にしてからStanに渡す。

df_rev <- data.frame(load_time, cv, session) %>% 
  mutate(index = row_number()) %>% 
  arrange(desc(index))

# 作図する際にcv率が小さい順に強制するためにレベルを付与
#l <- c('60+','35-60','21-35', '13-21', '07-13','03-07','01-03','00-01')
df_rev$load_time <- factor(df_rev$load_time, levels = rev(load_time))

data_rev <- list(
  N = length(cv), 
  cv = df_rev$cv, 
  session = df_rev$session, 
  load_time = df_rev$load_time, 
  true_cvr = rev(true_cvr)
  )

data_rev
## $N
## [1] 8
## 
## $cv
## [1]  11  17  35  54 142 366 769 208
## 
## $session
## [1]  150  200  500  700 1500 4000 6000 1000
## 
## $load_time
## [1] 60+   35-60 21-35 13-21 07-13 03-07 01-03 00-01
## Levels: 60+ 35-60 21-35 13-21 07-13 03-07 01-03 00-01
## 
## $true_cvr
## [1] 0.057 0.067 0.077 0.086 0.093 0.096 0.126 0.209

小さい順のデータを渡し、ordered型を利用して、さきほどと同じくinv_logit関数で変換する。逆順にする操作は含んでいない。

data {
  int<lower=0> N;
  int<lower=0> cv[N];
  int<lower=0> session[N];
}
parameters {
  ordered[N] tmp;
}
transformed parameters {
  real<lower=0, upper=1> cvr[N];
  for(i in 1:N) {
    cvr[i] <- inv_logit(tmp[i]);
  }
}
model {
  cv ~ binomial(session, cvr);
}

サンプリングを実行する。

model_or3 <- stan_model('model-ordered3.stan')
fit_or3 <- sampling(object = model_or3, data = data_rev, seed = 1989)

推定結果はこちら。

print(fit_or3, prob = c(0.025, 0.5, 0.975), digits_summary = 3)
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##             mean se_mean    sd      2.5%       50%     97.5% n_eff  Rhat
## tmp[1]    -2.907   0.005 0.203    -3.348    -2.886    -2.574  1482 1.002
## tmp[2]    -2.682   0.002 0.133    -2.968    -2.672    -2.453  3308 1.000
## tmp[3]    -2.561   0.002 0.102    -2.774    -2.556    -2.382  4557 0.999
## tmp[4]    -2.444   0.001 0.084    -2.622    -2.439    -2.297  3511 0.999
## tmp[5]    -2.321   0.001 0.056    -2.434    -2.319    -2.215  2686 1.002
## tmp[6]    -2.259   0.001 0.047    -2.351    -2.260    -2.166  3246 1.000
## tmp[7]    -1.918   0.001 0.039    -1.996    -1.918    -1.841  3290 1.000
## tmp[8]    -1.341   0.001 0.077    -1.491    -1.341    -1.193  3015 1.000
## cvr[1]     0.053   0.000 0.010     0.034     0.053     0.071  1606 1.001
## cvr[2]     0.065   0.000 0.008     0.049     0.065     0.079  3325 1.000
## cvr[3]     0.072   0.000 0.007     0.059     0.072     0.085  4456 0.999
## cvr[4]     0.080   0.000 0.006     0.068     0.080     0.091  3387 0.999
## cvr[5]     0.090   0.000 0.005     0.081     0.090     0.098  2653 1.002
## cvr[6]     0.095   0.000 0.004     0.087     0.094     0.103  3186 1.000
## cvr[7]     0.128   0.000 0.004     0.120     0.128     0.137  3303 1.000
## cvr[8]     0.208   0.000 0.013     0.184     0.207     0.233  3062 1.000
## lp__   -4935.273   0.058 2.132 -4940.539 -4934.974 -4932.083  1347 1.004
## 
## Samples were drawn using NUTS(diag_e) at Sun Jan 21 16:55:36 2024.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

推定決kは下記のような関係にある。

            mean
cvr[1]     0.053 <- `60+`に対応(cv率が小さい、読み込み時間が長い)
cvr[2]     0.065
cvr[3]     0.072
cvr[4]     0.080
cvr[5]     0.090
cvr[6]     0.095
cvr[7]     0.128
cvr[8]     0.208 <- `00-01`に対応(cv率が大きい、読み込み時間が短い)

可視化用のデータを作成する。ラベルなども反転させている点は注意。

ms_rev <- rstan::extract(fit_or3, pars = 'cvr')$cvr 

df_rev <- ms_rev %>%
  data.frame() %>%
  setNames(data_rev$load_time) %>%
  pivot_longer(
    cols = everything(),
    names_to = 'load_time',
    values_to = 'cvr'
  )
cvr_hat_rev <- colMeans(ms_rev)
point_df_rev <- data.frame(load_time = data_rev$load_time, true_cvr = data_rev$true_cvr, cvr_hat_rev) %>%
  pivot_longer(
    cols = -load_time,
    names_to = 'cvr_type',
    values_to = 'cvr'
  )

可視化すると上手く推定できていることがわかる。これがordered型の「小さい順」という制約。データの渡し方は注意が必要。並びに関しては、Stanにデータを渡す前にload_timeのレベルを操作しているが、レベルを付与しなければ、これまでと同じ並びで可視化できる。

ggplot(df_rev, aes(x = load_time, y = cvr)) + 
  geom_violin() +
  geom_point(data = point_df_rev, aes(x = load_time, y = cvr, col = cvr_type)) + 
  theme_bw()