UPDATE: 2022-12-19 20:34:08
ここでは、下記のドキュメントを参考に、prophet
パッケージの基本的な使い方をおさらいすることを目的としています。ゆくゆくは外部予測変数を追加したモデルやクロスバリデーション、パラメタチューニングなどなど、モデルを発展させながら使用方法をまとめていきます。
モデルの数理部分は下記のprophet
に関する論文やブログ記事を参照願います。非常にわかりやすいです。
library(prophet)
library(tidyverse)
<- function(data, n = 5){
head_tail stopifnot(is.data.frame(data))
head(data, n = n) %>% print()
cat(paste(rep("-", 100), collapse = ""), "\n")
tail(data, n = n) %>% print()
}
prophet()
に適した形でデータフレームを渡せば、あとは関数の内部でrstan
パッケージの関数が計算し、計算計算を結果を返してくれる。そのためか、ドキュメントではR
APIと表記されている。
最低限のカラムとして、日付(ds
)と予測したい指標(y
)を用意すればよいが、カラム名はds
とy
である必要があるので注意。また、ds
は、yyyy-mm-dd
かyyyy-mm-dd hh:mm:ss
フォーマットでなければいけない。
ここでは、ドキュメントにもあるWikipediaのページビューデータの期間を2年分に絞って動かしていく。
# Githubのprophetのリポジトリからサンプルデータをインポート
<- 'https://raw.githubusercontent.com/facebook/prophet/master/examples/'
base_url <- paste0(base_url, 'example_wp_log_peyton_manning.csv')
data_path <- readr::read_csv(data_path) %>%
df ::filter(ds >= "2011-01-01" & ds <= "2012-12-31")
dplyr
# 先頭と末尾のデータを表示
head_tail(df, n = 5)
## # A tibble: 5 × 2
## ds y
## <date> <dbl>
## 1 2011-01-01 9.01
## 2 2011-01-02 9.40
## 3 2011-01-03 9.99
## 4 2011-01-04 9.06
## 5 2011-01-05 8.97
## ----------------------------------------------------------------------------------------------------
## # A tibble: 5 × 2
## ds y
## <date> <dbl>
## 1 2012-12-27 8.97
## 2 2012-12-28 8.68
## 3 2012-12-29 8.53
## 4 2012-12-30 9.49
## 5 2012-12-31 10.1
モデル作成のためにprophet()
に先程用意しているデータフレームを渡すことで作成できます。モデルを発展させていくためには、用意されている引数やadd_*()
を利用することができます。m
にはモデルを構築するための様々な要素が計算されていることがわかります。
# デフォルト設定でモデルを作成
<- prophet(df = df)
m names(m)
## [1] "growth" "changepoints"
## [3] "n.changepoints" "changepoint.range"
## [5] "yearly.seasonality" "weekly.seasonality"
## [7] "daily.seasonality" "holidays"
## [9] "seasonality.mode" "seasonality.prior.scale"
## [11] "changepoint.prior.scale" "holidays.prior.scale"
## [13] "mcmc.samples" "interval.width"
## [15] "uncertainty.samples" "specified.changepoints"
## [17] "start" "y.scale"
## [19] "logistic.floor" "t.scale"
## [21] "changepoints.t" "seasonalities"
## [23] "extra_regressors" "country_holidays"
## [25] "stan.fit" "params"
## [27] "history" "history.dates"
## [29] "train.holiday.names" "train.component.cols"
## [31] "component.modes" "fit.kwargs"
構築したモデルで予測を行うための準備として、データフレームを拡張します。このデータは2012-12-31までしかないので、ここではmake_future_dataframe()
を利用して、予測期間を含むデータフレームに拡張します。
make_future_dataframe
## function (m, periods, freq = "day", include_history = TRUE)
## {
## if (freq == "m") {
## freq <- "month"
## }
## if (is.null(m$history.dates)) {
## stop("Model must be fit before this can be used.")
## }
## dates <- seq(max(m$history.dates), length.out = periods +
## 1, by = freq)
## dates <- dates[2:(periods + 1)]
## if (include_history) {
## dates <- c(m$history.dates, dates)
## attr(dates, "tzone") <- "GMT"
## }
## return(data.frame(ds = dates))
## }
## <bytecode: 0x1312557b0>
## <environment: namespace:prophet>
関数の中身を見るとわかりますが、この関数は、モデルに渡されたデータフレームの最大日付(m$history.dates
)を取得し、指定した期間の日付型のベクトルを生成し、最大日付の末尾にアペンドして、期間を拡張してくれる便利な関数です。
例えば、今回のデータに対して、30日分追加する場合は、下記の日付が最大日付のベクトルにアペンドされることになります。
# make_future_dataframe(m, periods = 30)
<- seq(max(m$history.dates), length.out = 30 + 1, by = "day")
dates 2:(30 + 1)] dates[
## [1] "2013-01-01 GMT" "2013-01-02 GMT" "2013-01-03 GMT" "2013-01-04 GMT"
## [5] "2013-01-05 GMT" "2013-01-06 GMT" "2013-01-07 GMT" "2013-01-08 GMT"
## [9] "2013-01-09 GMT" "2013-01-10 GMT" "2013-01-11 GMT" "2013-01-12 GMT"
## [13] "2013-01-13 GMT" "2013-01-14 GMT" "2013-01-15 GMT" "2013-01-16 GMT"
## [17] "2013-01-17 GMT" "2013-01-18 GMT" "2013-01-19 GMT" "2013-01-20 GMT"
## [21] "2013-01-21 GMT" "2013-01-22 GMT" "2013-01-23 GMT" "2013-01-24 GMT"
## [25] "2013-01-25 GMT" "2013-01-26 GMT" "2013-01-27 GMT" "2013-01-28 GMT"
## [29] "2013-01-29 GMT" "2013-01-30 GMT"
実際にmake_future_dataframe()
を利用して、予測期間を含むデータフレームに拡張します。
<- make_future_dataframe(m, periods = 30)
future_df head_tail(future_df, n = 5)
## ds
## 1 2011-01-01
## 2 2011-01-02
## 3 2011-01-03
## 4 2011-01-04
## 5 2011-01-05
## ----------------------------------------------------------------------------------------------------
## ds
## 752 2013-01-26
## 753 2013-01-27
## 754 2013-01-28
## 755 2013-01-29
## 756 2013-01-30
このデータフレームとモデルをpredict()
に渡すことで、prophet
クラスに対するpredict
メソッドが呼び出され、30日分の予測が行われます。
# 予測値を計算
<- predict(object = m, df = future_df)
forecast_df
# getS3method("predict", "prophet")
# https://github.com/facebook/prophet/blob/a794018d654402ab6a97cb262e80d347db3485bd/R/R/prophet.R#L1303
# df$yhat <- df$trend * (1 + df$multiplicative_terms) + df$additive_term
# predict.prophetの予測値の計算に最終的に必要なカラム
%>%
forecast_df ::select(ds, trend, ends_with("terms"), yhat) %>%
dplyrhead_tail(., n = 5)
## ds trend additive_terms multiplicative_terms yhat
## 1 2011-01-01 7.749329 0.2573800 0 8.006709
## 2 2011-01-02 7.750771 0.7048030 0 8.455574
## 3 2011-01-03 7.752214 0.9456609 0 8.697875
## 4 2011-01-04 7.753656 0.7461379 0 8.499794
## 5 2011-01-05 7.755099 0.6047914 0 8.359890
## ----------------------------------------------------------------------------------------------------
## ds trend additive_terms multiplicative_terms yhat
## 752 2013-01-26 8.698186 0.9827461 0 9.680932
## 753 2013-01-27 8.698558 1.3831731 0 10.081731
## 754 2013-01-28 8.698931 1.5731877 0 10.272118
## 755 2013-01-29 8.699303 1.3188602 0 10.018163
## 756 2013-01-30 8.699675 1.1185997 0 9.818275
データフレームとモデルをplot()
に渡すことで、prophet
クラスに対するplot
メソッドが呼び出され、30日分の予測が加えられた、可視化が行われます。
# getS3method("plot", "prophet")
plot(x = m, fcst = forecast_df) +
labs(title = "Forecasting Wikipedia Page View", y = "PageView", x = "Date")
dyplot.prophet()
を利用すれば、dygraph
を使ったインタラクティブな可視化を行うことも可能です。
dyplot.prophet(x = m, fcst = forecast_df)
prophet_plot_components()
を使うことで、トレンド、周期、祝日、外部予測変数など、各要素の効果を分解して可視化することもできます。
# weekly_start=1で月曜日始まりに設定
# render_plot=FALSEでグラフをリストに格納し、ばらして可視化
<- prophet_plot_components(m = m,
plts fcst = forecast_df,
weekly_start = 1,
yearly_start = 0,
render_plot = FALSE)
1]] + labs(title = "Trend Components") plts[[
2]] + labs(title = "Weekly Seasonality Components") plts[[
3]] + labs(title = "Yearly Seasonality Components") plts[[
Pythonでよく行われるようなpickle
でモデルを保存することも可能です。Rでは、学習済みのモデルをsaveRDS()
で保存し、readRDS()
で読み込むことで、予測に使用できます。
saveRDS(m, file = "model.RDS") # Save model
<- readRDS(file = "model.RDS") # Load mode
m2 # 学習済みのモデルを呼び出して予測
# 60日分の予測を行う
<- make_future_dataframe(m2, periods = 60)
future_df2 <- predict(object = m2, df = future_df2)
forecast_df2 %>%
forecast_df2 ::select(ds, trend, ends_with("terms"), yhat) %>%
dplyrhead_tail(., n = 5)
## ds trend additive_terms multiplicative_terms yhat
## 1 2011-01-01 7.749329 0.2573800 0 8.006709
## 2 2011-01-02 7.750771 0.7048030 0 8.455574
## 3 2011-01-03 7.752214 0.9456609 0 8.697875
## 4 2011-01-04 7.753656 0.7461379 0 8.499794
## 5 2011-01-05 7.755099 0.6047914 0 8.359890
## ----------------------------------------------------------------------------------------------------
## ds trend additive_terms multiplicative_terms yhat
## 782 2013-02-25 8.709357 0.1142789 0 8.823636
## 783 2013-02-26 8.709729 -0.1252476 0 8.584482
## 784 2013-02-27 8.710102 -0.2952094 0 8.414892
## 785 2013-02-28 8.710474 -0.2602026 0 8.450272
## 786 2013-03-01 8.710847 -0.2300108 0 8.480836
新たに予測した結果を可視化しておきます。
plot(x = m2, fcst = forecast_df2) +
labs(title = "Forecasting Wikipedia Page View From Saved Model", y = "PageView", x = "Date")
sessionInfo()
## R version 4.2.2 (2022-10-31)
## Platform: aarch64-apple-darwin20 (64-bit)
## Running under: macOS Monterey 12.5
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] forcats_0.5.2 stringr_1.4.1 dplyr_1.0.10 purrr_0.3.5
## [5] readr_2.1.3 tidyr_1.2.1 tibble_3.1.8 ggplot2_3.4.0
## [9] tidyverse_1.3.2 prophet_1.0 rlang_1.0.6 Rcpp_1.0.9
##
## loaded via a namespace (and not attached):
## [1] matrixStats_0.63.0 fs_1.5.2 xts_0.12.2
## [4] lubridate_1.9.0 bit64_4.0.5 httr_1.4.4
## [7] rstan_2.21.7 tools_4.2.2 backports_1.4.1
## [10] bslib_0.4.1 utf8_1.2.2 R6_2.5.1
## [13] DBI_1.1.3 colorspace_2.0-3 withr_2.5.0
## [16] tidyselect_1.2.0 gridExtra_2.3 prettyunits_1.1.1
## [19] processx_3.8.0 bit_4.0.5 curl_4.3.3
## [22] compiler_4.2.2 textshaping_0.3.6 cli_3.4.1
## [25] rvest_1.0.3 xml2_1.3.3 labeling_0.4.2
## [28] sass_0.4.2 dygraphs_1.1.1.6 scales_1.2.1
## [31] callr_3.7.3 systemfonts_1.0.4 digest_0.6.30
## [34] StanHeaders_2.21.0-7 rmarkdown_2.18 extraDistr_1.9.1
## [37] pkgconfig_2.0.3 htmltools_0.5.3 highr_0.9
## [40] dbplyr_2.2.1 fastmap_1.1.0 htmlwidgets_1.5.4
## [43] readxl_1.4.1 rstudioapi_0.14 farver_2.1.1
## [46] jquerylib_0.1.4 generics_0.1.3 zoo_1.8-11
## [49] jsonlite_1.8.3 vroom_1.6.0 googlesheets4_1.0.1
## [52] inline_0.3.19 magrittr_2.0.3 loo_2.5.1
## [55] munsell_0.5.0 fansi_1.0.3 lifecycle_1.0.3
## [58] stringi_1.7.8 yaml_2.3.6 pkgbuild_1.3.1
## [61] grid_4.2.2 parallel_4.2.2 crayon_1.5.2
## [64] lattice_0.20-45 haven_2.5.1 hms_1.1.2
## [67] knitr_1.41 ps_1.7.2 pillar_1.8.1
## [70] codetools_0.2-18 stats4_4.2.2 reprex_2.0.2
## [73] glue_1.6.2 evaluate_0.18 RcppParallel_5.1.5
## [76] modelr_0.1.10 vctrs_0.5.1 tzdb_0.3.0
## [79] cellranger_1.1.0 gtable_0.3.1 assertthat_0.2.1
## [82] cachem_1.0.6 xfun_0.35 broom_1.0.1
## [85] ragg_1.2.4 googledrive_2.0.0 gargle_1.2.1
## [88] timechange_0.1.1 ellipsis_0.3.2