UPDATE: 2024-01-16 21:23:48.385572

はじめに

このノートは「ベイズ統計」に関する何らかの内容をまとめ、ベイズ統計への理解を深めていくために作成している。今回は「RとStanではじめるベイズ統計モデリングによるデータ分析入門」を写経していく。基本的には気になった部分を写経しながら、ところどころ自分用の補足をメモすることで、自分用の補足資料になることを目指す。私の解釈がおかしく、メモが誤っている場合があるので注意。

5.2 brmsとは

brmsとは、Bayesian Regression Models using ’Stan’の頭文字をとったもので、Stanを使ってベイジアンな回帰分析ができるパッケージのこと。Stanのコードを書かなくても一般化線形モデルや一般化線形混合モデルも推定できる。

5.4 分析の準備

ここでも参考書に従って、ビールと気温に関するサンプルデータを読み込んでおく。

library(tidyverse)
library(rstan)
library(brms)
library(patchwork)

options(max.print = 999999)
rstan_options(auto_write=TRUE)
options(mc.cores=parallel::detectCores())

file_beer_sales_2 <- read.csv('https://raw.githubusercontent.com/logics-of-blue/book-r-stan-bayesian-model-intro/master/book-data/3-2-1-beer-sales-2.csv')

head(file_beer_sales_2, 10)
##     sales temperature
## 1   41.68        13.7
## 2  110.99        24.0
## 3   65.32        21.5
## 4   72.64        13.4
## 5   76.54        28.9
## 6   62.76        28.9
## 7   46.66        12.6
## 8  100.79        26.7
## 9   85.59        19.4
## 10  97.57        21.0

5.5 brmsによる単回帰モデルの推定

brmsパッケージで単回帰モデルを推定するためには、brm()関数を利用する。glm()関数のようにモデル式、リンク関数や確率分布、データを指定すればOK。逆に、これらの情報を決めればStanのコードが自動生成できる。

simple_lm_brms <- brm(
  formula = sales ~ temperature,
  family = gaussian(link = "identity"),  
  data = file_beer_sales_2,              
  seed = 1                               
)

推定結果を確認すると、Stanでモデルを自作したときと同じようにパラメタの事後分布が得られている。

simple_lm_brms
##  Family: gaussian 
##   Links: mu = identity; sigma = identity 
## Formula: sales ~ temperature 
##    Data: file_beer_sales_2 (Number of observations: 100) 
##   Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
##          total post-warmup draws = 4000
## 
## Population-Level Effects: 
##             Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## Intercept      21.18      5.96     9.64    32.79 1.00     3705     2453
## temperature     2.46      0.29     1.90     3.02 1.00     3705     2760
## 
## Family Specific Parameters: 
##       Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sigma    16.98      1.22    14.79    19.57 1.00     3884     2948
## 
## Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).

MCMCサンプルを取得する時は、推定結果に対してas.mcmc()関数を利用すれば良い。

list(
  Size = dim(as.mcmc(simple_lm_brms, combine_chains = TRUE)),
  MCMC = head(as.mcmc(simple_lm_brms, combine_chains = TRUE), 50)
  )
## $Size
## [1] 4000    5
## 
## $MCMC
## Markov Chain Monte Carlo (MCMC) output:
## Start = 4001 
## End = 4051 
## Thinning interval = 1 
##           parameters
## iterations b_Intercept b_temperature    sigma    lprior      lp__
##       [1,]   17.291494      2.665912 16.86284 -7.728407 -428.5302
##       [2,]   22.176072      2.393405 16.23336 -7.702903 -428.3797
##       [3,]   20.649078      2.457055 17.13983 -7.745576 -428.3710
##       [4,]   20.497947      2.399533 15.96360 -7.704604 -429.1765
##       [5,]   18.387851      2.554971 15.21855 -7.662978 -429.4203
##       [6,]   32.763268      1.928922 16.74897 -7.720640 -430.2209
##       [7,]    9.973281      2.972899 16.71802 -7.729838 -430.1046
##       [8,]   23.986324      2.352285 16.49417 -7.710022 -428.4250
##       [9,]   13.656594      2.955341 18.71022 -7.813028 -431.4581
##      [10,]   24.397969      2.172297 15.46809 -7.690510 -430.7690
##      [11,]   25.672652      2.214695 18.96808 -7.830608 -430.0525
##      [12,]   23.523418      2.367624 15.84476 -7.681783 -428.6530
##      [13,]   23.243416      2.278205 17.17390 -7.755351 -428.9207
##      [14,]   30.604967      1.963894 17.18123 -7.746286 -429.8200
##      [15,]   19.561305      2.644317 19.27587 -7.840006 -430.8558
##      [16,]   18.739174      2.491874 18.36868 -7.814204 -429.5899
##      [17,]   22.521248      2.319539 18.49512 -7.816448 -429.6271
##      [18,]    5.924973      3.276812 16.70196 -7.718709 -432.6439
##      [19,]   14.525595      2.876314 16.94828 -7.729309 -429.7959
##      [20,]   15.957501      2.847406 18.13478 -7.785934 -430.6366
##      [21,]   11.209621      3.089915 18.05999 -7.782584 -431.9746
##      [22,]   11.510068      3.097923 18.01243 -7.782193 -432.4038
##      [23,]   38.073220      1.546244 17.33478 -7.759727 -433.5256
##      [24,]   43.642904      1.298038 16.91420 -7.735196 -436.6372
##      [25,]   45.493322      1.259228 17.69126 -7.765399 -436.7272
##      [26,]   33.791912      1.865228 17.04723 -7.734681 -430.5502
##      [27,]   31.796534      1.998478 17.76470 -7.767063 -430.1476
##      [28,]   23.144278      2.404004 19.67770 -7.860185 -430.7826
##      [29,]   23.579783      2.370542 19.43733 -7.848907 -430.4281
##      [30,]   16.802035      2.702568 14.97702 -7.644856 -430.0474
##      [31,]   17.661225      2.646508 15.50333 -7.668291 -429.1020
##      [32,]   22.241695      2.380265 18.47862 -7.807939 -429.2763
##      [33,]   23.994133      2.301813 14.14357 -7.614204 -431.6037
##      [34,]   17.081272      2.623425 14.62513 -7.637674 -430.5794
##      [35,]   19.268381      2.589057 16.61017 -7.715299 -428.4392
##      [36,]   20.112310      2.411279 15.91553 -7.704256 -429.3199
##      [37,]   19.284037      2.651970 18.59386 -7.806681 -430.0518
##      [38,]   20.439210      2.611772 17.71983 -7.765848 -429.5511
##      [39,]   19.428455      2.620386 16.23283 -7.697139 -428.9045
##      [40,]   15.515091      2.829066 17.29885 -7.745406 -429.6431
##      [41,]   22.250251      2.395467 17.20843 -7.746286 -428.3660
##      [42,]   23.919751      2.349645 16.12630 -7.694018 -428.5190
##      [43,]   23.778799      2.337470 16.17207 -7.697539 -428.4623
##      [44,]   22.431422      2.407983 18.02559 -7.782222 -428.8236
##      [45,]   22.206173      2.477126 17.50667 -7.755056 -428.7762
##      [46,]   28.090255      2.191537 15.94652 -7.684494 -429.4760
##      [47,]   11.089470      2.980302 19.48794 -7.852873 -431.6765
##      [48,]   11.326685      2.958818 19.45156 -7.851960 -431.5223
##      [49,]    6.291949      3.199711 19.03379 -7.832876 -432.4982
##      [50,]   39.173221      1.790383 17.61729 -7.776925 -434.6981
##      [51,]   33.340715      1.958910 18.77913 -7.816174 -431.4895

トレースプロットや事後分布がはplot()関数を利用すれば確認できる。

plot(simple_lm_brms)

5.6 brmsの基本的な使い方

bf()関数を利用することで、モデル式は別で指定できる。brmsformula, bformというクラスなので、通常のformulaクラスではないので注意。

simple_lm_formula <- bf(sales ~ temperature)
class(simple_lm_formula)
## [1] "brmsformula" "bform"

MCMCの設定に関してもStanでサンプリングする時のように指定できる。

simple_lm_brms_2 <- brm(
  formula = simple_lm_formula, # bf関数で作成済みのformulaを指定
  family = gaussian(),       # 正規分布を使う(リンク関数省略)
  data = file_beer_sales_2,  # データ
  seed = 1,                  # 乱数の種
  chains = 4,                # チェーン数
  iter = 2000,               # 乱数生成の繰り返し数
  warmup = 1000,             # バーンイン期間
  thin = 1                   # 間引き数(1なら間引き無し) 
)

5.7 事前分布の変更

事前分布には無情報事前分布がデフォルトで指定されている。確認する場合はprior_summary()関数を利用する。事前分布の一覧を確認できる。

prior_summary(simple_lm_brms)
##                   prior     class        coef group resp dpar nlpar lb ub
##                  (flat)         b                                        
##                  (flat)         b temperature                            
##  student_t(3, 71.5, 20) Intercept                                        
##     student_t(3, 0, 20)     sigma                                    0   
##        source
##       default
##  (vectorized)
##       default
##       default

事前分布を変更する際は、prior引数にset_prior()関数を利用して事前分布を指定する。

simple_lm_brms_3 <- brm(
  formula = sales ~ temperature,
  family = gaussian(),
  data = file_beer_sales_2, 
  seed = 1,
  prior = c(
    set_prior("normal(50,100000)", class = "b", coef = ""),
    set_prior("normal(5,100000)", class = "b", coef = "temperature")
  )
)

指定した事前分布を利用していることがわかる。

prior_summary(simple_lm_brms_3)
##                   prior     class        coef group resp dpar nlpar lb ub
##       normal(50,100000)         b                                        
##        normal(5,100000)         b temperature                            
##  student_t(3, 71.5, 20) Intercept                                        
##     student_t(3, 0, 20)     sigma                                    0   
##   source
##     user
##     user
##  default
##  default

実際にサンプリングで使用されているStanコードはstancode()関数で確認できる。理由は分からないが、interceptの事前分布はStanコードを見ても変更されていない模様。temperatureのパラメタは指定した通りとなっている。

stancode(simple_lm_brms_3)
## // generated with brms 2.20.4
## functions {
## }
## data {
##   int<lower=1> N;  // total number of observations
##   vector[N] Y;  // response variable
##   int<lower=1> K;  // number of population-level effects
##   matrix[N, K] X;  // population-level design matrix
##   int<lower=1> Kc;  // number of population-level effects after centering
##   int prior_only;  // should the likelihood be ignored?
## }
## transformed data {
##   matrix[N, Kc] Xc;  // centered version of X without an intercept
##   vector[Kc] means_X;  // column means of X before centering
##   for (i in 2:K) {
##     means_X[i - 1] = mean(X[, i]);
##     Xc[, i - 1] = X[, i] - means_X[i - 1];
##   }
## }
## parameters {
##   vector[Kc] b;  // regression coefficients
##   real Intercept;  // temporary intercept for centered predictors
##   real<lower=0> sigma;  // dispersion parameter
## }
## transformed parameters {
##   real lprior = 0;  // prior contributions to the log posterior
##   lprior += normal_lpdf(b[1] | 5,100000);
##   lprior += student_t_lpdf(Intercept | 3, 71.5, 20);
##   lprior += student_t_lpdf(sigma | 3, 0, 20)
##     - 1 * student_t_lccdf(0 | 3, 0, 20);
## }
## model {
##   // likelihood including constants
##   if (!prior_only) {
##     target += normal_id_glm_lpdf(Y | Xc, Intercept, b, sigma);
##   }
##   // priors including constants
##   target += lprior;
## }
## generated quantities {
##   // actual population-level intercept
##   real b_Intercept = Intercept - dot_product(means_X, b);
## }

Stanに渡しているデータについては、standata()関数で確認できる。

standata(simple_lm_brms_3)
## $N
## [1] 100
## 
## $Y
##   [1]  41.68 110.99  65.32  72.64  76.54  62.76  46.66 100.79  85.59  97.57
##  [11]  45.93  87.47  72.45  56.37  72.84  75.45  63.77  49.06  68.51  35.32
##  [21]  64.18  69.46  84.63  59.02  61.44  55.89  72.05  74.33 109.04  30.35
##  [31]  60.44  27.81  77.50  67.92  37.63 103.58  77.45  54.98  72.45  58.30
##  [41] 118.01  84.97  33.00  32.87  69.56  65.95 123.14  62.61  57.36  76.48
##  [51]  61.37  49.66  74.54  80.26  45.82 116.22  98.35 124.63  69.43  75.24
##  [61]  68.09  85.49  83.33  78.38  96.46  73.59  82.96  85.21  57.60  36.50
##  [71]  77.37  62.58  72.66  47.79  38.59  90.81  49.46  98.08  39.39  47.61
##  [81]  72.85  83.46  59.32  34.76  73.19 105.76 108.98  73.05 113.40  40.19
##  [91]  85.00  70.90  51.90  74.28  79.44  44.94  76.11  58.00  38.65  52.04
## 
## $K
## [1] 2
## 
## $Kc
## [1] 1
## 
## $X
##     Intercept temperature
## 1           1        13.7
## 2           1        24.0
## 3           1        21.5
## 4           1        13.4
## 5           1        28.9
## 6           1        28.9
## 7           1        12.6
## 8           1        26.7
## 9           1        19.4
## 10          1        21.0
## 11          1        21.1
## 12          1        14.8
## 13          1        25.2
## 14          1        13.6
## 15          1        18.1
## 16          1        27.1
## 17          1        29.5
## 18          1        14.5
## 19          1        18.9
## 20          1        11.5
## 21          1        23.2
## 22          1        17.8
## 23          1        26.7
## 24          1        13.0
## 25          1        16.9
## 26          1        19.8
## 27          1        13.0
## 28          1        17.1
## 29          1        29.3
## 30          1        12.6
## 31          1        10.2
## 32          1        13.3
## 33          1        26.2
## 34          1        27.4
## 35          1        20.3
## 36          1        22.5
## 37          1        26.9
## 38          1        15.7
## 39          1        23.3
## 40          1        13.0
## 41          1        29.6
## 42          1        15.9
## 43          1        12.3
## 44          1        13.3
## 45          1        28.9
## 46          1        25.9
## 47          1        29.5
## 48          1        17.0
## 49          1        20.0
## 50          1        26.2
## 51          1        10.1
## 52          1        10.3
## 53          1        23.7
## 54          1        28.6
## 55          1        15.5
## 56          1        26.2
## 57          1        25.7
## 58          1        29.8
## 59          1        22.3
## 60          1        24.2
## 61          1        25.4
## 62          1        27.7
## 63          1        22.5
## 64          1        15.2
## 65          1        27.2
## 66          1        18.7
## 67          1        17.8
## 68          1        19.2
## 69          1        14.4
## 70          1        11.3
## 71          1        15.5
## 72          1        16.2
## 73          1        10.8
## 74          1        13.7
## 75          1        13.7
## 76          1        25.1
## 77          1        15.8
## 78          1        27.4
## 79          1        18.1
## 80          1        21.5
## 81          1        17.0
## 82          1        23.4
## 83          1        10.5
## 84          1        18.0
## 85          1        14.0
## 86          1        27.1
## 87          1        29.4
## 88          1        16.5
## 89          1        24.7
## 90          1        16.8
## 91          1        29.5
## 92          1        17.9
## 93          1        17.6
## 94          1        21.2
## 95          1        19.3
## 96          1        13.9
## 97          1        18.5
## 98          1        11.9
## 99          1        12.3
## 100         1        18.8
## attr(,"assign")
## [1] 0 1
## 
## $prior_only
## [1] 0
## 
## attr(,"class")
## [1] "standata" "list"

5.9 補足: make_stancode関数によるStanコードの生成

make_stancode()関数でStanコードを生成できる。気になった点を下記にまとめておく。

  • data : 説明変数にはデザイン行列を使用し、データに合わせて柔軟にモデリングできるようになっている
  • transformed data: 説明変数を中心化して変換
  • generated quantities: 中心化しているため切片が推定できず、推定後に構築
make_stancode(
  formula = sales ~ temperature,
  family = gaussian(),
  data = file_beer_sales_2, 
  prior = c(prior("", class = "Intercept"),
            prior("", class = "sigma"))
)
## // generated with brms 2.20.4
## functions {
## }
## data {
##   int<lower=1> N;  // total number of observations
##   vector[N] Y;  // response variable
##   int<lower=1> K;  // number of population-level effects
##   matrix[N, K] X;  // population-level design matrix
##   int<lower=1> Kc;  // number of population-level effects after centering
##   int prior_only;  // should the likelihood be ignored?
## }
## transformed data {
##   matrix[N, Kc] Xc;  // centered version of X without an intercept
##   vector[Kc] means_X;  // column means of X before centering
##   for (i in 2:K) {
##     means_X[i - 1] = mean(X[, i]);
##     Xc[, i - 1] = X[, i] - means_X[i - 1];
##   }
## }
## parameters {
##   vector[Kc] b;  // regression coefficients
##   real Intercept;  // temporary intercept for centered predictors
##   real<lower=0> sigma;  // dispersion parameter
## }
## transformed parameters {
##   real lprior = 0;  // prior contributions to the log posterior
## }
## model {
##   // likelihood including constants
##   if (!prior_only) {
##     target += normal_id_glm_lpdf(Y | Xc, Intercept, b, sigma);
##   }
##   // priors including constants
##   target += lprior;
## }
## generated quantities {
##   // actual population-level intercept
##   real b_Intercept = Intercept - dot_product(means_X, b);
## }

5.9 補足: make_standata関数によるStanに渡すデータの作成

make_standata()関数でStanにわたすデータを生成できる。

make_standata(
  formula = sales ~ temperature,
  family = gaussian(),
  data = file_beer_sales_2
)
## $N
## [1] 100
## 
## $Y
##   [1]  41.68 110.99  65.32  72.64  76.54  62.76  46.66 100.79  85.59  97.57
##  [11]  45.93  87.47  72.45  56.37  72.84  75.45  63.77  49.06  68.51  35.32
##  [21]  64.18  69.46  84.63  59.02  61.44  55.89  72.05  74.33 109.04  30.35
##  [31]  60.44  27.81  77.50  67.92  37.63 103.58  77.45  54.98  72.45  58.30
##  [41] 118.01  84.97  33.00  32.87  69.56  65.95 123.14  62.61  57.36  76.48
##  [51]  61.37  49.66  74.54  80.26  45.82 116.22  98.35 124.63  69.43  75.24
##  [61]  68.09  85.49  83.33  78.38  96.46  73.59  82.96  85.21  57.60  36.50
##  [71]  77.37  62.58  72.66  47.79  38.59  90.81  49.46  98.08  39.39  47.61
##  [81]  72.85  83.46  59.32  34.76  73.19 105.76 108.98  73.05 113.40  40.19
##  [91]  85.00  70.90  51.90  74.28  79.44  44.94  76.11  58.00  38.65  52.04
## 
## $K
## [1] 2
## 
## $Kc
## [1] 1
## 
## $X
##     Intercept temperature
## 1           1        13.7
## 2           1        24.0
## 3           1        21.5
## 4           1        13.4
## 5           1        28.9
## 6           1        28.9
## 7           1        12.6
## 8           1        26.7
## 9           1        19.4
## 10          1        21.0
## 11          1        21.1
## 12          1        14.8
## 13          1        25.2
## 14          1        13.6
## 15          1        18.1
## 16          1        27.1
## 17          1        29.5
## 18          1        14.5
## 19          1        18.9
## 20          1        11.5
## 21          1        23.2
## 22          1        17.8
## 23          1        26.7
## 24          1        13.0
## 25          1        16.9
## 26          1        19.8
## 27          1        13.0
## 28          1        17.1
## 29          1        29.3
## 30          1        12.6
## 31          1        10.2
## 32          1        13.3
## 33          1        26.2
## 34          1        27.4
## 35          1        20.3
## 36          1        22.5
## 37          1        26.9
## 38          1        15.7
## 39          1        23.3
## 40          1        13.0
## 41          1        29.6
## 42          1        15.9
## 43          1        12.3
## 44          1        13.3
## 45          1        28.9
## 46          1        25.9
## 47          1        29.5
## 48          1        17.0
## 49          1        20.0
## 50          1        26.2
## 51          1        10.1
## 52          1        10.3
## 53          1        23.7
## 54          1        28.6
## 55          1        15.5
## 56          1        26.2
## 57          1        25.7
## 58          1        29.8
## 59          1        22.3
## 60          1        24.2
## 61          1        25.4
## 62          1        27.7
## 63          1        22.5
## 64          1        15.2
## 65          1        27.2
## 66          1        18.7
## 67          1        17.8
## 68          1        19.2
## 69          1        14.4
## 70          1        11.3
## 71          1        15.5
## 72          1        16.2
## 73          1        10.8
## 74          1        13.7
## 75          1        13.7
## 76          1        25.1
## 77          1        15.8
## 78          1        27.4
## 79          1        18.1
## 80          1        21.5
## 81          1        17.0
## 82          1        23.4
## 83          1        10.5
## 84          1        18.0
## 85          1        14.0
## 86          1        27.1
## 87          1        29.4
## 88          1        16.5
## 89          1        24.7
## 90          1        16.8
## 91          1        29.5
## 92          1        17.9
## 93          1        17.6
## 94          1        21.2
## 95          1        19.3
## 96          1        13.9
## 97          1        18.5
## 98          1        11.9
## 99          1        12.3
## 100         1        18.8
## attr(,"assign")
## [1] 0 1
## 
## $prior_only
## [1] 0
## 
## attr(,"class")
## [1] "standata" "list"

5.12 brmsによる事後分布の可視化

パラメタの事後分布に関する95%ベイズ信用区間を確認したいのであれば、stanplot()関数が便利。

stanplot(simple_lm_brms, 
         type = "intervals",
         pars = "^b_",
         prob = 0.8,        # 太い線の範囲
         prob_outer = 0.95  # 細い線の範囲
)

5.13 brmsによる予測

推定したモデルを利用して予測値を計算したい時は、fitted()関数を利用する。指定した値で得られる予測値の信用区間が計算される。

new_data <- data.frame(temperature = c(20:25))
fitted(simple_lm_brms, new_data)
##      Estimate Est.Error     Q2.5    Q97.5
## [1,] 70.36463  1.652491 67.17275 73.58784
## [2,] 72.82405  1.684042 69.51305 76.10401
## [3,] 75.28347  1.762855 71.89943 78.71577
## [4,] 77.74289  1.883004 74.11427 81.40969
## [5,] 80.20231  2.037190 76.30567 84.19963
## [6,] 82.66174  2.218325 78.43960 86.98622

MCMCの結果を利用して予測値の信用区間を自ら計算することも可能。

mcmc_sample <- as.mcmc(simple_lm_brms, combine_chains = TRUE)
mcmc_b_Intercept   <- mcmc_sample[,"b_Intercept"]
mcmc_b_temperature <- mcmc_sample[,"b_temperature"]
mcmc_sigma         <- mcmc_sample[,"sigma"]

map_dfr(.x = 20:25, .f = function(x){
  out <- mcmc_b_Intercept + x * mcmc_b_temperature
  m <- mean(out)
  q <- quantile(out, probs = c(0.025, 0.975))
  return(tibble(x = x, Estimate = m, Q2.5 = q[1], Q97.5 = q[2]))
})
## # A tibble: 6 × 4
##       x Estimate  Q2.5 Q97.5
##   <int>    <dbl> <dbl> <dbl>
## 1    20     70.4  67.2  73.6
## 2    21     72.8  69.5  76.1
## 3    22     75.3  71.9  78.7
## 4    23     77.7  74.1  81.4
## 5    24     80.2  76.3  84.2
## 6    25     82.7  78.4  87.0

5.14 predict関数を使わない予測の実装

予測値といえばpredict()関数かもしれないが、fitted()関数で得られる結果よりもQ2.5, Q97.5の幅が広くなっている。これは信用区間ではなく、予測区間が計算されているためであり、予測区間の算出にあたっては乱数(rnorm())を使用しているため、毎回結果が異なる。

set.seed(1)
predict(simple_lm_brms, new_data)
##      Estimate Est.Error     Q2.5    Q97.5
## [1,] 70.45431  16.85742 37.32262 103.7980
## [2,] 72.63553  17.19176 38.34796 106.0014
## [3,] 75.00771  17.00605 41.83472 108.7187
## [4,] 77.92388  17.45858 43.87041 111.6380
## [5,] 79.98086  17.12105 45.82413 113.7551
## [6,] 82.99458  16.88340 50.20925 115.6032

MCMCの結果を利用して予測値の予測区間を自ら計算することも可能。

set.seed(1)
map_dfr(.x = 20:25, .f = function(x){
  out <- mcmc_b_Intercept + x * mcmc_b_temperature
  pred <- rnorm(n = 4000, mean = out, sd = mcmc_sigma)
  m <- mean(pred)
  q <- quantile(pred, probs = c(0.025, 0.975))
  return(tibble(x = x, Estimate = m, Q2.5 = q[1], Q97.5 = q[2]))
})
## # A tibble: 6 × 4
##       x Estimate  Q2.5 Q97.5
##   <int>    <dbl> <dbl> <dbl>
## 1    20     70.4  35.2  105.
## 2    21     72.6  38.4  106.
## 3    22     75.1  41.8  108.
## 4    23     77.4  45.2  111.
## 5    24     80.5  47.3  115.
## 6    25     83.2  49.2  117.
# mcmc_sigma、nなどの直書きを避ける例
# simulate_predictions <- function(x, n_sim, mcmc_b_Intercept, mcmc_b_temperature, mcmc_sigma) {
#   out <- mcmc_b_Intercept + x * mcmc_b_temperature
#   pred <- rnorm(n = n_sim, mean = out, sd = mcmc_sigma)
#   m <- mean(pred)
#   q <- quantile(pred, probs = c(0.025, 0.975))
#   return(tibble(x = x, Estimate = m, Q2.5 = q[1], Q97.5 = q[2]))
# }
# 
# result <- map_dfr(.x = 20:25, .f = function(x) {
#   simulate_predictions(x, n_sim = 4000, mcmc_b_Intercept, mcmc_b_temperature, mcmc_sigma)
# })

ちなみに、interval='confidence'としても、信用区間は計算できないので注意。

回帰直線の図示

ちなみにpredict()関数を使用せずとも、予測値を算出できる。Stanではgenerated quantitiesブロックで予測値を生成するが、

generated quantities {
  vector[N_pred] mu_pred;           // ビールの売り上げの期待値
  vector[N_pred] sales_pred;        // ビールの売り上げの予測値

  for (i in 1:N_pred) {
    mu_pred[i] = Intercept + beta*temperature_pred[i];
    sales_pred[i] = normal_rng(mu_pred[i], sigma);
  }
}

brmsではconditional_effects()関数を利用することで、回帰直線の95%ベイズ信用区間付きのグラフが得られる。marginal_effects()関数は非推奨になっている。

eff <- conditional_effects(simple_lm_brms, method = "posterior_epred")
plot(eff, points = TRUE)

conditional_effects(method = "predict")関数を利用することで、回帰直線の95%ベイズ予測区間付きのグラフが得られる。

set.seed(1)
eff_pre <- conditional_effects(simple_lm_brms, method = "predict")
plot(eff_pre, points = TRUE)

他にも、複数の説明変数があった時に、交互作用を表現したりもできる。

# 参考:複数の説明変数があるときは、特定の要因だけを切り出せる
conditional_effects(simple_lm_brms, effects = "temperature")

# 参考:複数の説明変数を同時に図示
conditional_effects(brms_model, effects = "x1:x2")