UPDATE: 2021-09-18 04:55:48
ここでは、下記のドキュメントを参考に、prophet
パッケージの基本的な使い方をおさらいすることを目的としています。ゆくゆくは外部予測変数を追加したモデルやクロスバリデーション、パラメタチューニングなどなど、モデルを発展させながら使用方法をまとめていきます。
モデルの数理部分は下記のprophet
に関する論文やブログ記事を参照願います。非常にわかりやすいです。
library(prophet)
library(tidyverse)
<- function(data, n = 5){
head_tail stopifnot(is.data.frame(data))
head(data, n = n) %>% print()
cat(paste(rep("-", 100), collapse = ""), "\n")
tail(data, n = n) %>% print()
}
今回はProphetのモデル検証についてまとめていきます。Prophetには時系列クロスバリデーションを行うための関数が予め用意されているので、ここではそのcross_validation()
を使うことにします。
cross_validation()
は、下記の手順で行われます。
initial
引数でモデルの学習する期間を指定します。initial
の終了時点がカットオフポイントになります。horizon
引数で予測範囲を指定します。period
引数でカットオフポイントを置く間隔を指定します。例えば、下記のように指定した場合、730日分のデータで学習し、そこから365日分の予測を行います。そして、カットオフポイントを180日ごとにずらして、再度予測を行います。これを繰り返すことでクロスバリデーションを行います。
cross_validation(m, initial = 730, period = 180, horizon = 365, units = 'days')
実際にcross_validation()
を使ってみます。cross_validation()
の説明のために、わかりやすいサンプルデータを作って使うことにします。
<- tibble(
df ds = seq(as.Date("2021-01-01"), by = "day", length.out = 300),
y = runif(300)
)head_tail(df)
## # A tibble: 5 × 2
## ds y
## <date> <dbl>
## 1 2021-01-01 0.861
## 2 2021-01-02 0.395
## 3 2021-01-03 0.599
## 4 2021-01-04 0.969
## 5 2021-01-05 0.775
## ----------------------------------------------------------------------------------------------------
## # A tibble: 5 × 2
## ds y
## <date> <dbl>
## 1 2021-10-23 0.463
## 2 2021-10-24 0.681
## 3 2021-10-25 0.301
## 4 2021-10-26 0.347
## 5 2021-10-27 0.558
cross_validation()
を使うためには、まずはモデルを学習させる必要があるので、学習されたモデルを作ってきます。
<- prophet(df = df) m
このサンプルデータは2021-10-27
まであるということを念頭に、1日目(2021-01-01
)から200日目(2021-07-29
)までのデータで学習します。そして、201日目(2021-07-30
)から20日分(horizon
)の予測を行います。予測が終わったら、次に10日分カットオフをずらして(period
)して、211日目(2021-08-09
)から20日分(horizon
)の予測を行います。予測が終わったら、先程同様、10日分カットオフをずらして(period
)して、221日目(2021-08-19
)から20日分(horizon
)の予測を行います。
<- cross_validation(m, initial = 200, period = 10, horizon = 20, units = 'days')
df_cv head_tail(df_cv)
## # A tibble: 5 × 6
## y ds yhat yhat_lower yhat_upper cutoff
## <dbl> <dttm> <dbl> <dbl> <dbl> <dttm>
## 1 0.226 2021-07-30 00:00:00 0.422 0.0540 0.835 2021-07-29 00:00:00
## 2 0.177 2021-07-31 00:00:00 0.484 0.101 0.869 2021-07-29 00:00:00
## 3 0.693 2021-08-01 00:00:00 0.527 0.150 0.917 2021-07-29 00:00:00
## 4 0.717 2021-08-02 00:00:00 0.393 0.00344 0.758 2021-07-29 00:00:00
## 5 0.810 2021-08-03 00:00:00 0.518 0.140 0.888 2021-07-29 00:00:00
## ----------------------------------------------------------------------------------------------------
## # A tibble: 5 × 6
## y ds yhat yhat_lower yhat_upper cutoff
## <dbl> <dttm> <dbl> <dbl> <dbl> <dttm>
## 1 0.463 2021-10-23 00:00:00 0.454 0.0976 0.828 2021-10-07 00:00:00
## 2 0.681 2021-10-24 00:00:00 0.535 0.139 0.905 2021-10-07 00:00:00
## 3 0.301 2021-10-25 00:00:00 0.441 0.0636 0.801 2021-10-07 00:00:00
## 4 0.347 2021-10-26 00:00:00 0.487 0.113 0.884 2021-10-07 00:00:00
## 5 0.558 2021-10-27 00:00:00 0.493 0.0834 0.880 2021-10-07 00:00:00
言葉ではわかりにくいかもしれないのですが、下記の結果を見たほうが早いかもしれません。
%>%
df_cv ::group_by(cutoff) %>%
dplyr::summarise(
dplyrstart = min(ds),
end = max(ds),
cnt = n()
%>%
) ::mutate(
dplyrlag_cutoff = lag(cutoff),
diff = difftime(cutoff, lag_cutoff)
)
## # A tibble: 8 × 6
## cutoff start end cnt
## <dttm> <dttm> <dttm> <int>
## 1 2021-07-29 00:00:00 2021-07-30 00:00:00 2021-08-18 00:00:00 20
## 2 2021-08-08 00:00:00 2021-08-09 00:00:00 2021-08-28 00:00:00 20
## 3 2021-08-18 00:00:00 2021-08-19 00:00:00 2021-09-07 00:00:00 20
## 4 2021-08-28 00:00:00 2021-08-29 00:00:00 2021-09-17 00:00:00 20
## 5 2021-09-07 00:00:00 2021-09-08 00:00:00 2021-09-27 00:00:00 20
## 6 2021-09-17 00:00:00 2021-09-18 00:00:00 2021-10-07 00:00:00 20
## 7 2021-09-27 00:00:00 2021-09-28 00:00:00 2021-10-17 00:00:00 20
## 8 2021-10-07 00:00:00 2021-10-08 00:00:00 2021-10-27 00:00:00 20
## # … with 2 more variables: lag_cutoff <dttm>, diff <drtn>
performance_metrics()
を使用することで、クロスバリデーションの結果を一般的な精度指標と共に表示してくれます。ただの乱数を振っているサンプルデータでは、精度指標のイメージがつきにくいかもしれないので、データを変更しておきます。
# Githubのprophetのリポジトリからサンプルデータをインポート
<- 'https://raw.githubusercontent.com/facebook/prophet/master/examples/'
base_url <- paste0(base_url, 'example_wp_log_R.csv')
data_path <- readr::read_csv(data_path) %>% dplyr::arrange(ds) %>% dplyr::filter(ds >= "2011-01-01" & ds <= "2012-12-31")
df <- prophet(df)
m <- cross_validation(m, initial = 120, period = 30, horizon = 5, units = 'days')
df_cv <- performance_metrics(df_cv, rolling_window = 0.1)
df_perform df_perform
## horizon mse rmse mae mape mdape smape
## 1 1 days 0.01590879 0.1261300 0.08366585 0.01153609 0.008260400 0.01164784
## 2 2 days 0.02108555 0.1452087 0.09743137 0.01335100 0.009528242 0.01348219
## 3 3 days 0.02561772 0.1600554 0.11096890 0.01520249 0.010414497 0.01538287
## 4 4 days 0.02569931 0.1603100 0.11884814 0.01620516 0.012524142 0.01630367
## 5 5 days 0.02857261 0.1690344 0.12785619 0.01759062 0.015004419 0.01768091
## coverage
## 1 0.8571429
## 2 0.8571429
## 3 0.8095238
## 4 0.7619048
## 5 0.8095238
performance_metrics()
は各クロスバリデーションのスライスの開始日からの日数ごとにグルーピングし、精度指標を計算しています。
%>%
df_perform ::select(horizon, mae) %>%
dplyr::bind_cols(
dplyr%>%
df_cv ::group_by(cutoff) %>%
dplyr::mutate(date_idx = row_number(),
dplyrmae = abs(y - yhat)) %>%
::ungroup() %>%
dplyr::group_by(date_idx) %>%
dplyr::summarise(my_mae = mean(mae))
dplyr )
## horizon mae date_idx my_mae
## 1 1 days 0.08366585 1 0.08366585
## 2 2 days 0.09743137 2 0.09743137
## 3 3 days 0.11096890 3 0.11096890
## 4 4 days 0.11884814 4 0.11884814
## 5 5 days 0.12785619 5 0.12785619
# rolling_windowの値に応じて、平均か移動平均に変更される
# %>% mutate(ma2 = slider::slide_vec(.x = mean_mae, .f = mean, .before = 2))
場合によっては数値が一致しない場合がありますが、それはperformance_metrics()
のrolling_window
の値に応じて、performance_metrics()
が移動平均を使用するためです。mae()
の中身を見ると、条件分岐で平均か移動平均かを決定してるようです。
# https://github.com/facebook/prophet/blob/17dbb86ab023e451dc40da343def788c9cda745c/R/R/diagnostics.R#L486
#' @keywords internal
mae <- function(df, w) {
ae <- abs(df$y - df$yhat)
if (w < 0) {
return(data.frame(horizon = df$horizon, mae = ae))
}
return(rolling_mean_by_h(x = ae, h = df$horizon, w = w, name = 'mae'))
}
plot_cross_validation_metric()
を使うことで、精度指標を可視化できます。予測期間が短い場合、x軸がHorizon(days)
ではなくHorizon(hours)
となりますが、このグラフが表現していることは、最初の25時間時点のMAEは0.08であり、時間が経過するとともにMAEは悪化していき、120時間時点のMAEは0.12となることを意味しています。
plot_cross_validation_metric(df_cv, metric = "mae")
plot_cross_validation_metric()
は、もちろんmae
以外の指標も可視化できます。
modeltime
パッケージやTidymodels
パッケージの関数を利用すれば、より簡単にパラメタチューニングが行えるのかもしれませんが、現状、勉強不足で知らないので、自分で書いていくことにします。ですが、パラメタを用意して順番にモデルを学習させていくだけです。
ここではchangepoint_prior_scale
とseasonality_prior_scale
を調整することにします。Prophetのドキュメントにも記載されている通り、チューニングしても精度改善が見込めない指標もありますので、それらの指標は後で触れることにします。
<- seq(0.0, 0.5, 0.2)
changepoint_prior_scale 1]] <- changepoint_prior_scale[[1]] + 0.01
changepoint_prior_scale[[
<- seq(0.0, 10.0, 5)
seasonality_prior_scale 1]] <- seasonality_prior_scale[[1]] + 0.1
seasonality_prior_scale[[
<- expand_grid(
param_grid
changepoint_prior_scale,
seasonality_prior_scale
)<- vector(mode = "integer", length = nrow(param_grid))
mean_mae
for (i in 1:nrow(param_grid)) {
<- param_grid$changepoint_prior_scale[[i]]
c <- param_grid$seasonality_prior_scale[[i]]
s
cat('iteration', i, ':', 'changepoint_prior_scale:', c, ' / ', 'seasonality_prior_scale:', s, '\n')
set.seed(1989)
<- prophet(df, changepoint.prior.scale = c, seasonality.prior.scale = s)
m <- cross_validation(m, initial = 120, period = 30, horizon = 5, units = 'days')
df_cv <- performance_metrics(df_cv, metrics = "mae")
df_perform <- mean(df_perform$mae)
mean_mae[[i]] }
## iteration 1 : changepoint_prior_scale: 0.01 / seasonality_prior_scale: 0.1
## iteration 2 : changepoint_prior_scale: 0.01 / seasonality_prior_scale: 5
## iteration 3 : changepoint_prior_scale: 0.01 / seasonality_prior_scale: 10
## iteration 4 : changepoint_prior_scale: 0.2 / seasonality_prior_scale: 0.1
## iteration 5 : changepoint_prior_scale: 0.2 / seasonality_prior_scale: 5
## iteration 6 : changepoint_prior_scale: 0.2 / seasonality_prior_scale: 10
## iteration 7 : changepoint_prior_scale: 0.4 / seasonality_prior_scale: 0.1
## iteration 8 : changepoint_prior_scale: 0.4 / seasonality_prior_scale: 5
## iteration 9 : changepoint_prior_scale: 0.4 / seasonality_prior_scale: 10
<- cbind(param_grid, mean_mae)
df_tuning df_tuning
## changepoint_prior_scale seasonality_prior_scale mean_mae
## 1 0.01 0.1 0.1036654
## 2 0.01 5.0 0.1027140
## 3 0.01 10.0 0.1036441
## 4 0.20 0.1 0.1112593
## 5 0.20 5.0 0.1123704
## 6 0.20 10.0 0.1163200
## 7 0.40 0.1 0.1071377
## 8 0.40 5.0 0.1088326
## 9 0.40 10.0 0.1116318
1番MAEが小さい結果を取り出したいときは、which.min()
が便利です。
which.min(df_tuning$mean_mae), ] df_tuning[
## changepoint_prior_scale seasonality_prior_scale mean_mae
## 2 0.01 5 0.102714
Prophetには、調整を検討したほうが良いパラメタがあります。下記のドキュメントを参考に、調整したほうが良いパラメタに絞って、まとめておきます。ベイズ最適化でパラメタをチューニングしている記事もありましたので、記載しておきいます。
最も影響のあるパラメタです。トレンドの柔軟性、特にトレンドの変化点でトレンドがどの程度変化するかを決定するパラメタです。小さすぎると、トレンドが不十分になり、大きすぎると、トレンドが過剰適合します。最も極端な場合、トレンドが毎年の季節性を捉えてしまう可能性があります。デフォルトの0.05です。0.001から0.5の範囲が妥当なチューニング範囲とのことです。
このパラメタは、季節性の柔軟性を制御するパラメタ。値が大きいと、季節性が大きな変動にフィットし、値が小さいと季節性の大きさが小さくなります。デフォルトは10です。チューニングの妥当な範囲は0.01から10です。
休日効果の柔軟性を制御するパラメタ。Seasonality_prior_scale
と同様に、デフォルトは10です。seasonality_prior_scale
の場合と同様に、0.01から10の範囲で調整できます。
sessionInfo()
## R version 4.0.3 (2020-10-10)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Big Sur 10.16
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
##
## locale:
## [1] ja_JP.UTF-8/ja_JP.UTF-8/ja_JP.UTF-8/C/ja_JP.UTF-8/ja_JP.UTF-8
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] forcats_0.5.1 stringr_1.4.0 dplyr_1.0.7 purrr_0.3.4
## [5] readr_1.4.0 tidyr_1.1.3 tibble_3.1.3 ggplot2_3.3.3
## [9] tidyverse_1.3.0 prophet_1.0 rlang_0.4.10 Rcpp_1.0.6
##
## loaded via a namespace (and not attached):
## [1] httr_1.4.2 jsonlite_1.7.2 modelr_0.1.8
## [4] RcppParallel_5.0.2 StanHeaders_2.21.0-7 assertthat_0.2.1
## [7] highr_0.8 stats4_4.0.3 cellranger_1.1.0
## [10] yaml_2.2.1 pillar_1.6.2 backports_1.2.1
## [13] glue_1.4.2 digest_0.6.27 rvest_0.3.6
## [16] colorspace_2.0-0 htmltools_0.5.1.1 pkgconfig_2.0.3
## [19] rstan_2.21.2 broom_0.7.9 haven_2.3.1
## [22] scales_1.1.1 processx_3.5.2 farver_2.0.3
## [25] generics_0.1.0 ellipsis_0.3.2 withr_2.4.1
## [28] cli_3.0.1 magrittr_2.0.1 crayon_1.4.0
## [31] readxl_1.3.1 evaluate_0.14 ps_1.5.0
## [34] fs_1.5.0 fansi_0.4.2 xml2_1.3.2
## [37] pkgbuild_1.2.0 textshaping_0.3.5 tools_4.0.3
## [40] loo_2.4.1 prettyunits_1.1.1 hms_1.0.0
## [43] lifecycle_1.0.0 matrixStats_0.58.0 extraDistr_1.9.1
## [46] V8_3.4.0 munsell_0.5.0 reprex_1.0.0
## [49] callr_3.7.0 compiler_4.0.3 systemfonts_1.0.2
## [52] grid_4.0.3 rstudioapi_0.13 labeling_0.4.2
## [55] rmarkdown_2.6 gtable_0.3.0 codetools_0.2-16
## [58] inline_0.3.17 DBI_1.1.1 curl_4.3
## [61] R6_2.5.0 gridExtra_2.3 lubridate_1.7.9.2
## [64] knitr_1.33 utf8_1.1.4 ragg_1.1.3
## [67] stringi_1.5.3 parallel_4.0.3 vctrs_0.3.8
## [70] dbplyr_2.1.0 tidyselect_1.1.0 xfun_0.24