UPDATE: 2021-09-15 00:07:40

1 はじめに

ここでは、下記のドキュメントを参考に、prophetパッケージの基本的な使い方をおさらいすることを目的としています。ゆくゆくは外部予測変数を追加したモデルやクロスバリデーション、パラメタチューニングなどなど、モデルを発展させながら使用方法をまとめていきます。

モデルの数理部分は下記のprophetに関する論文やブログ記事を参照願います。非常にわかりやすいです。

2 ライブラリと関数の読み込み

library(prophet)
library(tidyverse)

head_tail <- function(data, n = 5){
  stopifnot(is.data.frame(data))
  head(data, n = n) %>% print()
  cat(paste(rep("-", 100), collapse = ""), "\n")
  tail(data, n = n) %>% print()
}

3 prophetのトレンド

prophetのモデルはトレンドを1つの構成要素としており、線形に成長する場合はgrowth="linear"を指定します。デフォルト設定では、growth="linear"です。

一方で、特定の最大値、最小値が決まっていたり、予想できる場合、growth="logistic"capを指定することで、モデルの予測値の上限の収容力、下限の収容力をコントロールすることができます。

WikipediaのR言語のページのページビュー(対数スケール)が保存されているサンプルデータを利用します。ここでは、8を超えないようにコントロールします(絶対にコントロールできるわけではないです)。基本的にはドメインを知識を利用して値を設定します。また、capはデータフレームの列に加える必要があり、予測する場合のデータフレームにも必要です。そして、固定値である必要もなく、状況に応じて必要であれば、値を可変させることも可能です。

# Githubのprophetのリポジトリからサンプルデータをインポート
base_url <- 'https://raw.githubusercontent.com/facebook/prophet/master/examples/'
data_path <- paste0(base_url, 'example_wp_log_R.csv')
df <- readr::read_csv(data_path) %>%
  dplyr::arrange(ds)

# capの固定値を設定
df_cap <- df %>% 
  dplyr::mutate(cap = 8)

# 先頭と末尾のデータを表示
head_tail(df_cap, n = 5)
## # A tibble: 5 × 3
##   ds             y   cap
##   <date>     <dbl> <dbl>
## 1 2008-01-01  4.80     8
## 2 2008-01-02  5.38     8
## 3 2008-01-03  5.66     8
## 4 2008-01-04  5.60     8
## 5 2008-01-05  5.28     8
## ---------------------------------------------------------------------------------------------------- 
## # A tibble: 5 × 3
##   ds             y   cap
##   <date>     <dbl> <dbl>
## 1 2015-12-27  7.04     8
## 2 2015-12-28  7.56     8
## 3 2015-12-29  7.60     8
## 4 2015-12-30  7.61     8
## 5 2015-12-31  7.24     8

4 モデルの学習と予測

capを格納したデータフレームを利用し、growth=logisticを指定して予測モデルを構築します。

m_logistic <- prophet(df = df_cap, growth = "logistic")
names(m_logistic)
##  [1] "growth"                  "changepoints"           
##  [3] "n.changepoints"          "changepoint.range"      
##  [5] "yearly.seasonality"      "weekly.seasonality"     
##  [7] "daily.seasonality"       "holidays"               
##  [9] "seasonality.mode"        "seasonality.prior.scale"
## [11] "changepoint.prior.scale" "holidays.prior.scale"   
## [13] "mcmc.samples"            "interval.width"         
## [15] "uncertainty.samples"     "specified.changepoints" 
## [17] "start"                   "y.scale"                
## [19] "logistic.floor"          "t.scale"                
## [21] "changepoints.t"          "seasonalities"          
## [23] "extra_regressors"        "country_holidays"       
## [25] "stan.fit"                "params"                 
## [27] "history"                 "history.dates"          
## [29] "train.holiday.names"     "train.component.cols"   
## [31] "component.modes"         "fit.kwargs"

予測で利用するデータフレームにもcapを格納する必要があります。その他は前回の記事の内容と同じです。

future_df_logistic <- make_future_dataframe(m_logistic, periods = 900) %>% 
  dplyr::mutate(cap = 8)
forecast_df_logistic <- predict(object = m_logistic, df = future_df_logistic)
dyplot.prophet(x = m_logistic, fcst = forecast_df_logistic)

予測された結果をみるとわかりますが、capで指定した数値を超えないように調整されています。このモデルでは、growth=logisticを指定して予測モデルを構築しましたが、growth=linearを指定して予測モデルを構築すると、下記の通り、似たような予測が行われます。capを指定していませんが、8を超えないように予測されています。capを指定していないため、予測値の信用区間が発散していく傾向が見て取れます。

おそらく、2013年以降から予測時点までの学習データの傾向から、予測時点は右肩上がりではなく、水平にくらいになると計算された結果だと思われます。今後の傾向が変わらないようであれば、このような場合は、growth=logisticcapの設定はあってもなくても、同じになりそうです。

m_linear <- prophet(df = df, growth = "linear")
future_df_linear <- make_future_dataframe(m_linear, periods = 900)
forecast_df_linear <- predict(object = m_linear, df = future_df_linear)
dyplot.prophet(x = m_linear, fcst = forecast_df_linear)

サンプルデータを右肩上がりの傾向が学習できそうな2013年までに限定すると、8を超えて、9にも届きそうな予測になっています。

m_linear2013 <- prophet(df = df %>% dplyr::filter(ds < "2013-01-01"), growth = "linear")
future_df_linear2013 <- make_future_dataframe(m_linear2013, periods = 900)
forecast_df_linear2013 <- predict(object = m_linear2013, df = future_df_linear2013)
dyplot.prophet(x = m_linear2013, fcst = forecast_df_linear2013)

2013年までにデータを絞り、capgrowth=logisticを指定して予測モデルを構築すると、このようになります。

実際に今日が2013年だとして、明日以降のデータを予測してほしいと言われたとします。データだけ見れば、右肩上がりなので、capgrowth=logisticは不要と思ってしまうかもしれません。Prophetが手軽に高品質に予測できるからと言って、データを可視化してみる、ドメイン知識を使ってモデルを熟慮することなどを怠れば、Prophetといえど役に立たないモデルになりそうです。

m_logistic2013 <- prophet(df = df_cap %>% dplyr::filter(ds < "2013-01-01"), growth = "logistic")
future_df_logistic2013 <- make_future_dataframe(m_logistic2013, periods = 900) %>% 
  dplyr::mutate(cap = 8.0)
forecast_df_logistic2013 <- predict(object = m_logistic2013, df = future_df_logistic2013)
dyplot.prophet(x = m_logistic2013, fcst = forecast_df_logistic2013)

5 フラットトレンド

トレンドの変化があまりなく、強い季節性周期がある時系列データの場合、トレンドをフラット(growth="flat")にすることが役立つ場合があります。

prophet(df = df, growth = "flat")$growth
## [1] "flat"

6 prophetのトレンドを深ぼる

prophetのトレンドについては、下記のProphetに関する論文に内容が記載されています。ここでは、ロジスティックトレンドについてみていきます。

これを参考にすると、ロジスティクトレンドは下記のように表現されています。Cはcarrying capacity(環境収容力)、kはgrowth rate(成長率)、mはoffset parameter(オフセット)と記載されています。

\[ g(t) = \frac{C}{ 1 + exp(-k(t - m)) } \]

さらに、論文の中では、環境収容力は一定ではなく、成長率も一定ではないため、時点によってこれらが変化できるようにしているとあります。一旦ここでは、時点よって変わらない骨格となる関数の挙動を深ぼってみます。

# ロジスティックトレンド
g <- function(t, C = 1, k = 1, m = 0){
  res <- C / (1 + exp(-k * (t - m)))
  return(res)
}

kmを固定し、Cを変化させていきます。Cは環境収容力を表すため、Cで指定した値が上限になるように動いていることがわかります。

t <- seq(-10, 10, 0.1)
C <- c(1, 2, 3, 4, 5, 6)
k <- rep(1, length(C))

pattern <- paste0('C = ', C, ', k = ', k)
df_plt <- tibble::tibble(t)
for (i in 1:length(C)) {
  y <- g(t = t, C = C[[i]], k = k[[i]])
  df_plt <- cbind(df_plt, y)
}
names(df_plt) <- c('t', pattern)

df_plt %>% 
  tidyr::pivot_longer(cols = -t, names_to = 'patterns') %>% 
  dplyr::arrange(patterns, t) %>% 
  ggplot(., aes(t, value, col = patterns, fill = patterns)) + 
  geom_line() + 
  geom_hline(yintercept = 1) + 
  scale_y_continuous(breaks = min(C):max(C)) +
  facet_wrap( ~ patterns) 

Cmを固定し、kを変化させていきます。kは成長率を表すため、基準線と比較するとわかりますが、kが正の値に大きくなると、急激に大きく動いていることがわかります。また、kが負の場合、右肩下がりの変化になることが見てとれます。

t <- seq(-10, 10, 0.1)
k <- c(-2, -1, -0.5, 0.5, 1, 2)
C <- rep(1, length(k))

pattern <- paste0('C = ', C, ', k = ', k)
df_plt <- tibble::tibble(t)
for (i in 1:length(C)) {
  y <- g(t = t, C = C[[i]], k = k[[i]])
  df_plt <- cbind(df_plt, y)
}
names(df_plt) <- c('t', pattern)

df_plt %>% 
  tidyr::pivot_longer(cols = -t, names_to = 'patterns') %>% 
  dplyr::arrange(patterns, t) %>% 
  ggplot(., aes(t, value, col = patterns, fill = patterns)) + 
  geom_line() + 
  geom_vline(xintercept = 0) +
  facet_wrap( ~ patterns) 

最後にCkを固定し、mを変化させていきます。mはオフセットを表すため、同じ基準線の位置でも、出力値が変化していることがわかります。

t <- seq(-10, 10, 0.1)
m <- c(-2, -1, 0, 1, 2, 3)
C <- rep(1, length(m))
k <- rep(1, length(m))

pattern <- paste0('C = ', C, ', k = ', k, ', m = ', m)
df_plt <- tibble::tibble(t)
for (i in 1:length(C)) {
  y <- g(t = t, C = C[[i]], k = k[[i]], m = m[[i]])
  df_plt <- cbind(df_plt, y)
}
names(df_plt) <- c('t', pattern)

df_plt %>% 
  tidyr::pivot_longer(cols = -t, names_to = 'patterns') %>% 
  dplyr::arrange(patterns, t) %>% 
  ggplot(., aes(t, value, col = patterns, fill = patterns)) + 
  geom_line() + 
  geom_vline(xintercept = 0) +
  facet_wrap( ~ patterns) 

7 時点で変化する成長率のトレンド

ここの作業が終わると、ロジスティクトレンドは下記のようになります。

\[ g(t) = \frac{C_{t}}{ 1 + exp(-(k + \boldsymbol{a}(t) \cdot \boldsymbol{\delta^{ \mathrm{T}}}) (t - m)} \]

成長率k\(\boldsymbol{a}(t) \cdot \boldsymbol{\delta^{ \mathrm{T}}}\)を足し込んでいますが、これはイメージとしては、成長率は状況に応じて変わるので、状況に応じて成長率が変わるのであれば、変化させようと表現しています。

そのため、成長率が変わる変化点が\(S_{j, 1...S}\)個あったとし、各時点での成長率を調整するベクトル\(\boldsymbol{\delta}\)を用意します。また、時点\(t\)の成長率は、基本となる\(k\)\(t\)時点までに出現した\(\boldsymbol{\delta}\)の総和として、下記のように表現します。

\[ k + \displaystyle \sum_{j:t > s_{j}} \delta_i \]

そして、総和の計算をしやすいように\(\boldsymbol{a}(t)\)を01のベクトルで表して計算します。これで、時点によって変化する成長率を表現しています。再現コードは下記のブログを参考にさせていただきました。Pythonのコードを交えながら非常にわかりやすく、解説されているブログです。

自分の練習がてら、Rで再現コードを書いていますが、実装が誤っている場合、参考元ブログの誤りではなく、おそらく私の誤りである可能性が高いため、予めお断りさせていただきます。加え、次回以降の記事ではProphetパッケージのトレンドや未来の変化点の計算部分をスクリプトに沿って深ぼっているので、Rのスクリプトを勉強したいということであれば、そっちを見たほうが正確かつスクリプトも綺麗です。

g2 <- function(t, C = 1, k = 1, m = 0,  S, d){
  a <- matrix(0, nrow = length(t), ncol = length(S))
  for (i in 1:length(t)) {
    for (j in 1:length(S)) {
      a[i, j] <- ifelse(S[[j]] < t[[i]], 1, 0) 
    }
  }
  
  y <- C / (1 + exp(-(k + (a %*% d)) * (t - m)))
  
  return(data.frame(t, y))
}

t <- seq(-10 ,10, length.out = 100)
S <- c(-5, 1, 5) # change point time
delta <- c(0.1, 0.3, -0.6) # change rate for growth rate

df2 <- g2(
  t = t, # time
  C = 1, # capacity
  k = 0.1, # growth rate
  m = 0, # offset
  S = S, # change point
  d = delta  # change rate for growth rate
)

ggplot(df2, aes(t, y)) + 
  geom_line() + 
  geom_vline(xintercept = S, col = "red", linetype = "dashed") + 
  scale_x_continuous(breaks = ceiling(seq(-10, 10, 1)))

このままでは、変化点で曲線が滑らかではなく、トレンドの変化が大きい状態です。この状態を調整するために、オフセットを利用します。

8 時点で変化する成長率のトレンド調整

ここの作業が終わると、ロジスティクトレンドは下記のようになります。

\[ g(t) = \frac{C_{t}}{ 1 + exp(-(k + \boldsymbol{a}(t) \cdot \boldsymbol{\delta^{ \mathrm{T}}}) (t - (m + \boldsymbol{a}(t) \cdot \boldsymbol{\gamma^{ \mathrm{T}}}))} \]

まずは、成長率が変わる変化点\(S_{j,1...S}\)個に対し、オフセットを調整するベクトル\(\boldsymbol{\gamma}\)を用意します。ベクトル\(\boldsymbol{\gamma}\)は下記のように定義します。

\[ \gamma_{j} = \left(s_{j} - m - \displaystyle \sum_{l \lt j} \gamma_{l}\right) \cdot \left(1 - \frac{ k + \displaystyle \sum_{l \lt j} \delta_{l} }{k + \displaystyle \sum_{l \le j} \delta_{l}} \right) \]

オフセットを利用することで、折れ線が折れ曲がる前後を調整することで、変化点の結合部分が連続で滑らかになるように調整します。

g3 <- function(t, C = 1, k = 1, m = 0,  S, d){
  a <- matrix(0, nrow = length(t), ncol = length(S))
  for (i in 1:length(t)) {
    for (j in 1:length(S)) {
      a[i, j] <- ifelse(S[[j]] < t[[i]], 1, 0) 
    }
  }
  
  gamma <- vector(mode = "numeric", length = length(S))
  for (j in 1:length(gamma)) {
    gamma[j] = (S[j] - m - sum(gamma[0:(j-1)])) * (1 - ((k + sum(d[0:(j-1)])) / (k + sum(d[0:j]))))
  }

  y <- C / (1 + exp(-(k + (a %*% d)) * (t - (m + (a %*% gamma)))))
  
  return(data.frame(t, y))
}

t <- seq(-10 ,10, length.out = 100)
S <- c(-5, 1, 5) # change point time
delta <- c(0.1, 0.3, -0.6) # change rate for growth rate

df3 <-  g3(
  t = t, # time
  C = 1, # capacity
  k = 0.1, # growth rate
  m = 0, # offset
  S = S, # change point
  d = delta  # change rate for growth rate
)

ggplot(df3, aes(t, y)) + 
  geom_line() + 
  geom_vline(xintercept = S, col = "red", linetype = "dashed") + 
  scale_x_continuous(breaks = seq(-10, 10, 1))

ここでは、トレンドを深ぼるにあたって、変化点を自分で設定していましたが、Prophetではどのように変化点を検知しているのか、次回はその点について見ていきます。

9 セッション情報

sessionInfo()
## R version 4.0.3 (2020-10-10)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Big Sur 10.16
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] ja_JP.UTF-8/ja_JP.UTF-8/ja_JP.UTF-8/C/ja_JP.UTF-8/ja_JP.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] forcats_0.5.1   stringr_1.4.0   dplyr_1.0.7     purrr_0.3.4    
##  [5] readr_1.4.0     tidyr_1.1.3     tibble_3.1.3    ggplot2_3.3.3  
##  [9] tidyverse_1.3.0 prophet_1.0     rlang_0.4.10    Rcpp_1.0.6     
## 
## loaded via a namespace (and not attached):
##  [1] matrixStats_0.58.0   fs_1.5.0             xts_0.12.1          
##  [4] lubridate_1.7.9.2    httr_1.4.2           rstan_2.21.2        
##  [7] tools_4.0.3          backports_1.2.1      utf8_1.1.4          
## [10] R6_2.5.0             DBI_1.1.1            colorspace_2.0-0    
## [13] withr_2.4.1          tidyselect_1.1.0     gridExtra_2.3       
## [16] prettyunits_1.1.1    processx_3.5.2       curl_4.3            
## [19] compiler_4.0.3       textshaping_0.3.5    cli_3.0.1           
## [22] rvest_0.3.6          xml2_1.3.2           labeling_0.4.2      
## [25] scales_1.1.1         dygraphs_1.1.1.6     callr_3.7.0         
## [28] systemfonts_1.0.2    digest_0.6.27        StanHeaders_2.21.0-7
## [31] rmarkdown_2.6        extraDistr_1.9.1     pkgconfig_2.0.3     
## [34] htmltools_0.5.1.1    highr_0.8            dbplyr_2.1.0        
## [37] htmlwidgets_1.5.3    readxl_1.3.1         rstudioapi_0.13     
## [40] generics_0.1.0       farver_2.0.3         zoo_1.8-8           
## [43] jsonlite_1.7.2       inline_0.3.17        magrittr_2.0.1      
## [46] loo_2.4.1            munsell_0.5.0        fansi_0.4.2         
## [49] lifecycle_1.0.0      stringi_1.5.3        yaml_2.2.1          
## [52] pkgbuild_1.2.0       grid_4.0.3           parallel_4.0.3      
## [55] crayon_1.4.0         lattice_0.20-41      haven_2.3.1         
## [58] hms_1.0.0            knitr_1.33           ps_1.5.0            
## [61] pillar_1.6.2         codetools_0.2-16     stats4_4.0.3        
## [64] reprex_1.0.0         glue_1.4.2           evaluate_0.14       
## [67] V8_3.4.0             RcppParallel_5.0.2   modelr_0.1.8        
## [70] vctrs_0.3.8          cellranger_1.1.0     gtable_0.3.0        
## [73] assertthat_0.2.1     xfun_0.24            broom_0.7.9         
## [76] ragg_1.1.3           ellipsis_0.3.2