UPDATE: 2021-09-15 00:10:00

1 はじめに

ここでは、下記のドキュメントを参考に、prophetパッケージの基本的な使い方をおさらいすることを目的としています。ゆくゆくは外部予測変数を追加したモデルやクロスバリデーション、パラメタチューニングなどなど、モデルを発展させながら使用方法をまとめていきます。

モデルの数理部分は下記のprophetに関する論文やブログ記事を参照願います。非常にわかりやすいです。

2 ライブラリと関数の読み込み

library(prophet)
library(tidyverse)

head_tail <- function(data, n = 5){
  stopifnot(is.data.frame(data))
  head(data, n = n) %>% print()
  cat(paste(rep("-", 100), collapse = ""), "\n")
  tail(data, n = n) %>% print()
}

cumpaste <- function(x, sep = "+") {
  Reduce(f = function(x1, x2){paste(x1, x2, sep = sep)},
         x, 
         accumulate = TRUE)
}

3 Prophetの周期変動

今回は周期変動の扱いについておさらいしていきます。周期変動とは、繰り返される変動のことで、例えば、週5日の勤務は「毎週」繰り返され時系列に影響を与えたり、季節による商品の売れ行きは「毎年」繰り返される時系列として影響を及ぼします。 Prophetでは、これらの周期的に繰り返される影響を捉え、未来を予測するために、下記のフーリエ級数を使用して、周期的な影響をモデルに組みます。例えば、\(2 \pi\)単位で1周するので、1年周期の場合は\(P=365.25\)、1週間周期の場合は\(P=7\)、24時間周期の場合は\(P=24\)として、周期効果を表現します。

\[ s(t) = \sum_{n=1}^{N} \left(a_{n} \cos\left(\frac{2 \pi n t}{P} \right) + b_{n} \sin \left( \frac{2 \pi n t}{P}\right) \right) \]

これらの係数を最適化するため、合計\(2N\)のパラメタをベクトルとして表現し、

\[ \boldsymbol{ \beta } =\left[ \begin{array}{c} a_1 \\ b_1 \\ a_2 \\ b_2 \\ \vdots \\ a_n \\ b_n \end{array} \right] = [ a_1, b_1, a_2, b_2, \ldots, a_n, b_n ]^{ \mathrm{ T } } \]

三角関数の部分も同じくベクトルとして抜き出します。

\[ X(t) = \left[ \cos \left(\frac{2 \pi (1) t}{365.25} \right), \ldots, \sin \left(\frac{2 \pi (n) t}{365.25} \right) \right] \]

これらを組み合わせ、周期要因を下記のとおりに表現してモデルに組み込みます。フーリエ級数の各係数には\(\beta \sim Normal(0, \sigma^2)\)として事前分布を設定しているそうです。

\[ s(t) = X(t) \boldsymbol{ \beta } \]

4 フーリエ級数のおさらい

Prophetで周期変動を組み込んだモデルを構築する前に、フーリエ級数で周期変動を捉える方法のおさらいをしておきます。フーリエ級数とは下記の三角関数の式のことで、

\[ y(t) = \frac{a_{0}}{2}\sum \left(a_{n} \cos\left(n t \right) + b_{n} \sin \left( n t \right) \right) \]

この\(a_{0},a_{1},b_{1}, \ldots, a_{n},b_{n}\)のフーリエ係数を決めることにより、様々なグラフを表現することができきます。例えば、区間\([a, b]\)上で、データが\(2m+1\)個ある場合は、下記の式で近似できます。

\[ y(t) = \frac{a_{0}}{2}\sum_{n=1}^{N} \left(a_{n} \cos\left(\frac{2 \pi n t}{b - a} \right) + b_{n} \sin \left( \frac{2 \pi n t}{b - a}\right) \right) \]

例えば、下記のように1周期が24時間のようなデータであれば、区間[0, 24]で25個のデータがあれば、\(2 m + 1 = 25\)より、\(N = 12\)で近似できます。

max <- 24
set.seed(1989)
df <- tibble(
  x = seq(0, max, 1),
  y = sin((2*pi*x)/24) + rnorm(n=x,0,0.2),
) 

df %>% 
  ggplot() + 
  geom_point(aes(x, y)) + 
  geom_vline(xintercept = c(0,6,12,18,24), col = "gray", linetype = "dashed") + 
  scale_x_continuous(breaks = seq(0,max,1)) + 
  theme_bw()

フーリエ係数を計算してみます。

n <- 1/25
fit <- lm(
  y ~
    cos(2 * pi * n * x)  + sin(2 * pi * n * x) +
    cos(4 * pi * n * x)  + sin(4 * pi * n * x) +
    cos(6 * pi * n * x)  + sin(6 * pi * n * x) +
    cos(8 * pi * n * x)  + sin(8 * pi * n * x) +
    cos(10 * pi * n * x) + sin(10 * pi * n * x) +
    cos(12 * pi * n * x) + sin(12 * pi * n * x),
  data = df
)
summary(fit)
## 
## Call:
## lm(formula = y ~ cos(2 * pi * n * x) + sin(2 * pi * n * x) + 
##     cos(4 * pi * n * x) + sin(4 * pi * n * x) + cos(6 * pi * 
##     n * x) + sin(6 * pi * n * x) + cos(8 * pi * n * x) + sin(8 * 
##     pi * n * x) + cos(10 * pi * n * x) + sin(10 * pi * n * x) + 
##     cos(12 * pi * n * x) + sin(12 * pi * n * x), data = df)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -0.32007 -0.17309 -0.01641  0.18276  0.37208 
## 
## Coefficients:
##                      Estimate Std. Error t value Pr(>|t|)    
## (Intercept)          -0.01172    0.05728  -0.205    0.841    
## cos(2 * pi * n * x)   0.07412    0.08100   0.915    0.378    
## sin(2 * pi * n * x)   1.01371    0.08100  12.515 3.02e-08 ***
## cos(4 * pi * n * x)   0.02416    0.08100   0.298    0.771    
## sin(4 * pi * n * x)  -0.10522    0.08100  -1.299    0.218    
## cos(6 * pi * n * x)   0.01361    0.08100   0.168    0.869    
## sin(6 * pi * n * x)   0.01602    0.08100   0.198    0.847    
## cos(8 * pi * n * x)   0.07613    0.08100   0.940    0.366    
## sin(8 * pi * n * x)  -0.02091    0.08100  -0.258    0.801    
## cos(10 * pi * n * x)  0.07352    0.08100   0.908    0.382    
## sin(10 * pi * n * x)  0.07039    0.08100   0.869    0.402    
## cos(12 * pi * n * x)  0.08290    0.08100   1.023    0.326    
## sin(12 * pi * n * x) -0.03173    0.08100  -0.392    0.702    
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 0.2864 on 12 degrees of freedom
## Multiple R-squared:  0.9314, Adjusted R-squared:  0.8629 
## F-statistic: 13.59 on 12 and 12 DF,  p-value: 3.547e-05

さきほど計算したフーリエ係数を使って予測値を計算して、プロットしてみると、周期的な変動を上手く捉えられていそうです。

# predictの中身はこんな感じ
# pred <- function(x) {
#   res <-
#     fit$coefficients[[1]] +
#     fit$coefficients[[2]] * cos(2 * pi * n * x)   + fit$coefficients[[3]] * sin(2 * pi * n * x) +
#     fit$coefficients[[4]] * cos(4 * pi * n * x)   + fit$coefficients[[5]] * sin(4 * pi * n * x) +
#     fit$coefficients[[6]] * cos(6 * pi * n * x)   + fit$coefficients[[7]] * sin(6 * pi * n * x) +
#     fit$coefficients[[8]] * cos(8 * pi * n * x)   + fit$coefficients[[9]] * sin(8 * pi * n * x) +
#     fit$coefficients[[10]] * cos(10 * pi * n * x) + fit$coefficients[[11]] * sin(10 * pi * n * x) +
#     fit$coefficients[[12]] * cos(12 * pi * n * x) + fit$coefficients[[13]] * sin(12 * pi * n * x)
#   return(res)
# }

df %>% 
  dplyr::mutate(pred = predict(fit, newdata = .)) %>% 
  ggplot() + 
  geom_point(aes(x, y)) + 
  geom_line(aes(x, pred), col = "#0072B2") +  
  geom_vline(xintercept = c(0,6,12,18,24), col = "gray", linetype = "dashed") + 
  scale_x_continuous(breaks = seq(0, max, 1)) +
  theme_bw()

フーリエ係数を変化させてみるとわかりますが、大きくなればなるほど過学習していきます。このこともあって、Prophetの論文では、年次および週次の季節性については、年次は\(N=10\)、週次は\(N=3\)で機能すると書かれています。

n <- 12

# cos(*) + sin(*)のベースを作成
res <- vector(mode = "character", length = n)
for (i in 1:n){
  res[[i]] <- paste0(
    "cos(", i*2, "*pi*(1/25)*x)+sin(", i*2, "*pi*(1/25)*x)"
    )
}

# 累積の要領で計算式を連結そて回帰式を作成
# イメージ: f1, f1+f2, f1+f2+f3,... 
fs <- paste("y ~", cumpaste(res, sep = "+"))
df_plt <- df 
for (i in 1:n){
  fit <- lm(as.formula(fs[[i]]), data = df_plt)
  pred <- predict(fit, newdata = df_plt)
  df_plt <- cbind(df_plt, pred)
}
names(df_plt) <- c("x", "y", paste0("f=",2*1:n))

df_plt %>% 
  tidyr::pivot_longer(cols = starts_with("f"), names_to = 'patterns') %>% 
  dplyr::mutate(patterns = fct_relevel(patterns, c("f=2","f=4","f=6","f=8","f=10","f=12"))) %>% 
  ggplot(.) + 
  geom_point(aes(x, y)) +  
  geom_line(aes(x, value), col = "#0072B2") +  
  geom_vline(xintercept = c(0,6,12,18,24), col = "gray", linetype = "dashed") + 
  scale_x_continuous(breaks = seq(0, max, 3)) +
  facet_wrap( ~ patterns) +
  ggtitle("Fourier series for Periodic effects") + 
  theme_bw()

5 Prophetのadd_seasonality関数

Prophetのadd_seasonality関数の説明をするために、ドキュメントで使用されているWikipediaのページビューデータの期間を2年分に絞って動かしていきます。

# Githubのprophetのリポジトリからサンプルデータをインポート
base_url <- 'https://raw.githubusercontent.com/facebook/prophet/master/examples/'
data_path <- paste0(base_url, 'example_wp_log_peyton_manning.csv')
df <- readr::read_csv(data_path) %>%
  dplyr::filter(ds >= "2011-01-01" & ds <= "2012-12-31")

# 先頭と末尾のデータを表示
head_tail(df, n = 5)
## # A tibble: 5 × 2
##   ds             y
##   <date>     <dbl>
## 1 2011-01-01  9.01
## 2 2011-01-02  9.40
## 3 2011-01-03  9.99
## 4 2011-01-04  9.06
## 5 2011-01-05  8.97
## ---------------------------------------------------------------------------------------------------- 
## # A tibble: 5 × 2
##   ds             y
##   <date>     <dbl>
## 1 2012-12-27  8.97
## 2 2012-12-28  8.68
## 3 2012-12-29  8.53
## 4 2012-12-30  9.49
## 5 2012-12-31 10.1

prophet()のデフォルト設定では、yearly.seasonality="auto"weekly.seasonality="auto"daily.seasonality="auto"となっており、周期性を自動でモデルに組み込んでくれます。周期を自分で設定したい場合は、yearly.seasonality=365のように引数にフーリエ級数を設定するか、add_seasonality()を利用できます。四半期(period=91.3125)や日次(period=24)なども追加可能です。

add_seasonality()を使用するためには、prophet()にデータ渡さず、インスタンス化だけ行います。そのあとにadd_seasonality()を使用して、1周期の日数であるperiodと、フーリエ級数であるfourier.order、あとはラベルであるnameを渡します。name="monthly"と設定したからといって、月次の周期の設定ができるわけではなく、これはただのラベルです。

add_seasonality()の設定方法を見ればわかりますが、ビジネスの要件で発生する特別な周期性であっても、簡単にモデルに追加し、周期をモデルに組み込むことが可能です。

# カスタムで周期を追加する場合、prophet()でデータを持たないインスタンスを作成
# dfのデフォルト引数はNULLなので、記述する必要はない
m <- prophet(df = NULL,
             yearly.seasonality = FALSE,
             weekly.seasonality = FALSE,
             daily.seasonality = FALSE)

# 周期はperiodで設定するため、nameはただのラベル
m <- add_seasonality(m = m, name = "yearly", period = 365.25, fourier.order = 10)
m <- add_seasonality(m = m, name = "monthly", period = 30.5, fourier.order = 5)
m <- add_seasonality(m = m, name = "weekly", period = 7, fourier.order = 3)
# ここでモデルの計算を行う
m <- fit.prophet(m = m, df = df)

設定した周期成分は、prophet_plot_components()で簡単に可視化することも可能性です。可視化された周期を見ると、「週次では月曜日に」「年次では1月、9月に」「月次では10日前後」に数字が高くなることがわかります。

future_df <- prophet::make_future_dataframe(m = m, periods = 30)
forecast_df <- predict(object = m, fcst = future_df)
prophet_plot_components(m = m, fcst = forecast_df)

## 増加していく周期性(Multiplicative Seasonality)

Prophetは、周期ごとに大きくなっていくような乗法的な周期性もモデルに組み込むことができます。下記のような乗法的な季節周期をもつ時系列データに対して、デフォルトの設定でモデリングしても、うまく予測ができません。 乗法的な周期性のイメージは下記のとおりです。トレンドに何らかの係数をかけていくことで表現できます。

n <- 12 * 10
w <- 1:n
x <- seq(1, 100, length.out = n)
tibble::tibble(
  ds = seq(from = as.Date("2010-01-01"), by = "1 month", length.out = n),
  y = w * sin(x) + 2 * x + rnorm(n = n, 0, 10)) %>%
  ggplot(aes(ds, y)) +
  geom_line(size = 1, col = "#749FC6") +
  scale_y_continuous(labels = scales::comma) +
  scale_x_date(date_breaks = '12 month', date_labels = "%Y-%m") +
  theme_bw() +
  theme(axis.text.x = element_text(angle = 30, hjust = 1)) +
  ggtitle("Multiplicative Seasonality")

Prophetでは実際、下記のようなモデルを組んで乗法的な周期性を表現してます。

\[ y_{t} = g(t) * (1 + s(t) + h(t) + \beta x(t) )+ \epsilon_{t} \]

ここでは、サンプルデータとして航空旅客数のこの時系列データAirPassengersを利用して確認していきます。

n <- length(AirPassengers)
df <- tibble::tibble(y = AirPassengers)
df$ds <- seq(from = as.Date("1949-01-01"), by = "1 month", length.out = n)

df %>% 
  ggplot(aes(ds, y)) +
  geom_line(size = 1, col = "#749FC6") + 
  scale_y_continuous(labels = scales::comma) + 
  scale_x_date(date_breaks = '12 month', date_labels = "%Y-%m") + 
  theme_bw() + 
  theme(axis.text.x = element_text(angle = 30, hjust = 1)) +
  ggtitle("Multiplicative Seasonality")

まずはデフォルト設定でモデルを学習させてみると上にスパイクする箇所において、少し予測が小さくなっているようです。

m1 <- prophet(df)
future_df1 <- make_future_dataframe(m1, 12, freq = "month")
fore_df1 <- predict(m1, future_df1)
plot(m1, fore_df1) + ggtitle('seasonality.mode = "additive"')

あとで利用するので、prophet_plot_components()で各要素の影響を可視化しておきます。

prophet_plot_components(m1, fore_df1)

次は、乗法的な周期性を設定してモデルを学習させてみます。Prophetでは、特定の要素に対して、乗法的にモデルに組み込んだり、部分的に加法的にするなども柔軟に変更できます。

先程のモデルとの違いは、上にスパイクする箇所も、予測が上手く行っています。このように、トレンドとともに大きくなるような周期性を持つデータには、乗法的周期性を考慮したモデルで対応することでうまくいきます。

m2 <- prophet(df, seasonality.mode = "multiplicative")
future_df2 <- make_future_dataframe(m2, 12, freq = "month")
fore_df2 <- predict(m2, future_df2)
plot(m2, fore_df2) + ggtitle('seasonality.mode = "multiplicative"')

先程と同じく、prophet_plot_components()で各要素の影響を可視化しておきます。デフォルト設定のモデルと比べると、年周期の部分はパーセント表記になっています。各要素にトレンドが掛け合わされているため、パーセント表記にしているようです。

prophet_plot_components(m2, fore_df2)

解釈としては、6月、7月、8月あたりは、トレンドに対して平均して25%くらい大きくなり、3月くらいは平均してトレンドの−50%くらい少なくなるということだと思われます。計算のイメージは下記のとおりです。

fore_df2 %>% 
  dplyr::mutate(yearly_val = trend * yearly,
         my_yhat = trend + yearly_val,
         check = near(yhat, my_yhat)) %>% 
  dplyr::select(ds, trend, yearly, yearly_val, yhat, my_yhat, check) %>% 
  head_tail()
##           ds    trend      yearly yearly_val      yhat   my_yhat check
## 1 1949-01-01 113.5373 -0.08966038 -10.179796 103.35748 103.35748  TRUE
## 2 1949-02-01 115.2933 -0.13337690 -15.377462  99.91583  99.91583  TRUE
## 3 1949-03-01 116.8794 -0.02123801  -2.482285 114.39708 114.39708  TRUE
## 4 1949-04-01 118.6354 -0.02261999  -2.683531 115.95185 115.95185  TRUE
## 5 1949-05-01 120.3348 -0.01188864  -1.430616 118.90414 118.90414  TRUE
## ---------------------------------------------------------------------------------------------------- 
##             ds    trend      yearly yearly_val     yhat  my_yhat check
## 152 1961-08-01 520.4201  0.24888761  129.52612 649.9462 649.9462  TRUE
## 153 1961-09-01 524.0041  0.05567740   29.17518 553.1792 553.1792  TRUE
## 154 1961-10-01 527.4724 -0.07277258  -38.38553 489.0868 489.0868  TRUE
## 155 1961-11-01 531.0563 -0.20051412 -106.48429 424.5720 424.5720  TRUE
## 156 1961-12-01 534.5246 -0.11397004  -60.91979 473.6048 473.6048  TRUE

6 セッション情報

sessionInfo()
## R version 4.0.3 (2020-10-10)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Big Sur 10.16
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] ja_JP.UTF-8/ja_JP.UTF-8/ja_JP.UTF-8/C/ja_JP.UTF-8/ja_JP.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] forcats_0.5.1   stringr_1.4.0   dplyr_1.0.7     purrr_0.3.4    
##  [5] readr_1.4.0     tidyr_1.1.3     tibble_3.1.3    ggplot2_3.3.3  
##  [9] tidyverse_1.3.0 prophet_1.0     rlang_0.4.10    Rcpp_1.0.6     
## 
## loaded via a namespace (and not attached):
##  [1] httr_1.4.2           jsonlite_1.7.2       modelr_0.1.8        
##  [4] RcppParallel_5.0.2   StanHeaders_2.21.0-7 assertthat_0.2.1    
##  [7] highr_0.8            stats4_4.0.3         cellranger_1.1.0    
## [10] yaml_2.2.1           pillar_1.6.2         backports_1.2.1     
## [13] glue_1.4.2           digest_0.6.27        rvest_0.3.6         
## [16] colorspace_2.0-0     htmltools_0.5.1.1    pkgconfig_2.0.3     
## [19] rstan_2.21.2         broom_0.7.9          haven_2.3.1         
## [22] scales_1.1.1         processx_3.5.2       generics_0.1.0      
## [25] farver_2.0.3         ellipsis_0.3.2       withr_2.4.1         
## [28] cli_3.0.1            magrittr_2.0.1       crayon_1.4.0        
## [31] readxl_1.3.1         evaluate_0.14        ps_1.5.0            
## [34] fs_1.5.0             fansi_0.4.2          xml2_1.3.2          
## [37] pkgbuild_1.2.0       textshaping_0.3.5    loo_2.4.1           
## [40] tools_4.0.3          prettyunits_1.1.1    hms_1.0.0           
## [43] matrixStats_0.58.0   lifecycle_1.0.0      extraDistr_1.9.1    
## [46] V8_3.4.0             munsell_0.5.0        reprex_1.0.0        
## [49] callr_3.7.0          compiler_4.0.3       systemfonts_1.0.2   
## [52] grid_4.0.3           rstudioapi_0.13      labeling_0.4.2      
## [55] rmarkdown_2.6        gtable_0.3.0         codetools_0.2-16    
## [58] inline_0.3.17        DBI_1.1.1            curl_4.3            
## [61] R6_2.5.0             gridExtra_2.3        lubridate_1.7.9.2   
## [64] knitr_1.33           utf8_1.1.4           ragg_1.1.3          
## [67] stringi_1.5.3        parallel_4.0.3       vctrs_0.3.8         
## [70] dbplyr_2.1.0         tidyselect_1.1.0     xfun_0.24