UPDATE: 2024-01-26 18:04:07.556405
このノートは「ベイズ統計」に関する何らかの内容をまとめ、ベイズ統計への理解を深めていくために作成している。今回は「たのしいベイズモデリング」の第18章に記載されている「本当に麻雀が強いのは誰か?」の内容を参考にさせていただき、写経しながら、ところどころ自分用の補足をメモすることで、自分用の補足資料になることを目指す。私の解釈がおかしく、メモが誤っている場合があるので注意。
まずは必要なライブラリや設定を行う。使用するデータは4人の麻雀の143ゲーム分の結果データを利用する。
library(tidyverse)
library(rstan)
library(patchwork)
library(MCMCpack)
options(max.print = 999999)
rstan_options(auto_write=TRUE)
options(mc.cores=parallel::detectCores())
pointData <- read_csv('~/Desktop/pointData.csv')
head(pointData, 10)
## # A tibble: 10 × 4
## player_A player_B player_C player_D
## <dbl> <dbl> <dbl> <dbl>
## 1 10 -26 66 -50
## 2 5 -38 -24 57
## 3 -23 45 -34 12
## 4 -39 44 9 -14
## 5 -20 55 2 -37
## 6 -19 6 54 -41
## 7 -3 97 -63 -31
## 8 -32 114 3 -85
## 9 9 -14 -36 41
## 10 -14 42 -34 6
各プレイヤーの点数は下記の通りである。プレイヤーCが少し弱くみえる。
pointData %>%
pivot_longer(cols = everything()) %>%
ggplot(., aes(name, value, group = 1)) +
theme_bw(base_size = 18) +
geom_jitter(width = 0.1) +
stat_summary(fun.y = mean, geom = 'point', col = 'tomato') +
stat_summary(fun.y = mean, geom = 'line', col = 'tomato')
タイトルの通り「本当に麻雀が強いのは誰か?」を推定すること。そのためにディリクレ分布を使ったモデルを考える。ディリクレ分布は下記の通り。\(V\)は次元数、\(\boldsymbol{\beta} = (\beta_1, \beta_2, \cdots, \beta_V)\)はパラメタ、確率変数の実現値\(\boldsymbol{\phi} = (\phi_1, \phi_2, \cdots, \phi_V)\)である。\(\boldsymbol{\phi}\)の総和は1となる。\(C\)は正規化係数で、\(\Gamma(x)\)はガンマ関数。
\[ \begin{aligned} C &= \frac{\Gamma(\sum_{v=1}^V \beta_v)}{\prod_{v=1}^V \Gamma(\beta_v)} \\ \mathrm{Dirichlet}(\boldsymbol{\phi} | \boldsymbol{\beta}) &= C \prod_{v=1}^V \phi_v^{\beta_v-1} \end{aligned} \]
Wikipediaによると、
ベータ分布を多変量に拡張して一般化した形をしており、そのため多変量ベータ分布とも呼ばれる。
とある。パラメタがc(2.5, 15, 7.5)
の3次元のディリクレ分布を可視化しておく。パラメタの総和は1になるので、\(X3=X1+X2\)である。また、パラメタの値が相対的に大きくなれば、その値が出やすくなる。
data.frame(rdirichlet(500, c(2.5, 15, 7.5))) %>%
ggplot(., aes(x = X1, y = X2, col = X3)) +
geom_point(alpha = 0.2) +
labs(title = 'Dirichlet Distribution') +
scale_x_continuous(limits = c(0, 1)) +
scale_y_continuous(limits = c(0, 1)) +
coord_fixed() +
theme_bw()
なぜディリクレ分布をここで利用するのかというと、麻雀の点数ルールが関わっている。25000点の30000点返しというルールであれば、30000点を基準として考える。例えば、下記の結果となった場合、
であり、これらの総和は0になる。
- 25 - 12 + 2 + 35
## [1] 0
麻雀のポイント合計は常に0であるため、ポイントにソフトマックス関数を適用して、各ゲームごとの総和が1になるように調整する。そうすることでディリクレ分布のパラメタの総和が1になるという条件の元、推定が可能になる。
\[ \begin{eqnarray} point_{it} = \frac{exp[\alpha_{it}]}{\sum^{4}_{i} exp[\alpha_{it}]} \end{eqnarray} \]
変換したデータはこちら。
# 前処理
playerNum <- ncol(pointData) # num of player
hantyanNum <- nrow(pointData) # num of game
# 値が大きいので、小さくしてから変換
point_mod <- as.matrix(pointData/100)
point <- matrix(0, hantyanNum, playerNum)
for(h in 1:hantyanNum){
point[h,] <- exp(point_mod[h,])/sum(exp(point_mod[h,]))
}
head(point)
## [,1] [,2] [,3] [,4]
## [1,] 0.2501776 0.1745430 0.4379790 0.1373004
## [2,] 0.2450500 0.1594072 0.1833620 0.4121808
## [3,] 0.1890795 0.3732199 0.1693839 0.2683166
## [4,] 0.1614617 0.3702832 0.2609341 0.2073209
## [5,] 0.1920587 0.4065883 0.2393199 0.1620332
## [6,] 0.1937374 0.2487638 0.4020208 0.1554779
総和が1になっていることがわかる。
apply(point, 1, sum)
## [1] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
## [38] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
## [75] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
## [112] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
モデルはこちら。パラメタ\(\theta\)をもつディリクレ分布に従い、ポイント\(\overrightarrow{ point_{t}}\)が生成される。
\[ \begin{eqnarray} \overrightarrow{ point_{t}} &\sim& Dirichlet( \overrightarrow{ \theta } ) \\ \overrightarrow{ point_{t}} &=& (point1_{t},point2_{t},point3_{t},point4_{t}) \end{eqnarray} \]
モデルはこちら。総和が1になるためsimplex
型を利用している。
data {
int N;
int G;
simplex [N] point[G];
}
parameters {
vector <lower=0> [N] theta;
}
model {
for(t in 1:G){
point[t,] ~ dirichlet(theta);
}
}
データを用意する。
data <- list(
N = playerNum,
G = hantyanNum,
point = point
)
map(.x = data, .f = function(x){head(x, 50)})
## $N
## [1] 4
##
## $G
## [1] 143
##
## $point
## [,1] [,2] [,3] [,4]
## [1,] 0.2501776 0.1745430 0.4379790 0.13730038
## [2,] 0.2450500 0.1594072 0.1833620 0.41218083
## [3,] 0.1890795 0.3732199 0.1693839 0.26831664
## [4,] 0.1614617 0.3702832 0.2609341 0.20732095
## [5,] 0.1920587 0.4065883 0.2393199 0.16203317
## [6,] 0.1937374 0.2487638 0.4020208 0.15547794
## [7,] 0.1990891 0.5411802 0.1092624 0.15046829
## [8,] 0.1367310 0.5887580 0.1940305 0.08048053
## [9,] 0.2625161 0.2085779 0.1673877 0.36151827
## [10,] 0.2087331 0.3654234 0.1708962 0.25494724
## [11,] 0.1328743 0.2596679 0.4072404 0.20021737
## [12,] 0.3737486 0.2051176 0.1629728 0.25816099
## [13,] 0.2706772 0.3690481 0.1577367 0.20253792
## [14,] 0.1522056 0.1877727 0.2664622 0.39355951
## [15,] 0.2035715 0.1841991 0.3635865 0.24864281
## [16,] 0.2080430 0.2566581 0.4065654 0.12873355
## [17,] 0.2671102 0.1808488 0.4024864 0.14955457
## [18,] 0.1593905 0.3692069 0.2006089 0.27079368
## [19,] 0.4735884 0.1440757 0.1035794 0.27875650
## [20,] 0.3866301 0.2565871 0.1685898 0.18819305
## [21,] 0.2055410 0.3562545 0.1769108 0.26129378
## [22,] 0.1880236 0.3862819 0.2589325 0.16676197
## [23,] 0.4209048 0.2604489 0.1968431 0.12180321
## [24,] 0.3739299 0.1680176 0.1971704 0.26088207
## [25,] 0.1332970 0.4425615 0.1502919 0.27384969
## [26,] 0.3744686 0.1477480 0.2638836 0.21389986
## [27,] 0.2688386 0.2031838 0.1614364 0.36654125
## [28,] 0.3916435 0.1849993 0.1608307 0.26252651
## [29,] 0.1427067 0.2679476 0.3764523 0.21289337
## [30,] 0.2517987 0.2041041 0.1721954 0.37190184
## [31,] 0.1536675 0.2386006 0.1858223 0.42190958
## [32,] 0.2564553 0.1685031 0.2037624 0.37127921
## [33,] 0.1777644 0.2375693 0.4285722 0.15609413
## [34,] 0.2598664 0.1560486 0.1925137 0.39157127
## [35,] 0.1765560 0.4704258 0.1049662 0.24805197
## [36,] 0.2750949 0.1668535 0.1393678 0.41868386
## [37,] 0.2731049 0.3723580 0.1639983 0.19053881
## [38,] 0.1895500 0.1461528 0.4134980 0.25079922
## [39,] 0.3921413 0.2735877 0.1471750 0.18709606
## [40,] 0.1998985 0.3678994 0.2618599 0.17034224
## [41,] 0.3972524 0.1839331 0.1551779 0.26363667
## [42,] 0.1811013 0.5719529 0.1044864 0.14245935
## [43,] 0.1893221 0.1645887 0.2289376 0.41715158
## [44,] 0.2611288 0.3780452 0.1491591 0.21166688
## [45,] 0.1517271 0.4249976 0.1853199 0.23795545
## [46,] 0.3749100 0.2078229 0.2488091 0.16845794
## [47,] 0.1609359 0.2734192 0.1965676 0.36907733
## [48,] 0.1727035 0.4290511 0.2525416 0.14570383
## [49,] 0.2629568 0.4207294 0.1762652 0.14004864
## [50,] 0.1443187 0.3769167 0.2656087 0.21315596
先にコンパイルしてから、
model <- stan_model('model01.stan')
sampling()
関数でサンプリングする。
fit <- sampling(object = model, data = data, seed = 1989)
推定結果を確認する。プレイヤーC以外の雀力は拮抗している感じで、プレイヤーCだけ想定的に弱い。
print(fit, prob = c(0.025, 0.5, 0.975), digits = 2)
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 50% 97.5% n_eff Rhat
## theta[1] 5.21 0.01 0.38 4.48 5.21 5.95 879 1
## theta[2] 5.50 0.01 0.39 4.76 5.50 6.30 816 1
## theta[3] 4.56 0.01 0.33 3.93 4.56 5.20 1009 1
## theta[4] 5.14 0.01 0.37 4.42 5.14 5.87 878 1
## lp__ 464.29 0.04 1.45 460.66 464.66 466.06 1659 1
##
## Samples were drawn using NUTS(diag_e) at Fri Jan 26 18:04:12 2024.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
事後分布を可視化しておく。
stan_plot(
fit,
point_est = 'mean',
show_density = TRUE,
ci_level = 0.9,
outer_level = 1,
pars = 'theta',
fill_color = 'tomato'
)
事後分布の値をもとに、各プレイヤーごとの強さに差がある確率を計算する。事後分布の値を引き算し、0より大きければ、そのプレイヤーの方が強いことになるため、MCMCサンプルの回数分、計算する。
theta <- rstan::extract(fit)$theta
computeProb <- function(parameter, i, j){
temp <- parameter[, i] - parameter[, j] > 0
return(sum(temp) / length(temp))
}
mat <- matrix(0, nrow = playerNum, ncol = playerNum)
rownames(mat) <- colnames(mat) <- c('playerA', 'playerB', 'playerC', 'playerD')
for (i in 1:playerNum) {
for (j in 1:playerNum) {
if (i != j) {
mat[i, j] <- computeProb(theta, i, j)
} else {
mat[i, j] <- 1
}
}
}
round(mat, 2)
## playerA playerB playerC playerD
## playerA 1.00 0.12 1.00 0.61
## playerB 0.88 1.00 1.00 0.92
## playerC 0.00 0.00 1.00 0.01
## playerD 0.39 0.09 0.99 1.00
事後分布の可視化からプレイヤーBが強かったので、プレイヤーBを参考にすると、プレイヤーBは、プレイヤーAよりも88%強く、プレイヤーCよりも100%強く、プレイヤーDよりも92%強いことがわかる。まとめると、プレイヤーBは他のプレイヤーよりも93%強い。
(
sum(theta[,2] - theta[,1] > 0) +
sum(theta[,2] - theta[,3] > 0) +
sum(theta[,2] - theta[,4] > 0)
) / (nrow(theta)*3)
## [1] 0.9304167