UPDATE: 2021-09-18 04:55:48

1 はじめに

ここでは、下記のドキュメントを参考に、prophetパッケージの基本的な使い方をおさらいすることを目的としています。ゆくゆくは外部予測変数を追加したモデルやクロスバリデーション、パラメタチューニングなどなど、モデルを発展させながら使用方法をまとめていきます。

モデルの数理部分は下記のprophetに関する論文やブログ記事を参照願います。非常にわかりやすいです。

2 ライブラリと関数の読み込み

library(prophet)
library(tidyverse)

head_tail <- function(data, n = 5){
  stopifnot(is.data.frame(data))
  head(data, n = n) %>% print()
  cat(paste(rep("-", 100), collapse = ""), "\n")
  tail(data, n = n) %>% print()
}

3 モデル検証

今回はProphetのモデル検証についてまとめていきます。Prophetには時系列クロスバリデーションを行うための関数が予め用意されているので、ここではそのcross_validation()を使うことにします。

cross_validation()は、下記の手順で行われます。

例えば、下記のように指定した場合、730日分のデータで学習し、そこから365日分の予測を行います。そして、カットオフポイントを180日ごとにずらして、再度予測を行います。これを繰り返すことでクロスバリデーションを行います。

cross_validation(m, initial = 730, period = 180, horizon = 365, units = 'days')

実際にcross_validation()を使ってみます。cross_validation()の説明のために、わかりやすいサンプルデータを作って使うことにします。

df <- tibble(
  ds = seq(as.Date("2021-01-01"), by = "day", length.out = 300),
  y = runif(300)
)
head_tail(df)
## # A tibble: 5 × 2
##   ds             y
##   <date>     <dbl>
## 1 2021-01-01 0.861
## 2 2021-01-02 0.395
## 3 2021-01-03 0.599
## 4 2021-01-04 0.969
## 5 2021-01-05 0.775
## ---------------------------------------------------------------------------------------------------- 
## # A tibble: 5 × 2
##   ds             y
##   <date>     <dbl>
## 1 2021-10-23 0.463
## 2 2021-10-24 0.681
## 3 2021-10-25 0.301
## 4 2021-10-26 0.347
## 5 2021-10-27 0.558

cross_validation()を使うためには、まずはモデルを学習させる必要があるので、学習されたモデルを作ってきます。

m <- prophet(df = df)

このサンプルデータは2021-10-27まであるということを念頭に、1日目(2021-01-01)から200日目(2021-07-29)までのデータで学習します。そして、201日目(2021-07-30)から20日分(horizon)の予測を行います。予測が終わったら、次に10日分カットオフをずらして(period)して、211日目(2021-08-09)から20日分(horizon)の予測を行います。予測が終わったら、先程同様、10日分カットオフをずらして(period)して、221日目(2021-08-19)から20日分(horizon)の予測を行います。

df_cv <- cross_validation(m, initial = 200, period = 10, horizon = 20, units = 'days')
head_tail(df_cv)
## # A tibble: 5 × 6
##       y ds                   yhat yhat_lower yhat_upper cutoff             
##   <dbl> <dttm>              <dbl>      <dbl>      <dbl> <dttm>             
## 1 0.226 2021-07-30 00:00:00 0.422    0.0540       0.835 2021-07-29 00:00:00
## 2 0.177 2021-07-31 00:00:00 0.484    0.101        0.869 2021-07-29 00:00:00
## 3 0.693 2021-08-01 00:00:00 0.527    0.150        0.917 2021-07-29 00:00:00
## 4 0.717 2021-08-02 00:00:00 0.393    0.00344      0.758 2021-07-29 00:00:00
## 5 0.810 2021-08-03 00:00:00 0.518    0.140        0.888 2021-07-29 00:00:00
## ---------------------------------------------------------------------------------------------------- 
## # A tibble: 5 × 6
##       y ds                   yhat yhat_lower yhat_upper cutoff             
##   <dbl> <dttm>              <dbl>      <dbl>      <dbl> <dttm>             
## 1 0.463 2021-10-23 00:00:00 0.454     0.0976      0.828 2021-10-07 00:00:00
## 2 0.681 2021-10-24 00:00:00 0.535     0.139       0.905 2021-10-07 00:00:00
## 3 0.301 2021-10-25 00:00:00 0.441     0.0636      0.801 2021-10-07 00:00:00
## 4 0.347 2021-10-26 00:00:00 0.487     0.113       0.884 2021-10-07 00:00:00
## 5 0.558 2021-10-27 00:00:00 0.493     0.0834      0.880 2021-10-07 00:00:00

言葉ではわかりにくいかもしれないのですが、下記の結果を見たほうが早いかもしれません。

df_cv %>% 
  dplyr::group_by(cutoff) %>% 
  dplyr::summarise(
    start = min(ds),
    end = max(ds),
    cnt = n()
    ) %>% 
  dplyr::mutate(
    lag_cutoff = lag(cutoff),
    diff = difftime(cutoff, lag_cutoff)
    )
## # A tibble: 8 × 6
##   cutoff              start               end                   cnt
##   <dttm>              <dttm>              <dttm>              <int>
## 1 2021-07-29 00:00:00 2021-07-30 00:00:00 2021-08-18 00:00:00    20
## 2 2021-08-08 00:00:00 2021-08-09 00:00:00 2021-08-28 00:00:00    20
## 3 2021-08-18 00:00:00 2021-08-19 00:00:00 2021-09-07 00:00:00    20
## 4 2021-08-28 00:00:00 2021-08-29 00:00:00 2021-09-17 00:00:00    20
## 5 2021-09-07 00:00:00 2021-09-08 00:00:00 2021-09-27 00:00:00    20
## 6 2021-09-17 00:00:00 2021-09-18 00:00:00 2021-10-07 00:00:00    20
## 7 2021-09-27 00:00:00 2021-09-28 00:00:00 2021-10-17 00:00:00    20
## 8 2021-10-07 00:00:00 2021-10-08 00:00:00 2021-10-27 00:00:00    20
## # … with 2 more variables: lag_cutoff <dttm>, diff <drtn>

performance_metrics()を使用することで、クロスバリデーションの結果を一般的な精度指標と共に表示してくれます。ただの乱数を振っているサンプルデータでは、精度指標のイメージがつきにくいかもしれないので、データを変更しておきます。

# Githubのprophetのリポジトリからサンプルデータをインポート
base_url <- 'https://raw.githubusercontent.com/facebook/prophet/master/examples/'
data_path <- paste0(base_url, 'example_wp_log_R.csv')
df <- readr::read_csv(data_path) %>% dplyr::arrange(ds) %>% dplyr::filter(ds >= "2011-01-01" & ds <= "2012-12-31")
m <- prophet(df)
df_cv <- cross_validation(m, initial = 120, period = 30, horizon = 5, units = 'days')
df_perform <- performance_metrics(df_cv, rolling_window = 0.1)
df_perform 
##   horizon        mse      rmse        mae       mape       mdape      smape
## 1  1 days 0.01590879 0.1261300 0.08366585 0.01153609 0.008260400 0.01164784
## 2  2 days 0.02108555 0.1452087 0.09743137 0.01335100 0.009528242 0.01348219
## 3  3 days 0.02561772 0.1600554 0.11096890 0.01520249 0.010414497 0.01538287
## 4  4 days 0.02569931 0.1603100 0.11884814 0.01620516 0.012524142 0.01630367
## 5  5 days 0.02857261 0.1690344 0.12785619 0.01759062 0.015004419 0.01768091
##    coverage
## 1 0.8571429
## 2 0.8571429
## 3 0.8095238
## 4 0.7619048
## 5 0.8095238

performance_metrics()は各クロスバリデーションのスライスの開始日からの日数ごとにグルーピングし、精度指標を計算しています。

df_perform %>%
  dplyr::select(horizon, mae) %>% 
  dplyr::bind_cols(
  df_cv %>%
    dplyr::group_by(cutoff) %>%
    dplyr::mutate(date_idx = row_number(),
                  mae = abs(y - yhat)) %>%
    dplyr::ungroup() %>%
    dplyr::group_by(date_idx) %>%
    dplyr::summarise(my_mae = mean(mae))
  )
##   horizon        mae date_idx     my_mae
## 1  1 days 0.08366585        1 0.08366585
## 2  2 days 0.09743137        2 0.09743137
## 3  3 days 0.11096890        3 0.11096890
## 4  4 days 0.11884814        4 0.11884814
## 5  5 days 0.12785619        5 0.12785619
  # rolling_windowの値に応じて、平均か移動平均に変更される
  # %>% mutate(ma2 = slider::slide_vec(.x = mean_mae, .f = mean, .before = 2))

場合によっては数値が一致しない場合がありますが、それはperformance_metrics()rolling_windowの値に応じて、performance_metrics()が移動平均を使用するためです。mae()の中身を見ると、条件分岐で平均か移動平均かを決定してるようです。

# https://github.com/facebook/prophet/blob/17dbb86ab023e451dc40da343def788c9cda745c/R/R/diagnostics.R#L486
#' @keywords internal
mae <- function(df, w) {
  ae <- abs(df$y - df$yhat)
  if (w < 0) {
    return(data.frame(horizon = df$horizon, mae = ae))
  }
  return(rolling_mean_by_h(x = ae, h = df$horizon, w = w, name = 'mae'))
}

plot_cross_validation_metric()を使うことで、精度指標を可視化できます。予測期間が短い場合、x軸がHorizon(days)ではなくHorizon(hours)となりますが、このグラフが表現していることは、最初の25時間時点のMAEは0.08であり、時間が経過するとともにMAEは悪化していき、120時間時点のMAEは0.12となることを意味しています。

plot_cross_validation_metric(df_cv, metric = "mae")

plot_cross_validation_metric()は、もちろんmae以外の指標も可視化できます。

4 パラメタチューニング

modeltimeパッケージやTidymodelsパッケージの関数を利用すれば、より簡単にパラメタチューニングが行えるのかもしれませんが、現状、勉強不足で知らないので、自分で書いていくことにします。ですが、パラメタを用意して順番にモデルを学習させていくだけです。

ここではchangepoint_prior_scaleseasonality_prior_scaleを調整することにします。Prophetのドキュメントにも記載されている通り、チューニングしても精度改善が見込めない指標もありますので、それらの指標は後で触れることにします。

changepoint_prior_scale <- seq(0.0, 0.5, 0.2)
changepoint_prior_scale[[1]] <- changepoint_prior_scale[[1]] + 0.01

seasonality_prior_scale <- seq(0.0, 10.0, 5)
seasonality_prior_scale[[1]] <- seasonality_prior_scale[[1]] + 0.1

param_grid <- expand_grid(
  changepoint_prior_scale, 
  seasonality_prior_scale
)
mean_mae <- vector(mode = "integer", length = nrow(param_grid))

for (i in 1:nrow(param_grid)) {
  c <- param_grid$changepoint_prior_scale[[i]]
  s <- param_grid$seasonality_prior_scale[[i]]
  
  cat('iteration', i, ':', 'changepoint_prior_scale:', c, ' / ', 'seasonality_prior_scale:', s, '\n')
  
  set.seed(1989)
  m <- prophet(df, changepoint.prior.scale = c, seasonality.prior.scale = s)
  df_cv <- cross_validation(m, initial = 120, period = 30, horizon = 5, units = 'days')
  df_perform <- performance_metrics(df_cv, metrics = "mae")
  mean_mae[[i]] <- mean(df_perform$mae)
}
## iteration 1 : changepoint_prior_scale: 0.01  /  seasonality_prior_scale: 0.1 
## iteration 2 : changepoint_prior_scale: 0.01  /  seasonality_prior_scale: 5 
## iteration 3 : changepoint_prior_scale: 0.01  /  seasonality_prior_scale: 10 
## iteration 4 : changepoint_prior_scale: 0.2  /  seasonality_prior_scale: 0.1 
## iteration 5 : changepoint_prior_scale: 0.2  /  seasonality_prior_scale: 5 
## iteration 6 : changepoint_prior_scale: 0.2  /  seasonality_prior_scale: 10 
## iteration 7 : changepoint_prior_scale: 0.4  /  seasonality_prior_scale: 0.1 
## iteration 8 : changepoint_prior_scale: 0.4  /  seasonality_prior_scale: 5 
## iteration 9 : changepoint_prior_scale: 0.4  /  seasonality_prior_scale: 10
df_tuning <- cbind(param_grid, mean_mae)
df_tuning
##   changepoint_prior_scale seasonality_prior_scale  mean_mae
## 1                    0.01                     0.1 0.1036654
## 2                    0.01                     5.0 0.1027140
## 3                    0.01                    10.0 0.1036441
## 4                    0.20                     0.1 0.1112593
## 5                    0.20                     5.0 0.1123704
## 6                    0.20                    10.0 0.1163200
## 7                    0.40                     0.1 0.1071377
## 8                    0.40                     5.0 0.1088326
## 9                    0.40                    10.0 0.1116318

1番MAEが小さい結果を取り出したいときは、which.min()が便利です。

df_tuning[which.min(df_tuning$mean_mae), ]
##   changepoint_prior_scale seasonality_prior_scale mean_mae
## 2                    0.01                       5 0.102714

Prophetには、調整を検討したほうが良いパラメタがあります。下記のドキュメントを参考に、調整したほうが良いパラメタに絞って、まとめておきます。ベイズ最適化でパラメタをチューニングしている記事もありましたので、記載しておきいます。

4.1 changepoint_prior_scale

最も影響のあるパラメタです。トレンドの柔軟性、特にトレンドの変化点でトレンドがどの程度変化するかを決定するパラメタです。小さすぎると、トレンドが不十分になり、大きすぎると、トレンドが過剰適合します。最も極端な場合、トレンドが毎年の季節性を捉えてしまう可能性があります。デフォルトの0.05です。0.001から0.5の範囲が妥当なチューニング範囲とのことです。

4.2 seasonality_prior_scale

このパラメタは、季節性の柔軟性を制御するパラメタ。値が大きいと、季節性が大きな変動にフィットし、値が小さいと季節性の大きさが小さくなります。デフォルトは10です。チューニングの妥当な範囲は0.01から10です。

4.3 holidays_prior_scale

休日効果の柔軟性を制御するパラメタ。Seasonality_prior_scaleと同様に、デフォルトは10です。seasonality_prior_scaleの場合と同様に、0.01から10の範囲で調整できます。

5 セッション情報

sessionInfo()
## R version 4.0.3 (2020-10-10)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Big Sur 10.16
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] ja_JP.UTF-8/ja_JP.UTF-8/ja_JP.UTF-8/C/ja_JP.UTF-8/ja_JP.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] forcats_0.5.1   stringr_1.4.0   dplyr_1.0.7     purrr_0.3.4    
##  [5] readr_1.4.0     tidyr_1.1.3     tibble_3.1.3    ggplot2_3.3.3  
##  [9] tidyverse_1.3.0 prophet_1.0     rlang_0.4.10    Rcpp_1.0.6     
## 
## loaded via a namespace (and not attached):
##  [1] httr_1.4.2           jsonlite_1.7.2       modelr_0.1.8        
##  [4] RcppParallel_5.0.2   StanHeaders_2.21.0-7 assertthat_0.2.1    
##  [7] highr_0.8            stats4_4.0.3         cellranger_1.1.0    
## [10] yaml_2.2.1           pillar_1.6.2         backports_1.2.1     
## [13] glue_1.4.2           digest_0.6.27        rvest_0.3.6         
## [16] colorspace_2.0-0     htmltools_0.5.1.1    pkgconfig_2.0.3     
## [19] rstan_2.21.2         broom_0.7.9          haven_2.3.1         
## [22] scales_1.1.1         processx_3.5.2       farver_2.0.3        
## [25] generics_0.1.0       ellipsis_0.3.2       withr_2.4.1         
## [28] cli_3.0.1            magrittr_2.0.1       crayon_1.4.0        
## [31] readxl_1.3.1         evaluate_0.14        ps_1.5.0            
## [34] fs_1.5.0             fansi_0.4.2          xml2_1.3.2          
## [37] pkgbuild_1.2.0       textshaping_0.3.5    tools_4.0.3         
## [40] loo_2.4.1            prettyunits_1.1.1    hms_1.0.0           
## [43] lifecycle_1.0.0      matrixStats_0.58.0   extraDistr_1.9.1    
## [46] V8_3.4.0             munsell_0.5.0        reprex_1.0.0        
## [49] callr_3.7.0          compiler_4.0.3       systemfonts_1.0.2   
## [52] grid_4.0.3           rstudioapi_0.13      labeling_0.4.2      
## [55] rmarkdown_2.6        gtable_0.3.0         codetools_0.2-16    
## [58] inline_0.3.17        DBI_1.1.1            curl_4.3            
## [61] R6_2.5.0             gridExtra_2.3        lubridate_1.7.9.2   
## [64] knitr_1.33           utf8_1.1.4           ragg_1.1.3          
## [67] stringi_1.5.3        parallel_4.0.3       vctrs_0.3.8         
## [70] dbplyr_2.1.0         tidyselect_1.1.0     xfun_0.24