UPDATE: 2021-09-15 00:09:05
ここでは、下記のドキュメントを参考に、prophet
パッケージの基本的な使い方をおさらいすることを目的としています。ゆくゆくは外部予測変数を追加したモデルやクロスバリデーション、パラメタチューニングなどなど、モデルを発展させながら使用方法をまとめていきます。
モデルの数理部分は下記のprophet
に関する論文やブログ記事を参照願います。非常にわかりやすいです。
library(prophet)
library(tidyverse)
library(patchwork)
<- function(data, n = 5){
head_tail stopifnot(is.data.frame(data))
head(data, n = n) %>% print()
cat(paste(rep("-", 100), collapse = ""), "\n")
tail(data, n = n) %>% print()
}
Prophetではどのように変化点を検知しているのか、今回はその点について見ていきます。デフォルト設定では、自動的に変化点を検知するようになっていますが、データに過学習しているとき、変化を捉えられていないときなど、新製品の発売日やイベントなどの既知の日付を使用して設定することで変化点を調整することが可能です。
\[ g(t) = \frac{C_{t}}{ 1 + exp(-(k + \boldsymbol{a}(t) \cdot \boldsymbol{\delta^{ \mathrm{T}}}) (t - (m + \boldsymbol{a}(t) \cdot \boldsymbol{\gamma^{ \mathrm{T}}}))} \]
論文で説明されている内容を読むと、Prophetの変化点の自動検知は、大量の潜在的な変化点を可能な範囲で検知することから始まり、上記の式の変化率\(\delta_{j}\)に、スパース事前分布(ラプラス分布)を設定することで行われると説明されています。数年の時系列データであれば、1か月に1回程度で変化点を検知するとも書かれています。
\[ \delta_{j} \sim Laplace(0, \tau) \]
パラメタ\(\tau\) を調整することで、モデルの柔軟性を直接制御できるようになります。ラプラス分布は事前分布を用いることで、L1正則化と同様に機能し、パラメタ\(\delta\)が0になるように制約がかかります。これはラプラス分布を可視化するとわかりよいので、ラプラス分布を可視化してみます。ラプラス分布は\(\mu\)と\(\tau\)をパラメタに持ちます(\(\tau\)は\(\sigma\)と表記されていたりします)。
\[ f(x | \mu, \tau) = \frac{1}{2 \tau} exp\left(- \frac{ |x - \mu| }{ \tau } \right) \]
Rには、デフォルトでは関数が用意されていないので、確率密度関数を書いて可視化することにします。
<- function(x, mu, tau){
drlaplace <- 1/(2 * tau) * exp(-1 * (abs(x - mu)) / tau)
f return(f)
}
<- seq(-10, 10, 0.1)
x <- 0
mu <- c(0.1, 1:5)
tau
<- paste0('mu = ', mu, ', tau = ', tau)
pattern <- tibble::tibble(x)
df_plt for (i in 1:length(tau)) {
<- drlaplace(x, mu, tau[[i]])
y <- cbind(df_plt, y)
df_plt
}names(df_plt) <- c('x', pattern)
%>%
df_plt ::pivot_longer(cols = -x, names_to = 'patterns') %>%
tidyr::arrange(patterns, x) %>%
dplyrggplot(., aes(x, value, col = patterns, fill = patterns)) +
geom_line() +
facet_wrap( ~ patterns, scales = "free")
\(\tau\)が小さくなるほど、ほとんど0しか返さない確率密度関数であるため、それがスパースな分布と呼ばれる所以になっています。つまり、Prophetでは多めに潜在的な変化を検出するものの、変化率\(\delta_{j}\)にラプラス分布を設定していると、多くの変化点の変更率\(\delta_{j}\)はほとんど0になることが期待されます。
Prophet()
では、changepoint_prior_scale
で設定することができます。ドキュメントにも記載されていますが、値を大きくすると、多くの変化点が許容され、値が小さいと変化点が少なくなります。
デフォルト設定でProphetの変化点の自動検知機能を動かしていきます。ドキュメントで使用されているWikipediaのページビューデータの期間を2年分に絞って動かしていきます。
# Githubのprophetのリポジトリからサンプルデータをインポート
<- 'https://raw.githubusercontent.com/facebook/prophet/master/examples/'
base_url <- paste0(base_url, 'example_wp_log_peyton_manning.csv')
data_path <- readr::read_csv(data_path) %>%
df ::filter(ds >= "2011-01-01" & ds <= "2012-12-31")
dplyr
# 先頭と末尾のデータを表示
head_tail(df, n = 5)
## # A tibble: 5 × 2
## ds y
## <date> <dbl>
## 1 2011-01-01 9.01
## 2 2011-01-02 9.40
## 3 2011-01-03 9.99
## 4 2011-01-04 9.06
## 5 2011-01-05 8.97
## ----------------------------------------------------------------------------------------------------
## # A tibble: 5 × 2
## ds y
## <date> <dbl>
## 1 2012-12-27 8.97
## 2 2012-12-28 8.68
## 3 2012-12-29 8.53
## 4 2012-12-30 9.49
## 5 2012-12-31 10.1
changepoint_prior_scale
を変化させながら複数のモデルを構築するとわかり良いですが、changepoint_prior_scale
が大きくなるにつれて(ラプラス分布の裾が広がる)、多くの変化点が許容され、値が小さくなるにつれて(ラプラス分布の裾が狭くなる)、変化点が少なくなります。変化点はadd_changepoints_to_plot()
を利用することで可視化できます。
# changepoint_prior_scaleを変化させてモデルを作成
<- c(0.05, 0.1, 0.5, 1.0)
scale_vec <- length(scale_vec)
n_vec <- vector(mode = "list", length = n_vec)
plt_list
for (i in 1:n_vec) {
<- prophet(df = df, changepoint.prior.scale = scale_vec[[i]])
m <- make_future_dataframe(m, periods = 30)
future_df <- predict(m, future_df)
forecast_df <- paste0("changepoint.prior.scale = ", scale_vec[[i]])
plt_title <- plot(m, forecast_df) + add_changepoints_to_plot(m) +
plt_list[[i]] labs(title = plt_title)
}
1]] + plt_list[[2]] + plt_list[[3]] + plt_list[[4]] plt_list[[
実際の観測値に対して、変化点を捉える方法はわかりましたが、Prophetでは、未来の変化点をどのように扱っているでしょうか。その点も論文に記載されており、
論文の内容をざっくりとまとめると、モデルが外挿されて予測が行われる場合、予測トレンドの不確実性をモデルで推定するとのこと。トレンドの生成モデルは、\(T\)個時系列データに\(S\)個の変化点があり、それぞれに変化率\(\delta_{j} \sim Laplace(0, \tau)\)があるというもので、\(\tau\)をデータから推測される分散に置き換えることにより、過去の変化率を模倣する将来の変化率をシミュレートします。
ラプラス分布の分散\(\tau\)をベイズ推定を行った事後分布から得るか、\(\lambda = \frac{1}{S} \sum_{j=1}^{S} | \delta_{j}|\)を最尤法で推定します。この\(\lambda\)は過去の変化率\(\delta_{j}\)の絶対値の平均です。
将来の変化点は、変化点の平均頻度が過去のデータ内の頻度と一致するように、下記に従ってランダムにサンプリングされます。
\[ \begin{eqnarray} \forall j \gt T, = \begin{cases} \delta_{j} = 0 \ w.p. \frac{T-S}{T}, \\ \delta_{j} \sim Laplace(0, \lambda)\ w.p. \frac{S}{T} \end{cases} \end{eqnarray} \]
意味合いとしては、\(T\)よりも未来の変化点\(\delta_{j}\)は、確率\(\frac{T-S}{T}\)で\(\delta_{j}\)が0か、確率\(\frac{S}{T}\)で\(\delta_{j}\)がラプラス分布に従う乱数によって生成されるかが決められ、ラプラス分布の乱数の場合、平均0、分散\(\lambda\)で変化率\(\delta_{j}\)となる乱数が生成されることなります。
下記の再現コードは、前回同様、下記のブログを参考にさせていただきました。Pythonのコードを交えながら非常にわかりやすく、解説されているブログです。
自分の練習がてら、Rで再現コードを書いていますが、実装が誤っている場合、参考元ブログの誤りではなく、おそらく私の誤りである可能性が高いため、予めお断りさせていただきます。加え、次回以降の記事ではProphetパッケージのトレンドや未来の変化点の計算部分をスクリプトに沿って深ぼっているので、Rのスクリプトを勉強したいということであれば、そっちを見たほうが正確かつスクリプトも綺麗です。
<- function(t, C = 1, k = 1, m = 0, S, d){
g3 <- matrix(0, nrow = length(t), ncol = length(S))
a for (i in 1:length(t)) {
for (j in 1:length(S)) {
<- ifelse(S[[j]] < t[[i]], 1, 0)
a[i, j]
}
}
<- vector(mode = "numeric", length = length(S))
gamma for (j in 1:length(gamma)) {
= (S[j] - m - sum(gamma[0:(j-1)])) * (1 - ((k + sum(d[0:(j-1)])) / (k + sum(d[0:j]))))
gamma[j]
}
<- C / (1 + exp(-(k + (a %*% d)) * (t - (m + (a %*% gamma)))))
y
return(
list(data = data.frame(t, y),
gamma = gamma)
)
}
<- seq(1 ,100, length.out = 100)
t <- c(20, 60, 80) # change point time
S <- c(-0.03, 0.01, 0.02) # change rate for growth rate
delta
<- g3(
d t = t, # time
C = 1, # capacity
k = 0.01, # growth rate
m = 0, # offset
S = S, # change point
d = delta # change rate for growth rate
)
<- length(S) / length(t)
freq <- mean(abs(delta))
mu_delta
# Base様のブログの乱数で実行
# set.seed(5)
# occurrence <- rbinom(n = length(t), p = freq, size = 1)
<- c(0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
occurrence 0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0)
<- which(occurrence == 1) + (length(t))
generated_s # Base様のブログの乱数で実行
# generated_delta <- extraDistr::rlaplace(n = length(generated_s), mu = 0, sigma = mu_delta)
<- c(0.00053214, 0.00810678, -0.03104065)
generated_delta
# 上記のパラメタを使って、未来の変化点をシュミレーション
<- g3(
f t = seq(max(t), length.out = length(t)), # time
C = 1, # capacity
k = 0.01, # growth rate
m = d$gamma[length(d$gamma)],
S = generated_s, # change point
d = generated_delta # change rate for growth rate
)
$data$label <- "pre"
d$data$label <- "post"
fggplot(rbind(d$data, f$data), aes(t, y, col = label)) +
geom_line() +
geom_vline(xintercept = c(S, generated_s), col = "black", linetype = "dashed") +
scale_color_manual(values = c("#8BADCC", "#E56A73")) +
scale_y_continuous(breaks = seq(0.3, 0.7, 0.05), limits = c(0.3, 0.7)) +
scale_x_continuous(breaks = seq(0, 200, 10))
ここまでのおさらいで、周期性や祝日効果、外部予測変数を除いて、Prophetがトレンドをどのように計算し、どのように変化点を見つけ、未来の予測において、どのように振る舞うのか、概要が把握できました。次回はpredict()
を実行した際に、どのような計算が行われていくのかを見ていきます。
sessionInfo()
## R version 4.0.3 (2020-10-10)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Big Sur 10.16
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
##
## locale:
## [1] ja_JP.UTF-8/ja_JP.UTF-8/ja_JP.UTF-8/C/ja_JP.UTF-8/ja_JP.UTF-8
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] patchwork_1.1.1 forcats_0.5.1 stringr_1.4.0 dplyr_1.0.7
## [5] purrr_0.3.4 readr_1.4.0 tidyr_1.1.3 tibble_3.1.3
## [9] ggplot2_3.3.3 tidyverse_1.3.0 prophet_1.0 rlang_0.4.10
## [13] Rcpp_1.0.6
##
## loaded via a namespace (and not attached):
## [1] httr_1.4.2 jsonlite_1.7.2 modelr_0.1.8
## [4] StanHeaders_2.21.0-7 RcppParallel_5.0.2 assertthat_0.2.1
## [7] highr_0.8 stats4_4.0.3 cellranger_1.1.0
## [10] yaml_2.2.1 pillar_1.6.2 backports_1.2.1
## [13] glue_1.4.2 digest_0.6.27 rvest_0.3.6
## [16] colorspace_2.0-0 htmltools_0.5.1.1 pkgconfig_2.0.3
## [19] rstan_2.21.2 broom_0.7.9 haven_2.3.1
## [22] scales_1.1.1 processx_3.5.2 generics_0.1.0
## [25] farver_2.0.3 ellipsis_0.3.2 withr_2.4.1
## [28] cli_3.0.1 magrittr_2.0.1 crayon_1.4.0
## [31] readxl_1.3.1 evaluate_0.14 ps_1.5.0
## [34] fs_1.5.0 fansi_0.4.2 xml2_1.3.2
## [37] pkgbuild_1.2.0 textshaping_0.3.5 loo_2.4.1
## [40] tools_4.0.3 prettyunits_1.1.1 hms_1.0.0
## [43] matrixStats_0.58.0 lifecycle_1.0.0 extraDistr_1.9.1
## [46] V8_3.4.0 munsell_0.5.0 reprex_1.0.0
## [49] callr_3.7.0 compiler_4.0.3 systemfonts_1.0.2
## [52] grid_4.0.3 rstudioapi_0.13 labeling_0.4.2
## [55] rmarkdown_2.6 gtable_0.3.0 codetools_0.2-16
## [58] inline_0.3.17 DBI_1.1.1 curl_4.3
## [61] R6_2.5.0 gridExtra_2.3 lubridate_1.7.9.2
## [64] knitr_1.33 utf8_1.1.4 ragg_1.1.3
## [67] stringi_1.5.3 parallel_4.0.3 vctrs_0.3.8
## [70] dbplyr_2.1.0 tidyselect_1.1.0 xfun_0.24