UPDATE: 2022-12-19 20:34:08

1 はじめに

ここでは、下記のドキュメントを参考に、prophetパッケージの基本的な使い方をおさらいすることを目的としています。ゆくゆくは外部予測変数を追加したモデルやクロスバリデーション、パラメタチューニングなどなど、モデルを発展させながら使用方法をまとめていきます。

モデルの数理部分は下記のprophetに関する論文やブログ記事を参照願います。非常にわかりやすいです。

2 ライブラリと関数の読み込み

library(prophet)
library(tidyverse)

head_tail <- function(data, n = 5){
  stopifnot(is.data.frame(data))
  head(data, n = n) %>% print()
  cat(paste(rep("-", 100), collapse = ""), "\n")
  tail(data, n = n) %>% print()
}

3 必要なデータ形式

prophet()に適した形でデータフレームを渡せば、あとは関数の内部でrstanパッケージの関数が計算し、計算計算を結果を返してくれる。そのためか、ドキュメントではR APIと表記されている。

最低限のカラムとして、日付(ds)と予測したい指標(y)を用意すればよいが、カラム名はdsyである必要があるので注意。また、dsは、yyyy-mm-ddyyyy-mm-dd hh:mm:ssフォーマットでなければいけない。

ここでは、ドキュメントにもあるWikipediaのページビューデータの期間を2年分に絞って動かしていく。

# Githubのprophetのリポジトリからサンプルデータをインポート
base_url <- 'https://raw.githubusercontent.com/facebook/prophet/master/examples/'
data_path <- paste0(base_url, 'example_wp_log_peyton_manning.csv')
df <- readr::read_csv(data_path) %>%
  dplyr::filter(ds >= "2011-01-01" & ds <= "2012-12-31")

# 先頭と末尾のデータを表示
head_tail(df, n = 5)
## # A tibble: 5 × 2
##   ds             y
##   <date>     <dbl>
## 1 2011-01-01  9.01
## 2 2011-01-02  9.40
## 3 2011-01-03  9.99
## 4 2011-01-04  9.06
## 5 2011-01-05  8.97
## ---------------------------------------------------------------------------------------------------- 
## # A tibble: 5 × 2
##   ds             y
##   <date>     <dbl>
## 1 2012-12-27  8.97
## 2 2012-12-28  8.68
## 3 2012-12-29  8.53
## 4 2012-12-30  9.49
## 5 2012-12-31 10.1

4 予測モデルの作成

モデル作成のためにprophet()に先程用意しているデータフレームを渡すことで作成できます。モデルを発展させていくためには、用意されている引数やadd_*()を利用することができます。mにはモデルを構築するための様々な要素が計算されていることがわかります。

# デフォルト設定でモデルを作成
m <- prophet(df = df)
names(m)
##  [1] "growth"                  "changepoints"           
##  [3] "n.changepoints"          "changepoint.range"      
##  [5] "yearly.seasonality"      "weekly.seasonality"     
##  [7] "daily.seasonality"       "holidays"               
##  [9] "seasonality.mode"        "seasonality.prior.scale"
## [11] "changepoint.prior.scale" "holidays.prior.scale"   
## [13] "mcmc.samples"            "interval.width"         
## [15] "uncertainty.samples"     "specified.changepoints" 
## [17] "start"                   "y.scale"                
## [19] "logistic.floor"          "t.scale"                
## [21] "changepoints.t"          "seasonalities"          
## [23] "extra_regressors"        "country_holidays"       
## [25] "stan.fit"                "params"                 
## [27] "history"                 "history.dates"          
## [29] "train.holiday.names"     "train.component.cols"   
## [31] "component.modes"         "fit.kwargs"

5 モデルで予測値を計算

構築したモデルで予測を行うための準備として、データフレームを拡張します。このデータは2012-12-31までしかないので、ここではmake_future_dataframe()を利用して、予測期間を含むデータフレームに拡張します。

make_future_dataframe
## function (m, periods, freq = "day", include_history = TRUE) 
## {
##     if (freq == "m") {
##         freq <- "month"
##     }
##     if (is.null(m$history.dates)) {
##         stop("Model must be fit before this can be used.")
##     }
##     dates <- seq(max(m$history.dates), length.out = periods + 
##         1, by = freq)
##     dates <- dates[2:(periods + 1)]
##     if (include_history) {
##         dates <- c(m$history.dates, dates)
##         attr(dates, "tzone") <- "GMT"
##     }
##     return(data.frame(ds = dates))
## }
## <bytecode: 0x1312557b0>
## <environment: namespace:prophet>

関数の中身を見るとわかりますが、この関数は、モデルに渡されたデータフレームの最大日付(m$history.dates)を取得し、指定した期間の日付型のベクトルを生成し、最大日付の末尾にアペンドして、期間を拡張してくれる便利な関数です。

例えば、今回のデータに対して、30日分追加する場合は、下記の日付が最大日付のベクトルにアペンドされることになります。

# make_future_dataframe(m, periods = 30)
dates <- seq(max(m$history.dates), length.out = 30 + 1, by = "day")
dates[2:(30 + 1)]
##  [1] "2013-01-01 GMT" "2013-01-02 GMT" "2013-01-03 GMT" "2013-01-04 GMT"
##  [5] "2013-01-05 GMT" "2013-01-06 GMT" "2013-01-07 GMT" "2013-01-08 GMT"
##  [9] "2013-01-09 GMT" "2013-01-10 GMT" "2013-01-11 GMT" "2013-01-12 GMT"
## [13] "2013-01-13 GMT" "2013-01-14 GMT" "2013-01-15 GMT" "2013-01-16 GMT"
## [17] "2013-01-17 GMT" "2013-01-18 GMT" "2013-01-19 GMT" "2013-01-20 GMT"
## [21] "2013-01-21 GMT" "2013-01-22 GMT" "2013-01-23 GMT" "2013-01-24 GMT"
## [25] "2013-01-25 GMT" "2013-01-26 GMT" "2013-01-27 GMT" "2013-01-28 GMT"
## [29] "2013-01-29 GMT" "2013-01-30 GMT"

実際にmake_future_dataframe()を利用して、予測期間を含むデータフレームに拡張します。

future_df <- make_future_dataframe(m, periods = 30)
head_tail(future_df, n = 5)
##           ds
## 1 2011-01-01
## 2 2011-01-02
## 3 2011-01-03
## 4 2011-01-04
## 5 2011-01-05
## ---------------------------------------------------------------------------------------------------- 
##             ds
## 752 2013-01-26
## 753 2013-01-27
## 754 2013-01-28
## 755 2013-01-29
## 756 2013-01-30

このデータフレームとモデルをpredict()に渡すことで、prophetクラスに対するpredict メソッドが呼び出され、30日分の予測が行われます。

# 予測値を計算
forecast_df <- predict(object = m, df = future_df)

# getS3method("predict", "prophet")
# https://github.com/facebook/prophet/blob/a794018d654402ab6a97cb262e80d347db3485bd/R/R/prophet.R#L1303
# df$yhat <- df$trend * (1 + df$multiplicative_terms) + df$additive_term
# predict.prophetの予測値の計算に最終的に必要なカラム
forecast_df %>% 
  dplyr::select(ds, trend, ends_with("terms"), yhat) %>% 
  head_tail(., n = 5)
##           ds    trend additive_terms multiplicative_terms     yhat
## 1 2011-01-01 7.749329      0.2573800                    0 8.006709
## 2 2011-01-02 7.750771      0.7048030                    0 8.455574
## 3 2011-01-03 7.752214      0.9456609                    0 8.697875
## 4 2011-01-04 7.753656      0.7461379                    0 8.499794
## 5 2011-01-05 7.755099      0.6047914                    0 8.359890
## ---------------------------------------------------------------------------------------------------- 
##             ds    trend additive_terms multiplicative_terms      yhat
## 752 2013-01-26 8.698186      0.9827461                    0  9.680932
## 753 2013-01-27 8.698558      1.3831731                    0 10.081731
## 754 2013-01-28 8.698931      1.5731877                    0 10.272118
## 755 2013-01-29 8.699303      1.3188602                    0 10.018163
## 756 2013-01-30 8.699675      1.1185997                    0  9.818275

6 予測結果を可視化

データフレームとモデルをplot()に渡すことで、prophetクラスに対するplot メソッドが呼び出され、30日分の予測が加えられた、可視化が行われます。

# getS3method("plot", "prophet")
plot(x = m, fcst = forecast_df) +
  labs(title = "Forecasting Wikipedia Page View", y = "PageView", x = "Date")

dyplot.prophet()を利用すれば、dygraphを使ったインタラクティブな可視化を行うことも可能です。

dyplot.prophet(x = m, fcst = forecast_df)

7 モデルの要素を分解して可視化

prophet_plot_components()を使うことで、トレンド、周期、祝日、外部予測変数など、各要素の効果を分解して可視化することもできます。

# weekly_start=1で月曜日始まりに設定
# render_plot=FALSEでグラフをリストに格納し、ばらして可視化
plts <- prophet_plot_components(m = m,
                                         fcst = forecast_df, 
                                         weekly_start = 1,
                                         yearly_start = 0,
                                         render_plot = FALSE)
plts[[1]] + labs(title = "Trend Components")

plts[[2]] + labs(title = "Weekly Seasonality Components")

plts[[3]] + labs(title = "Yearly Seasonality Components")

8 モデルの保存と読み込み

Pythonでよく行われるようなpickleでモデルを保存することも可能です。Rでは、学習済みのモデルをsaveRDS()で保存し、readRDS()で読み込むことで、予測に使用できます。

saveRDS(m, file = "model.RDS")  # Save model
m2 <- readRDS(file = "model.RDS")  # Load mode
# 学習済みのモデルを呼び出して予測
# 60日分の予測を行う
future_df2 <- make_future_dataframe(m2, periods = 60)
forecast_df2 <- predict(object = m2, df = future_df2)
forecast_df2 %>% 
  dplyr::select(ds, trend, ends_with("terms"), yhat) %>% 
  head_tail(., n = 5)
##           ds    trend additive_terms multiplicative_terms     yhat
## 1 2011-01-01 7.749329      0.2573800                    0 8.006709
## 2 2011-01-02 7.750771      0.7048030                    0 8.455574
## 3 2011-01-03 7.752214      0.9456609                    0 8.697875
## 4 2011-01-04 7.753656      0.7461379                    0 8.499794
## 5 2011-01-05 7.755099      0.6047914                    0 8.359890
## ---------------------------------------------------------------------------------------------------- 
##             ds    trend additive_terms multiplicative_terms     yhat
## 782 2013-02-25 8.709357      0.1142789                    0 8.823636
## 783 2013-02-26 8.709729     -0.1252476                    0 8.584482
## 784 2013-02-27 8.710102     -0.2952094                    0 8.414892
## 785 2013-02-28 8.710474     -0.2602026                    0 8.450272
## 786 2013-03-01 8.710847     -0.2300108                    0 8.480836

新たに予測した結果を可視化しておきます。

plot(x = m2, fcst = forecast_df2) +
  labs(title = "Forecasting Wikipedia Page View From Saved Model", y = "PageView", x = "Date")

9 セッション情報

sessionInfo()
## R version 4.2.2 (2022-10-31)
## Platform: aarch64-apple-darwin20 (64-bit)
## Running under: macOS Monterey 12.5
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] forcats_0.5.2   stringr_1.4.1   dplyr_1.0.10    purrr_0.3.5    
##  [5] readr_2.1.3     tidyr_1.2.1     tibble_3.1.8    ggplot2_3.4.0  
##  [9] tidyverse_1.3.2 prophet_1.0     rlang_1.0.6     Rcpp_1.0.9     
## 
## loaded via a namespace (and not attached):
##  [1] matrixStats_0.63.0   fs_1.5.2             xts_0.12.2          
##  [4] lubridate_1.9.0      bit64_4.0.5          httr_1.4.4          
##  [7] rstan_2.21.7         tools_4.2.2          backports_1.4.1     
## [10] bslib_0.4.1          utf8_1.2.2           R6_2.5.1            
## [13] DBI_1.1.3            colorspace_2.0-3     withr_2.5.0         
## [16] tidyselect_1.2.0     gridExtra_2.3        prettyunits_1.1.1   
## [19] processx_3.8.0       bit_4.0.5            curl_4.3.3          
## [22] compiler_4.2.2       textshaping_0.3.6    cli_3.4.1           
## [25] rvest_1.0.3          xml2_1.3.3           labeling_0.4.2      
## [28] sass_0.4.2           dygraphs_1.1.1.6     scales_1.2.1        
## [31] callr_3.7.3          systemfonts_1.0.4    digest_0.6.30       
## [34] StanHeaders_2.21.0-7 rmarkdown_2.18       extraDistr_1.9.1    
## [37] pkgconfig_2.0.3      htmltools_0.5.3      highr_0.9           
## [40] dbplyr_2.2.1         fastmap_1.1.0        htmlwidgets_1.5.4   
## [43] readxl_1.4.1         rstudioapi_0.14      farver_2.1.1        
## [46] jquerylib_0.1.4      generics_0.1.3       zoo_1.8-11          
## [49] jsonlite_1.8.3       vroom_1.6.0          googlesheets4_1.0.1 
## [52] inline_0.3.19        magrittr_2.0.3       loo_2.5.1           
## [55] munsell_0.5.0        fansi_1.0.3          lifecycle_1.0.3     
## [58] stringi_1.7.8        yaml_2.3.6           pkgbuild_1.3.1      
## [61] grid_4.2.2           parallel_4.2.2       crayon_1.5.2        
## [64] lattice_0.20-45      haven_2.5.1          hms_1.1.2           
## [67] knitr_1.41           ps_1.7.2             pillar_1.8.1        
## [70] codetools_0.2-18     stats4_4.2.2         reprex_2.0.2        
## [73] glue_1.6.2           evaluate_0.18        RcppParallel_5.1.5  
## [76] modelr_0.1.10        vctrs_0.5.1          tzdb_0.3.0          
## [79] cellranger_1.1.0     gtable_0.3.1         assertthat_0.2.1    
## [82] cachem_1.0.6         xfun_0.35            broom_1.0.1         
## [85] ragg_1.2.4           googledrive_2.0.0    gargle_1.2.1        
## [88] timechange_0.1.1     ellipsis_0.3.2