UPDATE: 2021-09-15 00:10:19

1 はじめに

ここでは、下記のドキュメントを参考に、prophetパッケージの基本的な使い方をおさらいすることを目的としています。ゆくゆくは外部予測変数を追加したモデルやクロスバリデーション、パラメタチューニングなどなど、モデルを発展させながら使用方法をまとめていきます。

モデルの数理部分は下記のprophetに関する論文やブログ記事を参照願います。非常にわかりやすいです。

2 ライブラリと関数の読み込み

library(prophet)
library(forecast)
library(tidyverse)

head_tail <- function(data, n = 5){
  stopifnot(is.data.frame(data))
  head(data, n = n) %>% print()
  cat(paste(rep("-", 100), collapse = ""), "\n")
  tail(data, n = n) %>% print()
}

3 外部予測変数

Prophetは一般化加法モデルをベースにしているため、季節性の周期をモデルに組みやすいことを前回の記事で説明しました。これは周期性に限った話ではなく、目的変数に影響を及ぼす外部予測変数(Additional Regressor)もモデルに組みこむことが可能です。

注意としては、外部予測変数を組み込んだモデルで予測を行う場合、その変数の未来の時点での値が分かっている必要があります。

また、外部予測変数の数は1つだけに限られるということもなく、必要に応じて変数を追加してモデルを構築していくことが可能です。下記のようなイメージでモデルを拡張できます。gがトレンド、sが季節周期、hが祝日効果です。

\[ y_{t} = g(t) + s(t) + h(t) + \beta_{1}x_{1}(t) + \beta_{2}x_{2}(t) + \ldots + \beta_{n}x_{n}(t) + \epsilon_{t} \]   ここからは、実際にProphetでモデルを作って、外部予測変数をモデルに組み込んでいきます。また、冒頭の参考文献・サイト以外に、ここでは下記の記事も参考にしています。Prophetの外部予測変数を利用してモデルを改良していく過程がまとめられているので非常に参考になりました。

3.1 サンプルデータ

今回は、1969年1月から1984年12月までに英国で死亡または重傷を負った自動車ドライバーの月間合計数を示す時系列データSeatbeltsを使用します。推移を見るとわかりますが、このデータの特徴は、シートベルトの強制着用法による、死亡または重傷を負った自動車ドライバーの月間合計の変化です。1983年1月31日に導入されたことにより、死亡または重傷を負った自動車ドライバーの月間合計がガクンと下がっていることがわかります。また、車の死亡者数と関係しそうなガソリン価格PetrolPriceやシートベルト法フラグlawなどもあわせて利用します。

df <- Seatbelts %>% 
  tibble::as_tibble() %>% 
  dplyr::mutate(dt = seq(from = as.Date("1969-01-01"), by = "1 month", length.out = n())) %>% 
  dplyr::select(ds = dt, y = drivers, PetrolPrice, law) 

df %>% 
  ggplot(aes(ds, y)) +
  geom_line(size = 1, col = "#749FC6") + 
  geom_vline(xintercept = as.Date("1983-02-01"), col = "red", linetype = "dashed") +
  scale_y_continuous(labels = scales::comma) + 
  scale_x_date(date_breaks = '12 month', date_labels = "%Y-%m") + 
  theme_bw() + 
  theme(axis.text.x = element_text(angle = 30, hjust = 1)) +
  ggtitle("UKDriverDeaths")

3.2 Prophetのadd_regressor関数

外部予測変数を利用するためにはadd_regressor()を使用します。add_regressor()を使用する場合は、prophet()でインスタンスのみを作成しておきます。ここでは、後ほどの外部予測変数の効果の説明のため周期性はないもとしておきます。

# 外部予測変数を追加する場合、prophet()でインスタンスのみを作成
m <- prophet(fit = FALSE,
             yearly.seasonality = FALSE,
             weekly.seasonality = FALSE,
             daily.seasonality = FALSE
             )

m <- add_regressor(m, name = "law", standardize = FALSE)
m <- add_regressor(m, name = "PetrolPrice", standardize = FALSE)
m <- fit.prophet(m, df)

fore_df <- predict(m, df)

res_df <- df %>%
  dplyr::left_join(fore_df %>%
      dplyr::mutate(ds = as.Date(ds)) %>%
      dplyr::select(ds, law_ef = law, PetrolPrice_ef = PetrolPrice),
    by = "ds") %>%
  dplyr::mutate(
    beta_law = if_else(is.nan(law_ef / law), 0, law_ef / law),
    beta_PetrolPrice = PetrolPrice_ef / PetrolPrice
  )

PetrolPrice_eflaw_efは、予測値に足し込まれる値なので、回帰係数を計算するために、変数の観測値で割り戻します。lawに1が立つということは、シートベルト法が施行されている時期なので、この時期では、死亡者が0, -241.4453178人減ることになります。

head_tail(res_df)
## # A tibble: 5 × 8
##   ds             y PetrolPrice   law law_ef PetrolPrice_ef beta_law
##   <date>     <dbl>       <dbl> <dbl>  <dbl>          <dbl>    <dbl>
## 1 1969-01-01  1687       0.103     0      0          -8.60        0
## 2 1969-02-01  1508       0.102     0      0          -8.54        0
## 3 1969-03-01  1507       0.102     0      0          -8.52        0
## 4 1969-04-01  1385       0.101     0      0          -8.42        0
## 5 1969-05-01  1632       0.101     0      0          -8.43        0
## # … with 1 more variable: beta_PetrolPrice <dbl>
## ---------------------------------------------------------------------------------------------------- 
## # A tibble: 5 × 8
##   ds             y PetrolPrice   law law_ef PetrolPrice_ef beta_law
##   <date>     <dbl>       <dbl> <dbl>  <dbl>          <dbl>    <dbl>
## 1 1984-08-01  1284       0.115     1  -241.          -9.58    -241.
## 2 1984-09-01  1444       0.114     1  -241.          -9.52    -241.
## 3 1984-10-01  1575       0.116     1  -241.          -9.72    -241.
## 4 1984-11-01  1737       0.116     1  -241.          -9.69    -241.
## 5 1984-12-01  1763       0.116     1  -241.          -9.69    -241.
## # … with 1 more variable: beta_PetrolPrice <dbl>

どのような形で外部予測変数が効いているかを確認するために可視化しておきます。この図を見るとわかりますが、今回のモデルでは、周期性を考慮していないので、信用区間が太い帯の様になっています。これを解消したければ、周期性を追加することで解消できます。

plot(m, fore_df)

外部予測変数の影響を可視化する場合は、周期性の影響を可視化するために利用したprophet_plot_components()を利用します。

prophet_plot_components(m = m, fcst = fore_df)

prophet_plot_components()では複数の外部予測変数を足し合わせ、トータルとしての外部予測変数の影響を可視化することになるため、分離して各変数ごとに影響を可視化しておきます。この図をみると、PetrolPriceは死亡者の数とあまり関係がないようですね。

fore_df %>% 
  dplyr::mutate(ds = as.Date(ds)) %>% 
  dplyr::select(ds, PetrolPrice, law) %>% 
  tidyr::pivot_longer(cols = -ds, names_to = "vals", values_to = "value") %>% 
  ggplot(aes(ds, value, col = vals)) + 
  geom_line(size = 1) + 
  scale_y_continuous(labels = scales::comma, breaks = seq(0, -250, -20)) + 
  scale_x_date(date_breaks = '12 month', date_labels = "%Y-%m") + 
  scale_color_brewer(palette = "Set1") + 
  theme_bw() + 
  theme(axis.text.x = element_text(angle = 30, hjust = 1)) +
  ggtitle("Extra Regressors By Each variables")

詳しくは取り上げませんが、ARIMAXでも同じ様に外部予測変数を取り込んだモデルを作成することは可能です。

auto.arima(
    y = df$y,
    # 複数変数を組み込む場合 
    xreg = as.matrix(df[,c("law", "PetrolPrice")]),
    # xreg = df$law,
    ic = "aic",
    max.order = 20)
## Series: df$y 
## Regression with ARIMA(2,0,1) errors 
## 
## Coefficients:
##          ar1      ar2      ma1  intercept        law  PetrolPrice
##       1.2752  -0.5098  -0.6114  2468.5083  -271.9515    -7375.114
## s.e.  0.1297   0.0837   0.1306   214.2586    77.0583     2089.738
## 
## sigma^2 estimated as 37957:  log likelihood=-1281.9
## AIC=2577.81   AICc=2578.42   BIC=2600.61

私の勉強不足もあって、ARIMAXなど時系列モデルに外部予測変数として「連続変数」を組み込んだ場合、回帰係数の解釈がよくわかっていません。通常の回帰分析のように解釈するものではなく、より良い予測をするための誤差を捉えるための手段であって、解釈するためのものではないのかもしれません(そんなはずはないだろうけども…)。下記のブログでも回帰係数の解釈の方法がわかりにくいという指摘がされています。

4 セッション情報

sessionInfo()
## R version 4.0.3 (2020-10-10)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Big Sur 10.16
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] ja_JP.UTF-8/ja_JP.UTF-8/ja_JP.UTF-8/C/ja_JP.UTF-8/ja_JP.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] forcats_0.5.1   stringr_1.4.0   dplyr_1.0.7     purrr_0.3.4    
##  [5] readr_1.4.0     tidyr_1.1.3     tibble_3.1.3    ggplot2_3.3.3  
##  [9] tidyverse_1.3.0 forecast_8.15   prophet_1.0     rlang_0.4.10   
## [13] Rcpp_1.0.6     
## 
## loaded via a namespace (and not attached):
##  [1] nlme_3.1-149         matrixStats_0.58.0   fs_1.5.0            
##  [4] xts_0.12.1           lubridate_1.7.9.2    RColorBrewer_1.1-2  
##  [7] httr_1.4.2           rstan_2.21.2         tools_4.0.3         
## [10] backports_1.2.1      utf8_1.1.4           R6_2.5.0            
## [13] DBI_1.1.1            colorspace_2.0-0     nnet_7.3-14         
## [16] withr_2.4.1          gridExtra_2.3        prettyunits_1.1.1   
## [19] tidyselect_1.1.0     processx_3.5.2       curl_4.3            
## [22] compiler_4.0.3       textshaping_0.3.5    cli_3.0.1           
## [25] rvest_0.3.6          xml2_1.3.2           labeling_0.4.2      
## [28] tseries_0.10-48      scales_1.1.1         lmtest_0.9-38       
## [31] fracdiff_1.5-1       quadprog_1.5-8       callr_3.7.0         
## [34] StanHeaders_2.21.0-7 systemfonts_1.0.2    digest_0.6.27       
## [37] rmarkdown_2.6        extraDistr_1.9.1     pkgconfig_2.0.3     
## [40] htmltools_0.5.1.1    dbplyr_2.1.0         highr_0.8           
## [43] readxl_1.3.1         TTR_0.24.2           rstudioapi_0.13     
## [46] quantmod_0.4.18      generics_0.1.0       farver_2.0.3        
## [49] zoo_1.8-8            jsonlite_1.7.2       inline_0.3.17       
## [52] magrittr_2.0.1       loo_2.4.1            munsell_0.5.0       
## [55] fansi_0.4.2          lifecycle_1.0.0      stringi_1.5.3       
## [58] yaml_2.2.1           pkgbuild_1.2.0       grid_4.0.3          
## [61] parallel_4.0.3       crayon_1.4.0         lattice_0.20-41     
## [64] haven_2.3.1          hms_1.0.0            ps_1.5.0            
## [67] knitr_1.33           pillar_1.6.2         codetools_0.2-16    
## [70] stats4_4.0.3         reprex_1.0.0         urca_1.3-0          
## [73] glue_1.4.2           evaluate_0.14        V8_3.4.0            
## [76] RcppParallel_5.0.2   modelr_0.1.8         vctrs_0.3.8         
## [79] cellranger_1.1.0     gtable_0.3.0         assertthat_0.2.1    
## [82] xfun_0.24            broom_0.7.9          ragg_1.1.3          
## [85] timeDate_3043.102    ellipsis_0.3.2