UPDATE: 2021-09-15 00:09:29

1 はじめに

ここでは、下記のドキュメントを参考に、prophetパッケージの基本的な使い方をおさらいすることを目的としています。ゆくゆくは外部予測変数を追加したモデルやクロスバリデーション、パラメタチューニングなどなど、モデルを発展させながら使用方法をまとめていきます。

モデルの数理部分は下記のprophetに関する論文やブログ記事を参照願います。非常にわかりやすいです。

2 ライブラリと関数の読み込み

library(prophet)
library(tidyverse)

time_diff <- function(ds1, ds2, units = "days") {
  return(as.numeric(difftime(ds1, ds2, units = units)))
}

d <- tibble::tibble(
  ds = seq(as.Date("2020-01-01"), as.Date("2020-01-20"), by = "day"),
  y = c(10,13,14,20,24,19,12,10,13,14,16,24,25,26,22,21,16,15,18,25),
  cap = 30
)

# このあとの説明で使用するのでモデルを学習させておきます。
set.seed(1989)
m <- prophet(df = d, growth = "logistic", changepoint.prior.scale = 2)
future_df <- prophet::make_future_dataframe(m, periods = 5) %>% 
  dplyr::mutate(cap = 30)
forecast_df <- predict(m, future_df)
plot(m, forecast_df) + add_changepoints_to_plot(m) + 
  scale_x_datetime(date_breaks = '1 day', labels = scales::date_format("%Y-%m-%d")) + 
  theme(axis.text.x = element_text(angle = 30, hjust = 1))

3 prophetのpredict関数

ここまでのおさらいで、周期性や祝日効果、外部予測変数を除いて、Prophetがトレンドをどのように計算し、どのように変化点を見つけ、未来の予測において、どのように振る舞うのか、概要が把握できました。

ここではpredict()を実行した際に、どのような計算が行われていくのか、ソースコードをたどっていき、より理解を深めます。ここでは、季節周期や祝日効果、外部予測変数の部分は触れません。また、ロジスティクトレンドの場合をここでは扱うことにします。

Prophetのソースコードは下記より参照できます。

まずはpredict()が予測値を返すまでの流れをおさえていきます。

prophetクラスのpredict()は、学習したモデルと未来の日付が入ったデータフレームを受け取ると、setup_dataframe()でモデルに格納されている情報を使って、これ以降で必要な計算のための情報をデータフレームに付与していきます。

そして、predict_trend()が呼ばれ、この関数の中で、piecewise_logistic()が更に呼び出され、トレンドの計算が行われます。

その後に、predict_uncertainty()が呼び出され、未来のトレンドの計算の話に移ります。この関数の中では、sample_posterior_predictive()が呼び出され、MCMCまたはMAP推定でパラメタをシュミレーションするために、sample_model()が呼び出されます。この関数の中で、sample_predictive_trend()が呼び出され、未来の変化点の計算が行われ、未来のトレンドが返されます。

ここまで計算が進むと、少しデータフレームの情報を整えて、計算した値を元に予測値が計算されます。

#' @export
predict.prophet <- function(object, df = NULL, ...) {
  if (is.null(object$history)) {
    stop("Model must be fit before predictions can be made.")
  }
  if (is.null(df)) {
    df <- object$history
  } else {
    if (nrow(df) == 0) {
      stop("Dataframe has no rows.")
    }
    out <- setup_dataframe(object, df)
    df <- out$df
  }

  df$trend <- predict_trend(object, df)
  seasonal.components <- predict_seasonal_components(object, df)
  if (object$uncertainty.samples) {
    intervals <- predict_uncertainty(object, df)
  } else {
    intervals <- NULL
    }

  # Drop columns except ds, cap, floor, and trend
  cols <- c('ds', 'trend')
  if ('cap' %in% colnames(df)) {
    cols <- c(cols, 'cap')
  }
  if (object$logistic.floor) {
    cols <- c(cols, 'floor')
  }
  df <- df[cols]
  df <- dplyr::bind_cols(df, seasonal.components, intervals)
  df$yhat <- df$trend * (1 + df$multiplicative_terms) + df$additive_terms
  return(df)
}

ここでは主にトレンドの計算の部分、将来の変化点を計算する部分の実装見ていくことにします。

4 piecewise_logistic関数

この関数でロジスティクトレンドの計算を行っています。下記の数式に対応する部分です。

\[ g(t) = \frac{C_{t}}{ 1 + exp(-(k + \boldsymbol{a}(t) \cdot \boldsymbol{\delta^{ \mathrm{T}}}) (t - (m + \boldsymbol{a}(t) \cdot \boldsymbol{\gamma^{ \mathrm{T}}}))} \] \[ \gamma_{j} = \left(s_{j} - m - \displaystyle \sum_{l \lt j} \gamma_{l}\right) \cdot \left(1 - \frac{ k + \displaystyle \sum_{l \lt j} \delta_{l} }{k + \displaystyle \sum_{l \le j} \delta_{l}} \right) \]

#' @keywords internal
piecewise_logistic <- function(t, cap, deltas, k, m, changepoint.ts) {
  # Compute offset changes
  k.cum <- c(k, cumsum(deltas) + k)
  gammas <- rep(0, length(changepoint.ts))
  for (i in 1:length(changepoint.ts)) {
    gammas[i] <- ((changepoint.ts[i] - m - sum(gammas))
                  * (1 - k.cum[i] / k.cum[i + 1]))
  }
  # Get cumulative rate and offset at each t
  k_t <- rep(k, length(t))
  m_t <- rep(m, length(t))
  for (s in 1:length(changepoint.ts)) {
    indx <- t >= changepoint.ts[s]
    k_t[indx] <- k_t[indx] + deltas[s]
    m_t[indx] <- m_t[indx] + gammas[s]
    
 # 著書が追加したブロック。下記のようにすれば、値の変化が確認できる。
 # -------------------------------------
 # dd <- data.frame(t = 1:20, indx, k_t = round(k_t,10)) %>% print(dd)
 #-------------------------------------
 
  }
  y <- cap / (1 + exp(-k_t * (t - m_t)))
  return(y)
}

# 下記の設定でpiecewise_logistic()を実行できる
t <- time_diff(d$ds, m$start, "secs") / m$t.scale
cap <- 30
k <- mean(m$params$k, na.rm = TRUE)
deltas <- colMeans(m$params$delta, na.rm = TRUE)
param.m <- mean(m$params$m, na.rm = TRUE)
changepoint.ts <- m$changepoints.t
piecewise_logistic(t, cap, deltas, k, param.m, changepoint.ts)
##  [1] 15.63167 15.68140 15.73111 15.78079 15.83026 15.87972 15.92916 15.97867
##  [9] 16.84364 17.69652 18.53175 19.34442 20.13017 20.88529 21.60665 20.75587
## [17] 19.85922 18.92201 17.95074 16.95303

モデルに格納されているdeltaの値を使えば、下記のProphetのドキュメントにあるような変化率のグラフを作成することもできます。

df_delta <- tibble::tibble(
  changepoints= m$changepoints,
  changepoints_t = m$changepoints.t,
  delta = colMeans(m$params$delta),
  cumsum_delta = cumsum(colMeans(m$params$delta))
)

d %>% 
  dplyr::left_join(df_delta, by = c("ds" = "changepoints")) %>% 
  ggplot() + 
  geom_bar(aes(ds, delta), stat = "identity", fill = "#0072B2") + 
  geom_line(aes(ds, y)) +
  geom_hline(yintercept = 0) +
  scale_y_continuous(
    name = "y",
    limits = c(-30, 30),
    sec.axis = sec_axis(~ . * (1/100), name = "Rate change")
  ) +
  scale_x_datetime(date_breaks = '1 day', labels = scales::date_format("%Y-%m-%d")) + 
  xlab("Potential Changepoint") + 
  theme_bw() + 
  theme(axis.text.x = element_text(angle = 30, hjust = 1))

5 sample_predictive_trend関数

この関数が呼び出される前にMCMC、MAP推定のいずれかを行うか、加えてシュミレーション回数などが決定されており、そのiterationを受け取りながら計算が行われます。関数を分割して、どのような計算が行われていくのかたどっていきます。

まずは、モデルからパラメタの抽出が行われています。#>の部分は、変数に格納されている値を意味します。計算のイメージをしやすくするために、表示しているだけなので、乱数シードの固定も行っていません。

#' @keywords internal
sample_predictive_trend <- function(model, df, iteration) {
  k <- model$params$k[iteration]
  #> [1] 0.1262284
  
  param.m <- model$params$m[iteration]
  #> [1] -0.6676147
  
  deltas <- model$params$delta[iteration,]
  #>  [1] -5.557431e-07 -6.823703e-05 -4.602641e-04 -1.988231e-07 -1.101329e-07  2.463856e-04  2.085527e+00  4.540816e-04  3.364070e-06
  #> [10] -2.674883e-08 -1.236256e-07  4.914982e-08 -2.130730e-04 -4.809416e+00 -3.102756e-04

tsetup_dataframe()というインターナル関数によって付与されている値で、日付や時刻を秒数に換算し、学習期間のはじめを0、学習期間の終わりを1に調整している値が格納されています。その最大値をTに格納して、ポアソン過程を使って、未来の変化点changepoint.ts.newの計算を行っているようです。

  t <- df$t
  #>  [1] 0.00000000 0.05263158 0.10526316 0.15789474 0.21052632 0.26315789 0.31578947 0.36842105 0.42105263 0.47368421 0.52631579 0.57894737 0.63157895
  #> [14] 0.68421053 0.73684211 0.78947368 0.84210526 0.89473684 0.94736842 1.00000000 1.05263158 1.10526316 1.15789474 1.21052632 1.26315789
  
  T <- max(t)
  #> [1] 1.263158

下記の部分で下記の数式にあたる部分の計算を行っています。

\[ \begin{eqnarray} \forall j \gt T = \begin{cases} \delta_{j} = 0 \ w.p. \frac{T-S}{T}, \\ \delta_{j} \sim Laplace(0, \lambda)\ w.p. \frac{S}{T} \end{cases} \end{eqnarray} \]

下記では、\(T\)よりも未来の変化点\(\delta_{j}\)は、確率\(\frac{T-S}{T}\)\(\delta_{j}\)が0か、確率\(\frac{S}{T}\)\(\delta_{j}\)がラプラス分布に従う乱数によって生成されるかが決められ、ラプラス分布の乱数の場合、平均0、分散\(\lambda\)で変化率\(\delta_{j}\)となる乱数が生成されることなります。

  # New changepoints from a Poisson process with rate S on [1, T]
  if (T > 1) {
    S <- length(model$changepoints.t)
    #> [1] 15
    
    n.changes <- stats::rpois(1, S * (T - 1))
    #> [1] 3
  } else {
    n.changes <- 0
  }
  if (n.changes > 0) {
    changepoint.ts.new <- 1 + stats::runif(n.changes) * (T - 1)
    #> [1] 1.220517 1.008240 1.128195

    changepoint.ts.new <- sort(changepoint.ts.new)
    #> [1] 1.008240 1.128195 1.220517

  } else {
    changepoint.ts.new <- c()
  }

ここからは、deltasの平均をラプラス分布の分散として設定し、未来の変化率deltasを乱数で発生させています。その後は、changepoints.tdeltaの過去と未来をつなぎ合わせたベクトルを作っています。

  # Get the empirical scale of the deltas, plus epsilon to avoid NaNs.
  lambda <- mean(abs(c(deltas))) + 1e-8
  #> [1] 0.45978

  # Sample deltas
  deltas.new <- extraDistr::rlaplace(n.changes, mu = 0, sigma = lambda)
  #> [1] -0.0167830  0.2276684  0.2669462

  # Combine with changepoints from the history
  changepoint.ts <- c(model$changepoints.t, changepoint.ts.new)
  #>  [1] 0.05263158 0.10526316 0.15789474 0.21052632 0.26315789 0.31578947 0.36842105 0.42105263 0.47368421 0.52631579 0.57894737 0.63157895 0.68421053
  #> [14] 0.73684211 0.78947368 1.00824019 1.12819463 1.22051720

  deltas <- c(deltas, deltas.new)
  #>  [1] -5.557431e-07 -6.823703e-05 -4.602641e-04 -1.988231e-07 -1.101329e-07  2.463856e-04  2.085527e+00  4.540816e-04  3.364070e-06 -2.674883e-08
  #> [11] -1.236256e-07  4.914982e-08 -2.130730e-04 -4.809416e+00 -3.102756e-04 -1.678300e-02  2.276684e-01  2.669462e-01

未来の変化点と変化率の計算が終わっているので、トレンドを計算するためにpiecewise_logistic()が呼び出され、トレンドの予測部分の計算が完了します。

  # Get the corresponding trend
  if (model$growth == 'linear') {
    trend <- piecewise_linear(t, deltas, k, param.m, changepoint.ts)
  } else if (model$growth == 'flat') {
    trend <- flat_trend(t, param.m)
  } else if (model$growth == 'logistic') {
  
    cap <- df$cap_scaled
    #>  [1] 1.153846 1.153846 1.153846 1.153846 1.153846 1.153846 1.153846 1.153846 1.153846 1.153846 1.153846 1.153846 1.153846 1.153846 1.153846 1.153846
    #> [17] 1.153846 1.153846 1.153846 1.153846 1.153846 1.153846 1.153846 1.153846 1.153846
    
    trend <- piecewise_logistic(t, cap, deltas, k, param.m, changepoint.ts)
    #>  [1] 0.6012179 0.6031307 0.6050428 0.6069533 0.6088562 0.6107584 0.6126599 0.6145642 0.6478324 0.6806353 0.7127596 0.7440162 0.7742375 0.8032803 0.8310250
    #> [16] 0.7983027 0.7638160 0.7277696 0.6904131 0.6520397 0.6127640 0.5731120 0.5354359 0.4995907 0.4675052
  }
  return(trend * model$y.scale + df$floor)
  #>  [1] 15.63167 15.68140 15.73111 15.78079 15.83026 15.87972 15.92916 15.97867 16.84364 17.69652 18.53175 19.34442 20.13017 20.88529 21.60665 20.75587
  #> [17] 19.85922 18.92201 17.95074 16.95303 15.93186 14.90091 13.92133 12.98936 12.15513
}

これ以降は、piecewise_logistic()sample_predictive_trend()が実行される以前の部分で使用されている関数をいくつかピックアップしておきます。

6 setup_dataframe関数

この関数の目的は、学習済みモデルから予測値を計算するために必要な情報を、データフレームに付与することです。

#' @keywords internal
setup_dataframe <- function(m, df, initialize_scales = FALSE) {
  if (exists('y', where=df)) {
    df$y <- as.numeric(df$y)
    if (any(is.infinite(df$y))) {
      stop("Found infinity in column y.")
    }
  }
  df$ds <- set_date(df$ds)
  if (anyNA(df$ds)) {
    stop(paste('Unable to parse date format in column ds. Convert to date ',
               'format (%Y-%m-%d or %Y-%m-%d %H:%M:%S) and check that there',
               'are no NAs.'))
  }
  for (name in names(m$extra_regressors)) {
    if (!(name %in% colnames(df))) {
      stop('Regressor "', name, '" missing from dataframe')
    }
    df[[name]] <- as.numeric(df[[name]])
    if (anyNA(df[[name]])) {
      stop('Found NaN in column ', name)
    }
  }
  for (name in names(m$seasonalities)) {
    condition.name = m$seasonalities[[name]]$condition.name
    if (!is.null(condition.name)) {
      if (!(condition.name %in% colnames(df))) {
        stop('Condition "', name, '" missing from dataframe')
      }
      if(!all(df[[condition.name]] %in% c(FALSE,TRUE))) {
        stop('Found non-boolean in column ', name)
      }
      df[[condition.name]] <- as.logical(df[[condition.name]])
    }
  }

  df <- df %>%
    dplyr::arrange(ds)

  m <- initialize_scales_fn(m, initialize_scales, df)

  if (m$logistic.floor) {
    if (!('floor' %in% colnames(df))) {
      stop("Expected column 'floor'.")
    }
  } else {
    df$floor <- 0
  }

  if (m$growth == 'logistic') {
    if (!(exists('cap', where=df))) {
      stop('Capacities must be supplied for logistic growth.')
    }
    if (any(df$cap <= df$floor)) {
      stop('cap must be greater than floor (which defaults to 0).')
    }
    df <- df %>%
      dplyr::mutate(cap_scaled = (cap - floor) / m$y.scale)
  }

  df$t <- time_diff(df$ds, m$start, "secs") / m$t.scale
  if (exists('y', where=df)) {
    df$y_scaled <- (df$y - df$floor) / m$y.scale
  }

  for (name in names(m$extra_regressors)) {
    props <- m$extra_regressors[[name]]
    df[[name]] <- (df[[name]] - props$mu) / props$std
  }
  return(list("m" = m, "df" = df))
}

途中辺りで出てくるcap_scaledは、capの値をyの最大値を利用して、スケールを調整した値を計算しています。

# > d$cap
# [1] 30
# > d$floor
# [1] 0
# > m$y.scale
# [1] 26
# > max(d$y)
# [1] 26
# (cap - floor) / m$y.scale
(30 - 0) / 26
## [1] 1.153846

先程sample_predictive_trend()の部分で出てきていたtは、学習の最初の時点から最後の時点までをtime_diff()で秒数に換算し、それらの値を最後の時点の秒数で割り戻すことで、最初の時点が0、最後の時点が1になるようにスケールを調整された値が格納されています。

次回はProphetが周期変動をどのように扱っているのか見ていきます。

7 セッション情報

sessionInfo()
## R version 4.0.3 (2020-10-10)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Big Sur 10.16
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] ja_JP.UTF-8/ja_JP.UTF-8/ja_JP.UTF-8/C/ja_JP.UTF-8/ja_JP.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] forcats_0.5.1   stringr_1.4.0   dplyr_1.0.7     purrr_0.3.4    
##  [5] readr_1.4.0     tidyr_1.1.3     tibble_3.1.3    ggplot2_3.3.3  
##  [9] tidyverse_1.3.0 prophet_1.0     rlang_0.4.10    Rcpp_1.0.6     
## 
## loaded via a namespace (and not attached):
##  [1] httr_1.4.2           jsonlite_1.7.2       modelr_0.1.8        
##  [4] RcppParallel_5.0.2   StanHeaders_2.21.0-7 assertthat_0.2.1    
##  [7] highr_0.8            stats4_4.0.3         cellranger_1.1.0    
## [10] yaml_2.2.1           pillar_1.6.2         backports_1.2.1     
## [13] glue_1.4.2           digest_0.6.27        rvest_0.3.6         
## [16] colorspace_2.0-0     htmltools_0.5.1.1    pkgconfig_2.0.3     
## [19] rstan_2.21.2         broom_0.7.9          haven_2.3.1         
## [22] scales_1.1.1         processx_3.5.2       farver_2.0.3        
## [25] generics_0.1.0       ellipsis_0.3.2       withr_2.4.1         
## [28] cli_3.0.1            magrittr_2.0.1       crayon_1.4.0        
## [31] readxl_1.3.1         evaluate_0.14        ps_1.5.0            
## [34] fs_1.5.0             fansi_0.4.2          xml2_1.3.2          
## [37] pkgbuild_1.2.0       textshaping_0.3.5    tools_4.0.3         
## [40] loo_2.4.1            prettyunits_1.1.1    hms_1.0.0           
## [43] lifecycle_1.0.0      matrixStats_0.58.0   extraDistr_1.9.1    
## [46] V8_3.4.0             munsell_0.5.0        reprex_1.0.0        
## [49] callr_3.7.0          compiler_4.0.3       systemfonts_1.0.2   
## [52] grid_4.0.3           rstudioapi_0.13      labeling_0.4.2      
## [55] rmarkdown_2.6        gtable_0.3.0         codetools_0.2-16    
## [58] inline_0.3.17        DBI_1.1.1            curl_4.3            
## [61] R6_2.5.0             gridExtra_2.3        lubridate_1.7.9.2   
## [64] knitr_1.33           utf8_1.1.4           ragg_1.1.3          
## [67] stringi_1.5.3        parallel_4.0.3       vctrs_0.3.8         
## [70] dbplyr_2.1.0         tidyselect_1.1.0     xfun_0.24