UPDATE: 2024-01-21 16:55:31.439467
このノートは「StanとRでベイズ統計モデリング」の内容を写経することで、ベイズ統計への理解を深めていくために作成している。ここでは制約付きデータ型であるordered
型への理解を先に深めておく。
Stan
でパラメータに大小関係に制約をつける場合、制約付きデータ型のordered
型を利用するが、これのメリットや挙動がよくわかっていない。ドキュメントには下記のような記載がある。
あるモデリングタスクでは、順序関係をもつベクトル値確率変数Xが必要な場合がある。1つの例は,順序ロジスティック回帰におけるカットポイントの集合である。制約条件ではKベクトル\(x \in \mathbb{R}^K\)は下記を満たす
\[ x_{k} \lt x_{k+1} \\ for \ k \in \{ 1, \ldots, K-1 \} \]
ドキュメントを読む限り、これ以上情報は見つけられなかったので、ネットで調べたところ、下記のありがたい記事を見つけることができた。この記事を参考にordered
型への理解を深める。
サンプルデータは、ECサイトの読み込み時間とCVの関係に関するデータ。読込み時間が短くなると、CV率は高くなることがわかっている前提。この関係がわかっていることが、あとあと重要になる。まずは何も考慮せずに素直にモデリングを行う。ここでは
library(tidyverse)
library(rstan)
options(max.print = 999999)
rstan_options(auto_write=TRUE)
options(mc.cores=parallel::detectCores())
true_cvr <- c(0.209, 0.126, 0.096, 0.093, 0.086, 0.077, 0.067, 0.057)
load_time <- c("00-01", "01-03", "03-07", "07-13", "13-21", "21-35", "35-60", "60+")
session <- c(1000, 6000, 4000, 1500, 700, 500, 200, 150)
set.seed(71)
cv <- unlist(Map(function(n, p) rbinom(1, n, p), session, true_cvr))
data <- list(N = length(cv), cv = cv, session = session, load_time = load_time)
data.frame(load_time, cv, session)
## load_time cv session
## 1 00-01 208 1000
## 2 01-03 769 6000
## 3 03-07 366 4000
## 4 07-13 142 1500
## 5 13-21 54 700
## 6 21-35 35 500
## 7 35-60 17 200
## 8 60+ 11 150
モデルはこちら。
data {
int<lower=0> N;
int<lower=0> cv[N];
int<lower=0> session[N];
}
parameters {
real<lower=0, upper=1> cvr[N];
}
model {
cv ~ binomial(session, cvr);
}
ここでは、stan_model()
関数で最初にコンパイルしておいてから、
model_or1 <- stan_model('model-ordered1.stan')
sampling()
関数でサンプリングする。
fit_or1 <- sampling(object = model_or1, data = data, seed = 1989)
推定結果は下記の通り。
print(fit_or1, prob = c(0.025, 0.5, 0.975), digits_summary = 3)
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 50% 97.5% n_eff Rhat
## cvr[1] 0.209 0.000 0.013 0.184 0.209 0.234 6940 0.999
## cvr[2] 0.128 0.000 0.004 0.120 0.128 0.136 6612 0.999
## cvr[3] 0.092 0.000 0.005 0.083 0.092 0.101 6636 0.999
## cvr[4] 0.095 0.000 0.008 0.081 0.095 0.111 7200 1.000
## cvr[5] 0.078 0.000 0.010 0.059 0.078 0.099 7694 1.000
## cvr[6] 0.071 0.000 0.012 0.051 0.071 0.096 6987 0.999
## cvr[7] 0.089 0.000 0.020 0.054 0.088 0.132 6698 1.000
## cvr[8] 0.079 0.000 0.021 0.043 0.077 0.124 7058 0.999
## lp__ -4940.378 0.045 1.962 -4945.090 -4940.042 -4937.533 1879 1.000
##
## Samples were drawn using NUTS(diag_e) at Sun Jan 21 16:55:33 2024.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
推定結果を可視化すると、上手く推定できていない。07-13
が03-07
よりもcv率が高くなっているが、これは本来の関係とは異なる。35-60, 60+
も同様である。
ms1 <- rstan::extract(fit_or1, pars = 'cvr')$cvr
df <- ms1 %>%
data.frame() %>%
setNames(load_time) %>%
pivot_longer(
cols = everything(),
names_to = 'load_time',
values_to = 'cvr'
)
cvr_hat <- colMeans(ms1)
point_df <- data.frame(load_time, true_cvr, cvr_hat) %>%
pivot_longer(
cols = -load_time,
names_to = 'cvr_type',
values_to = 'cvr'
) %>%
arrange(cvr_type)
ggplot(df, aes(x = load_time, y = cvr)) +
geom_violin() + ylim(0, NA) +
geom_point(data = point_df, aes(x = load_time, y = cvr, col = cvr_type)) +
theme_bw()
ここで事前知識として持っている「読込み時間が短くなると、cv率は高くなることがわかっているとする」という関係を利用する。parameters
ブロックで、ordered
型のパラメタを作成する。モデルはこちら。
data {
int<lower=0> N;
int<lower=0> cv[N];
int<lower=0> session[N];
}
parameters {
ordered[N] cvr_rev;
}
transformed parameters {
real<lower=0, upper=1> cvr[N];
for(i in 1:N) {
cvr[i] <- inv_logit(cvr_rev[N - i + 1]);
}
}
model {
cv ~ binomial(session, cvr);
}
ordered
型は、「小さい順」という制約である一方で、今回ここで求めたいパラメタは大きい順。つまり、これを逆順にする必要がある。少し理解しにくいので、この点は後で補足する。
そして、ordered
型にはlowerやupperの制約がつけられないため、\(-\infty \sim
+\infty\)までの値を取ってしまう。cvr
は0-1の範囲である必要があるので、cvr_rev
をサンプリングすると、範囲外の値を取ってしまう。そのため、inv_logit()
関数(logistic
関数)で\(-\infty \sim
+\infty\)の値を0-1に変換する。
ここでは、stan_model()
関数で最初にコンパイルしておいてから、
model_or2 <- stan_model('model-ordered2.stan')
sampling()
関数でサンプリングする。
fit_or2 <- sampling(object = model_or2, data = data, seed = 1989)
推定結果は下記の通り。cvr_rev
がマイナスの値を取っていることがわかる。
print(fit_or2, prob = c(0.025, 0.5, 0.975), digits_summary = 3)
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 50% 97.5% n_eff Rhat
## cvr_rev[1] -2.911 0.006 0.204 -3.358 -2.888 -2.572 1329 1.004
## cvr_rev[2] -2.683 0.002 0.133 -2.967 -2.670 -2.456 2922 1.001
## cvr_rev[3] -2.561 0.002 0.103 -2.775 -2.557 -2.376 3048 1.000
## cvr_rev[4] -2.447 0.002 0.085 -2.631 -2.441 -2.297 2974 1.000
## cvr_rev[5] -2.323 0.001 0.056 -2.439 -2.319 -2.222 2453 1.001
## cvr_rev[6] -2.262 0.001 0.048 -2.355 -2.263 -2.169 2518 1.001
## cvr_rev[7] -1.917 0.001 0.038 -1.991 -1.917 -1.843 3082 1.000
## cvr_rev[8] -1.337 0.001 0.077 -1.489 -1.338 -1.190 2650 1.001
## cvr[1] 0.208 0.000 0.013 0.184 0.208 0.233 2708 1.001
## cvr[2] 0.128 0.000 0.004 0.120 0.128 0.137 3096 1.000
## cvr[3] 0.094 0.000 0.004 0.087 0.094 0.103 2483 1.001
## cvr[4] 0.089 0.000 0.005 0.080 0.090 0.098 2462 1.001
## cvr[5] 0.080 0.000 0.006 0.067 0.080 0.091 2929 1.000
## cvr[6] 0.072 0.000 0.007 0.059 0.072 0.085 3011 1.000
## cvr[7] 0.064 0.000 0.008 0.049 0.065 0.079 2947 1.001
## cvr[8] 0.053 0.000 0.010 0.034 0.053 0.071 1465 1.004
## lp__ -4935.412 0.067 2.297 -4940.723 -4935.028 -4932.179 1172 1.004
##
## Samples were drawn using NUTS(diag_e) at Sun Jan 21 16:55:35 2024.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
推定結果を可視化すると、上手く推定できている事がわかる。
ms2 <- rstan::extract(fit_or2, pars = 'cvr')$cvr
df <- ms2 %>%
data.frame() %>%
setNames(load_time) %>%
pivot_longer(
cols = everything(),
names_to = 'load_time',
values_to = 'cvr'
)
cvr_hat <- colMeans(ms2)
point_df <- data.frame(load_time, true_cvr, cvr_hat) %>%
pivot_longer(
cols = -load_time,
names_to = 'cvr_type',
values_to = 'cvr'
)
ggplot(df, aes(x = load_time, y = cvr)) +
geom_violin() + ylim(0, NA) +
geom_point(data = point_df, aes(x = load_time, y = cvr, col = cvr_type)) +
theme_bw()
下記の部分の補足をメモしておく。
ordered
型は、「小さい順」という制約である一方で、今回ここで求めたいパラメタは大きい順。つまり、これを逆順にする必要がある。
小さい順という制約がついているので、推定する際に「cv率が小さい順」で渡すと手間が省ける。今はcv率が「大きい順」にデータが並んでいるのと、小さい順にできるという事前知識をもっていることが肝要である。通常の推定では、このような順序関係がわかっているケースは少ないかもしれない。
data
## $N
## [1] 8
##
## $cv
## [1] 208 769 366 142 54 35 17 11
##
## $session
## [1] 1000 6000 4000 1500 700 500 200 150
##
## $load_time
## [1] "00-01" "01-03" "03-07" "07-13" "13-21" "21-35" "35-60" "60+"
このデータを「小さい順」にしてからStanに渡す。
df_rev <- data.frame(load_time, cv, session) %>%
mutate(index = row_number()) %>%
arrange(desc(index))
# 作図する際にcv率が小さい順に強制するためにレベルを付与
#l <- c('60+','35-60','21-35', '13-21', '07-13','03-07','01-03','00-01')
df_rev$load_time <- factor(df_rev$load_time, levels = rev(load_time))
data_rev <- list(
N = length(cv),
cv = df_rev$cv,
session = df_rev$session,
load_time = df_rev$load_time,
true_cvr = rev(true_cvr)
)
data_rev
## $N
## [1] 8
##
## $cv
## [1] 11 17 35 54 142 366 769 208
##
## $session
## [1] 150 200 500 700 1500 4000 6000 1000
##
## $load_time
## [1] 60+ 35-60 21-35 13-21 07-13 03-07 01-03 00-01
## Levels: 60+ 35-60 21-35 13-21 07-13 03-07 01-03 00-01
##
## $true_cvr
## [1] 0.057 0.067 0.077 0.086 0.093 0.096 0.126 0.209
小さい順のデータを渡し、ordered
型を利用して、さきほどと同じくinv_logit
関数で変換する。逆順にする操作は含んでいない。
data {
int<lower=0> N;
int<lower=0> cv[N];
int<lower=0> session[N];
}
parameters {
ordered[N] tmp;
}
transformed parameters {
real<lower=0, upper=1> cvr[N];
for(i in 1:N) {
cvr[i] <- inv_logit(tmp[i]);
}
}
model {
cv ~ binomial(session, cvr);
}
サンプリングを実行する。
model_or3 <- stan_model('model-ordered3.stan')
fit_or3 <- sampling(object = model_or3, data = data_rev, seed = 1989)
推定結果はこちら。
print(fit_or3, prob = c(0.025, 0.5, 0.975), digits_summary = 3)
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 50% 97.5% n_eff Rhat
## tmp[1] -2.907 0.005 0.203 -3.348 -2.886 -2.574 1482 1.002
## tmp[2] -2.682 0.002 0.133 -2.968 -2.672 -2.453 3308 1.000
## tmp[3] -2.561 0.002 0.102 -2.774 -2.556 -2.382 4557 0.999
## tmp[4] -2.444 0.001 0.084 -2.622 -2.439 -2.297 3511 0.999
## tmp[5] -2.321 0.001 0.056 -2.434 -2.319 -2.215 2686 1.002
## tmp[6] -2.259 0.001 0.047 -2.351 -2.260 -2.166 3246 1.000
## tmp[7] -1.918 0.001 0.039 -1.996 -1.918 -1.841 3290 1.000
## tmp[8] -1.341 0.001 0.077 -1.491 -1.341 -1.193 3015 1.000
## cvr[1] 0.053 0.000 0.010 0.034 0.053 0.071 1606 1.001
## cvr[2] 0.065 0.000 0.008 0.049 0.065 0.079 3325 1.000
## cvr[3] 0.072 0.000 0.007 0.059 0.072 0.085 4456 0.999
## cvr[4] 0.080 0.000 0.006 0.068 0.080 0.091 3387 0.999
## cvr[5] 0.090 0.000 0.005 0.081 0.090 0.098 2653 1.002
## cvr[6] 0.095 0.000 0.004 0.087 0.094 0.103 3186 1.000
## cvr[7] 0.128 0.000 0.004 0.120 0.128 0.137 3303 1.000
## cvr[8] 0.208 0.000 0.013 0.184 0.207 0.233 3062 1.000
## lp__ -4935.273 0.058 2.132 -4940.539 -4934.974 -4932.083 1347 1.004
##
## Samples were drawn using NUTS(diag_e) at Sun Jan 21 16:55:36 2024.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
推定決kは下記のような関係にある。
mean
cvr[1] 0.053 <- `60+`に対応(cv率が小さい、読み込み時間が長い)
cvr[2] 0.065
cvr[3] 0.072
cvr[4] 0.080
cvr[5] 0.090
cvr[6] 0.095
cvr[7] 0.128
cvr[8] 0.208 <- `00-01`に対応(cv率が大きい、読み込み時間が短い)
可視化用のデータを作成する。ラベルなども反転させている点は注意。
ms_rev <- rstan::extract(fit_or3, pars = 'cvr')$cvr
df_rev <- ms_rev %>%
data.frame() %>%
setNames(data_rev$load_time) %>%
pivot_longer(
cols = everything(),
names_to = 'load_time',
values_to = 'cvr'
)
cvr_hat_rev <- colMeans(ms_rev)
point_df_rev <- data.frame(load_time = data_rev$load_time, true_cvr = data_rev$true_cvr, cvr_hat_rev) %>%
pivot_longer(
cols = -load_time,
names_to = 'cvr_type',
values_to = 'cvr'
)
可視化すると上手く推定できていることがわかる。これがordered
型の「小さい順」という制約。データの渡し方は注意が必要。並びに関しては、Stanにデータを渡す前にload_time
のレベルを操作しているが、レベルを付与しなければ、これまでと同じ並びで可視化できる。
ggplot(df_rev, aes(x = load_time, y = cvr)) +
geom_violin() +
geom_point(data = point_df_rev, aes(x = load_time, y = cvr, col = cvr_type)) +
theme_bw()