UPDATE: 2025-08-03 18:25:54.783912
長ったらしいタイトルそのままだが、prophetパッケージのprophet()
でローリングオリジン法でクロスバリデーションしながら複数モデルを比較する方法についての自分用のまとめ。
ローリングオリジン法の説明に前にとりあえずサンプルデータ。
n <- 24
set.seed(1989)
df <- tibble::tibble(ds = seq(from = as.Date("2010-01-01"), by = "1 month", length.out = n),
y = rnorm(n, 100, 1),
id = 1:n)
df
# A tibble: 24 x 3
ds y id
<date> <dbl> <int>
1 2010-01-01 101. 1
2 2010-02-01 101. 2
3 2010-03-01 98.2 3
4 2010-04-01 99.8 4
5 2010-05-01 99.4 5
6 2010-06-01 99.7 6
7 2010-07-01 100. 7
8 2010-08-01 101. 8
9 2010-09-01 100. 9
10 2010-10-01 99.3 10
# … with 14 more rows
ローリングオリジン法は、時系列データに対するクロスバリデーションのための方法で、モデルの学習データとテストデータをずらしながらモデルの性能を評価する方法。rsampleパッケージのrolling_origin()
を使えばデータを分割してくれる。イメージは下記の画像の通りで、上から順番に実装してみる。
df_rolling_origin1 <- rsample::rolling_origin(
data = df,
initial = 12,
assess = 6,
skip = 0,
cumulative = FALSE
)
df_rolling_origin1
# Rolling origin forecast resampling
# A tibble: 7 x 2
splits id
<list> <chr>
1 <split [12/6]> Slice1
2 <split [12/6]> Slice2
3 <split [12/6]> Slice3
4 <split [12/6]> Slice4
5 <split [12/6]> Slice5
6 <split [12/6]> Slice6
7 <split [12/6]> Slice7
詳細はこんな感じ。
n <- nrow(df_rolling_origin1)
res <- vector(mode = "list", length = n)
for (i in 1:n) {
r1 <- df_rolling_origin1$splits[[i]] %>% rsample::analysis() %>% head()
r2 <- df_rolling_origin1$splits[[i]] %>% rsample::analysis() %>% tail()
r3 <- df_rolling_origin1$splits[[i]] %>% rsample::assessment() %>% head()
r4 <- df_rolling_origin1$splits[[i]] %>% rsample::assessment() %>% tail()
res[[i]] <- rbind(r1,r2,r3,r4)
}
names(res) <- paste0("slice", 1:n)
res
$slice1
# A tibble: 4 x 3
ds y id
<date> <dbl> <int>
1 2010-01-01 101. 1
2 2010-12-01 101. 12
3 2011-01-01 99.7 13
4 2011-06-01 98.6 18
$slice2
# A tibble: 4 x 3
ds y id
<date> <dbl> <int>
1 2010-02-01 101. 2
2 2011-01-01 99.7 13
3 2011-02-01 102. 14
4 2011-07-01 101. 19
$slice3
# A tibble: 4 x 3
ds y id
<date> <dbl> <int>
1 2010-03-01 98.2 3
2 2011-02-01 102. 14
3 2011-03-01 97.6 15
4 2011-08-01 98.7 20
$slice4
# A tibble: 4 x 3
ds y id
<date> <dbl> <int>
1 2010-04-01 99.8 4
2 2011-03-01 97.6 15
3 2011-04-01 100. 16
4 2011-09-01 99.5 21
$slice5
# A tibble: 4 x 3
ds y id
<date> <dbl> <int>
1 2010-05-01 99.4 5
2 2011-04-01 100. 16
3 2011-05-01 100. 17
4 2011-10-01 102. 22
$slice6
# A tibble: 4 x 3
ds y id
<date> <dbl> <int>
1 2010-06-01 99.7 6
2 2011-05-01 100. 17
3 2011-06-01 98.6 18
4 2011-11-01 98.5 23
$slice7
# A tibble: 4 x 3
ds y id
<date> <dbl> <int>
1 2010-07-01 100. 7
2 2011-06-01 98.6 18
3 2011-07-01 101. 19
4 2011-12-01 97.7 24
cumulative = TRUE
とすると学習データの開始地点からスプリットごとにデータが累積されていく。
df_rolling_origin2 <- rsample::rolling_origin(
data = df,
initial = 12,
assess = 6,
skip = 0,
cumulative = TRUE
)
df_rolling_origin2
# Rolling origin forecast resampling
# A tibble: 7 x 2
splits id
<list> <chr>
1 <split [12/6]> Slice1
2 <split [13/6]> Slice2
3 <split [14/6]> Slice3
4 <split [15/6]> Slice4
5 <split [16/6]> Slice5
6 <split [17/6]> Slice6
7 <split [18/6]> Slice7
skip = 2
とするとスライドの幅を調整できる。2だと3ずれるが、これは考えるとすごく自然で、0でも学習とテストデータをずらすとなると0=1と考えることになるので、2だと3ずれることになる。
df_rolling_origin3 <- rsample::rolling_origin(
data = df,
initial = 12,
assess = 6,
skip = 2,
cumulative = FALSE
)
df_rolling_origin3
# Rolling origin forecast resampling
# A tibble: 3 x 2
splits id
<list> <chr>
1 <split [12/6]> Slice1
2 <split [12/6]> Slice2
3 <split [12/6]> Slice3
n <- nrow(df_rolling_origin3)
res <- vector(mode = "list", length = n)
for (i in 1:n) {
r1 <- df_rolling_origin3$splits[[i]] %>% rsample::analysis() %>% head()
r2 <- df_rolling_origin3$splits[[i]] %>% rsample::analysis() %>% tail()
r3 <- df_rolling_origin3$splits[[i]] %>% rsample::assessment() %>% head()
r4 <- df_rolling_origin3$splits[[i]] %>% rsample::assessment() %>% tail()
res[[i]] <- rbind(r1,r2,r3,r4)
}
names(res) <- paste0("slice", 1:n)
res
$slice1
# A tibble: 4 x 3
ds y id
<date> <dbl> <int>
1 2010-01-01 101. 1
2 2010-12-01 101. 12
3 2011-01-01 99.7 13
4 2011-06-01 98.6 18
$slice2
# A tibble: 4 x 3
ds y id
<date> <dbl> <int>
1 2010-04-01 99.8 4
2 2011-03-01 97.6 15
3 2011-04-01 100. 16
4 2011-09-01 99.5 21
$slice3
# A tibble: 4 x 3
ds y id
<date> <dbl> <int>
1 2010-07-01 100. 7
2 2011-06-01 98.6 18
3 2011-07-01 101. 19
4 2011-12-01 97.7 24
画像の最下部の再現は下記のように記述すればよい。
df_rolling_origin4 <- rsample::rolling_origin(
data = df,
initial = 12,
assess = 6,
skip = 2,
cumulative = TRUE
)
df_rolling_origin4
# Rolling origin forecast resampling
# A tibble: 3 x 2
splits id
<list> <chr>
1 <split [12/6]> Slice1
2 <split [15/6]> Slice2
3 <split [18/6]> Slice3
まずはローリングオリジン法をもとに1つのモデルでクロスバリデーションしてみる。ちょっと学習時間の関係上、ここでは24ヶ月分のデータで、分割は2個にしている。
n <- 24
set.seed(1989)
df <- tibble::tibble(ds = seq(from = as.Date("2010-01-01"), by = "1 month", length.out = n),
y = rnorm(n, 100, 1))
df_rolling_origin <- rsample::rolling_origin(
data = df,
initial = 12,
assess = 11,
skip = 0,
cumulative = FALSE
)
df_rolling_origin
# Rolling origin forecast resampling
# A tibble: 2 x 2
splits id
<list> <chr>
1 <split [12/11]> Slice1
2 <split [12/11]> Slice2
n <- nrow(df_rolling_origin)
res <- vector(mode = "list", length = n)
for (i in 1:n) {
r1 <- df_rolling_origin$splits[[i]] %>% rsample::analysis() %>% head()
r2 <- df_rolling_origin$splits[[i]] %>% rsample::analysis() %>% tail()
r3 <- df_rolling_origin$splits[[i]] %>% rsample::assessment() %>% head()
r4 <- df_rolling_origin$splits[[i]] %>% rsample::assessment() %>% tail()
res[[i]] <- rbind(r1,r2,r3,r4)
}
names(res) <- paste0("slice", 1:n)
res
$slice1
# A tibble: 4 x 2
ds y
<date> <dbl>
1 2010-01-01 101.
2 2010-12-01 101.
3 2011-01-01 99.7
4 2011-11-01 98.5
$slice2
# A tibble: 4 x 2
ds y
<date> <dbl>
1 2010-02-01 101.
2 2011-01-01 99.7
3 2011-02-01 102.
4 2011-12-01 97.7
最終的にはこのように掛けば、prophet::prophet(df = x, growth = "linear", n.changepoints = 1)
というモデルでローリングオリジン法のクロスバリデーションの結果が表示される。警告が出ているがここでは問題ではない。mean_rmse
は2分割されたデータの平均RMSEの値。これではさすがにメモにならないので、分割して見く。
df_rolling_origin %>%
dplyr::mutate(train = purrr::map(.x = splits, .f = function(x){rsample::analysis(x)}),
test = purrr::map(.x = splits, .f = function(x){rsample::assessment(x)}),
model = purrr::map(.x = train, .f = function(x){prophet::prophet(df = x, growth = "linear", n.changepoints = 1)}),
pred = purrr::map2(.x = model, .y = test, .f = function(x, y){predict(object = x, df = y)})
) %>%
tidyr::unnest(c(pred, test), names_repair = "minimal") %>%
dplyr::select("id","ds", "y", "yhat") %>%
dplyr:: group_by(id) %>%
yardstick::rmse(y, yhat) %>%
dplyr::summarise(mean_rmse = mean(.estimate))
Disabling yearly seasonality. Run prophet with yearly.seasonality=TRUE to override this.
Disabling weekly seasonality. Run prophet with weekly.seasonality=TRUE to override this.
Disabling daily seasonality. Run prophet with daily.seasonality=TRUE to override this.
Disabling yearly seasonality. Run prophet with yearly.seasonality=TRUE to override this.
Disabling weekly seasonality. Run prophet with weekly.seasonality=TRUE to override this.
Disabling daily seasonality. Run prophet with daily.seasonality=TRUE to override this.
# A tibble: 1 x 1
mean_rmse
<dbl>
1 2.71
まずは分割済みのデータを受け取って、それをslice
ごとに、学習データ、テストデータに分けて列に保存している。
df_rolling_origin %>%
dplyr::mutate(train = purrr::map(.x = splits, .f = function(x){rsample::analysis(x)}),
test = purrr::map(.x = splits, .f = function(x){rsample::assessment(x)}))
# Rolling origin forecast resampling
# A tibble: 2 x 4
splits id train test
<list> <chr> <list> <list>
1 <split [12/11]> Slice1 <tibble [12 × 2]> <tibble [11 × 2]>
2 <split [12/11]> Slice2 <tibble [12 × 2]> <tibble [11 × 2]>
学習データとテストデータを展開するとこうなる。ローリングオリジン法で分割しているので、月がずれていることが確認できる。
# 学習データを展開
df_rolling_origin %>%
dplyr::mutate(train = purrr::map(.x = splits, .f = function(x){rsample::analysis(x)}),
test = purrr::map(.x = splits, .f = function(x){rsample::assessment(x)})) %>%
tidyr::unnest(train, names_repair = "minimal") %>%
print(n = 100)
# A tibble: 24 x 5
splits id ds y test
<list> <chr> <date> <dbl> <list>
1 <split [12/11]> Slice1 2010-01-01 101. <tibble [11 × 2]>
2 <split [12/11]> Slice1 2010-02-01 101. <tibble [11 × 2]>
3 <split [12/11]> Slice1 2010-03-01 98.2 <tibble [11 × 2]>
4 <split [12/11]> Slice1 2010-04-01 99.8 <tibble [11 × 2]>
5 <split [12/11]> Slice1 2010-05-01 99.4 <tibble [11 × 2]>
6 <split [12/11]> Slice1 2010-06-01 99.7 <tibble [11 × 2]>
7 <split [12/11]> Slice1 2010-07-01 100. <tibble [11 × 2]>
8 <split [12/11]> Slice1 2010-08-01 101. <tibble [11 × 2]>
9 <split [12/11]> Slice1 2010-09-01 100. <tibble [11 × 2]>
10 <split [12/11]> Slice1 2010-10-01 99.3 <tibble [11 × 2]>
11 <split [12/11]> Slice1 2010-11-01 101. <tibble [11 × 2]>
12 <split [12/11]> Slice1 2010-12-01 101. <tibble [11 × 2]>
13 <split [12/11]> Slice2 2010-02-01 101. <tibble [11 × 2]>
14 <split [12/11]> Slice2 2010-03-01 98.2 <tibble [11 × 2]>
15 <split [12/11]> Slice2 2010-04-01 99.8 <tibble [11 × 2]>
16 <split [12/11]> Slice2 2010-05-01 99.4 <tibble [11 × 2]>
17 <split [12/11]> Slice2 2010-06-01 99.7 <tibble [11 × 2]>
18 <split [12/11]> Slice2 2010-07-01 100. <tibble [11 × 2]>
19 <split [12/11]> Slice2 2010-08-01 101. <tibble [11 × 2]>
20 <split [12/11]> Slice2 2010-09-01 100. <tibble [11 × 2]>
21 <split [12/11]> Slice2 2010-10-01 99.3 <tibble [11 × 2]>
22 <split [12/11]> Slice2 2010-11-01 101. <tibble [11 × 2]>
23 <split [12/11]> Slice2 2010-12-01 101. <tibble [11 × 2]>
24 <split [12/11]> Slice2 2011-01-01 99.7 <tibble [11 × 2]>
# テストデータを展開
df_rolling_origin %>%
dplyr::mutate(train = purrr::map(.x = splits, .f = function(x){rsample::analysis(x)}),
test = purrr::map(.x = splits, .f = function(x){rsample::assessment(x)})) %>%
tidyr::unnest(test, names_repair = "minimal") %>%
print(n = 100)
# A tibble: 22 x 5
splits id train ds y
<list> <chr> <list> <date> <dbl>
1 <split [12/11]> Slice1 <tibble [12 × 2]> 2011-01-01 99.7
2 <split [12/11]> Slice1 <tibble [12 × 2]> 2011-02-01 102.
3 <split [12/11]> Slice1 <tibble [12 × 2]> 2011-03-01 97.6
4 <split [12/11]> Slice1 <tibble [12 × 2]> 2011-04-01 100.
5 <split [12/11]> Slice1 <tibble [12 × 2]> 2011-05-01 100.
6 <split [12/11]> Slice1 <tibble [12 × 2]> 2011-06-01 98.6
7 <split [12/11]> Slice1 <tibble [12 × 2]> 2011-07-01 101.
8 <split [12/11]> Slice1 <tibble [12 × 2]> 2011-08-01 98.7
9 <split [12/11]> Slice1 <tibble [12 × 2]> 2011-09-01 99.5
10 <split [12/11]> Slice1 <tibble [12 × 2]> 2011-10-01 102.
11 <split [12/11]> Slice1 <tibble [12 × 2]> 2011-11-01 98.5
12 <split [12/11]> Slice2 <tibble [12 × 2]> 2011-02-01 102.
13 <split [12/11]> Slice2 <tibble [12 × 2]> 2011-03-01 97.6
14 <split [12/11]> Slice2 <tibble [12 × 2]> 2011-04-01 100.
15 <split [12/11]> Slice2 <tibble [12 × 2]> 2011-05-01 100.
16 <split [12/11]> Slice2 <tibble [12 × 2]> 2011-06-01 98.6
17 <split [12/11]> Slice2 <tibble [12 × 2]> 2011-07-01 101.
18 <split [12/11]> Slice2 <tibble [12 × 2]> 2011-08-01 98.7
19 <split [12/11]> Slice2 <tibble [12 × 2]> 2011-09-01 99.5
20 <split [12/11]> Slice2 <tibble [12 × 2]> 2011-10-01 102.
21 <split [12/11]> Slice2 <tibble [12 × 2]> 2011-11-01 98.5
22 <split [12/11]> Slice2 <tibble [12 × 2]> 2011-12-01 97.7
次はこの状態からフィッティングを行う。prophet()
にtrain
データを渡して学習している。
df_rolling_origin %>%
dplyr::mutate(train = purrr::map(.x = splits, .f = function(x){rsample::analysis(x)}),
test = purrr::map(.x = splits, .f = function(x){rsample::assessment(x)}),
model = purrr::map(.x = train, .f = function(x){prophet::prophet(df = x, growth = "linear", n.changepoints = 1)}))
Disabling yearly seasonality. Run prophet with yearly.seasonality=TRUE to override this.
Disabling weekly seasonality. Run prophet with weekly.seasonality=TRUE to override this.
Disabling daily seasonality. Run prophet with daily.seasonality=TRUE to override this.
Disabling yearly seasonality. Run prophet with yearly.seasonality=TRUE to override this.
Disabling weekly seasonality. Run prophet with weekly.seasonality=TRUE to override this.
Disabling daily seasonality. Run prophet with daily.seasonality=TRUE to override this.
# Rolling origin forecast resampling
# A tibble: 2 x 5
splits id train test model
<list> <chr> <list> <list> <list>
1 <split [12/11]> Slice1 <tibble [12 × 2]> <tibble [11 × 2]> <prophet [32]>
2 <split [12/11]> Slice2 <tibble [12 × 2]> <tibble [11 × 2]> <prophet [32]>
そして、学習したモデルを使って、test
データの期間で予測を実行。
df_rolling_origin %>%
dplyr::mutate(train = purrr::map(.x = splits, .f = function(x){rsample::analysis(x)}),
test = purrr::map(.x = splits, .f = function(x){rsample::assessment(x)}),
model = purrr::map(.x = train, .f = function(x){prophet::prophet(df = x, growth = "linear", n.changepoints = 1)}),
pred = purrr::map2(.x = model, .y = test, .f = function(x, y){predict(object = x, df = y)})
)
Disabling yearly seasonality. Run prophet with yearly.seasonality=TRUE to override this.
Disabling weekly seasonality. Run prophet with weekly.seasonality=TRUE to override this.
Disabling daily seasonality. Run prophet with daily.seasonality=TRUE to override this.
Disabling yearly seasonality. Run prophet with yearly.seasonality=TRUE to override this.
Disabling weekly seasonality. Run prophet with weekly.seasonality=TRUE to override this.
Disabling daily seasonality. Run prophet with daily.seasonality=TRUE to override this.
# Rolling origin forecast resampling
# A tibble: 2 x 6
splits id train test model pred
<list> <chr> <list> <list> <list> <list>
1 <split [12/11]> Slice1 <tibble [12 × 2]> <tibble [11 × 2]> <prophet [32]> <tibble [11 × 16]>
2 <split [12/11]> Slice2 <tibble [12 × 2]> <tibble [11 × 2]> <prophet [32]> <tibble [11 × 16]>
pred
を展開するとこうなる。各テストデータに対して予測値が返される。names_repair = "minimal"
は、unnest()
した際に両方にds
という日付のカラムが含まれるが、これを名前を変えずにそのままデータの左基準で先にあるds
を取得するため。危ない方法ではある。
df_rolling_origin %>%
dplyr::mutate(train = purrr::map(.x = splits, .f = function(x){rsample::analysis(x)}),
test = purrr::map(.x = splits, .f = function(x){rsample::assessment(x)}),
model = purrr::map(.x = train, .f = function(x){prophet::prophet(df = x, growth = "linear", n.changepoints = 1)}),
pred = purrr::map2(.x = model, .y = test, .f = function(x, y){predict(object = x, df = y)})
) %>%
tidyr::unnest(c(pred, test), names_repair = "minimal") %>%
dplyr::select("id","ds", "y", "yhat") %>%
print(n = 100)
Disabling yearly seasonality. Run prophet with yearly.seasonality=TRUE to override this.
Disabling weekly seasonality. Run prophet with weekly.seasonality=TRUE to override this.
Disabling daily seasonality. Run prophet with daily.seasonality=TRUE to override this.
Disabling yearly seasonality. Run prophet with yearly.seasonality=TRUE to override this.
Disabling weekly seasonality. Run prophet with weekly.seasonality=TRUE to override this.
Disabling daily seasonality. Run prophet with daily.seasonality=TRUE to override this.
# A tibble: 22 x 4
id ds y yhat
<chr> <date> <dbl> <dbl>
1 Slice1 2011-01-01 99.7 101.
2 Slice1 2011-02-01 102. 102.
3 Slice1 2011-03-01 97.6 102.
4 Slice1 2011-04-01 100. 102.
5 Slice1 2011-05-01 100. 102.
6 Slice1 2011-06-01 98.6 103.
7 Slice1 2011-07-01 101. 103.
8 Slice1 2011-08-01 98.7 103.
9 Slice1 2011-09-01 99.5 104.
10 Slice1 2011-10-01 102. 104.
11 Slice1 2011-11-01 98.5 104.
12 Slice2 2011-02-01 102. 101.
13 Slice2 2011-03-01 97.6 101.
14 Slice2 2011-04-01 100. 101.
15 Slice2 2011-05-01 100. 101.
16 Slice2 2011-06-01 98.6 101.
17 Slice2 2011-07-01 101. 101.
18 Slice2 2011-08-01 98.7 101.
19 Slice2 2011-09-01 99.5 101.
20 Slice2 2011-10-01 102. 101.
21 Slice2 2011-11-01 98.5 101.
22 Slice2 2011-12-01 97.7 101.
あとは、Slice
ごとにRMSEを計算して、最終的にそれらを平均する。
df_rolling_origin %>%
dplyr::mutate(train = purrr::map(.x = splits, .f = function(x){rsample::analysis(x)}),
test = purrr::map(.x = splits, .f = function(x){rsample::assessment(x)}),
model = purrr::map(.x = train, .f = function(x){prophet::prophet(df = x, growth = "linear", n.changepoints = 1)}),
pred = purrr::map2(.x = model, .y = test, .f = function(x, y){predict(object = x, df = y)})
) %>%
tidyr::unnest(c(pred, test), names_repair = "minimal") %>%
dplyr::select("id","ds", "y", "yhat") %>%
dplyr:: group_by(id) %>%
yardstick::rmse(y, yhat) %>%
print(n = 100)
# A tibble: 2 x 4
id .metric .estimator .estimate
<chr> <chr> <chr> <dbl>
1 Slice1 rmse standard 3.30
2 Slice2 rmse standard 2.13
ということで、複数のモデルでローリングオリジン法をもとにクロスバリデーションしてみる。方針としては、1つのモデルでクロスバリデーションするパイプラインをさらにpurrr::map()
でラップして、関数化するという方針でやっていく。なので、まずは複数のモデルをリストに格納する。ここでは変化点を的にずらしたモデルを用意する。
models <- list(
p1 = function(data) {
model <- prophet::prophet(df = data, growth = "linear", n.changepoints = 1)},
p2 = function(data) {
model <- prophet::prophet(df = data, growth = "linear", n.changepoints = 2)},
p3 = function(data) {
model <- prophet::prophet(df = data, growth = "linear", n.changepoints = 3)},
p4 = function(data) {
model <- prophet::prophet(df = data, growth = "linear", n.changepoints = 4)},
p5 = function(data) {
model <- prophet::prophet(df = data, growth = "linear", n.changepoints = 5)},
p6 = function(data) {
model <- prophet::prophet(df = data, growth = "linear", n.changepoints = 6)},
p7 = function(data) {
model <- prophet::prophet(df = data, growth = "linear", n.changepoints = 7)},
p8 = function(data) {
model <- prophet::prophet(df = data, growth = "linear", n.changepoints = 8)}
)
models
$p1
function(data) {
model <- prophet::prophet(df = data, growth = "linear", n.changepoints = 1)}
$p2
function(data) {
model <- prophet::prophet(df = data, growth = "linear", n.changepoints = 2)}
$p3
function(data) {
model <- prophet::prophet(df = data, growth = "linear", n.changepoints = 3)}
$p4
function(data) {
model <- prophet::prophet(df = data, growth = "linear", n.changepoints = 4)}
$p5
function(data) {
model <- prophet::prophet(df = data, growth = "linear", n.changepoints = 5)}
$p6
function(data) {
model <- prophet::prophet(df = data, growth = "linear", n.changepoints = 6)}
$p7
function(data) {
model <- prophet::prophet(df = data, growth = "linear", n.changepoints = 7)}
$p8
function(data) {
model <- prophet::prophet(df = data, growth = "linear", n.changepoints = 8)}
では雑に関数を作成する。元のデータを渡せば、クロスバリデーション用に分割する作業も関数内で行う。変わったのはfit = purrr::map(.x = train, .f = function(x){model(x)})
の部分で、ここではモデルを渡せるようにしている。
multi_prophets <- function(data, list_models, set_initial, set_assess, set_skip, set_cumulative){
df_rolling_origin <- rsample::rolling_origin(
data = data,
initial = set_initial,
assess = set_assess,
skip = set_skip,
cumulative = set_cumulative)
purrr::map(.x = list_models, .f = function(model){
df_rolling_origin %>%
dplyr::mutate(train = purrr::map(.x = splits, .f = function(x){rsample::analysis(x)}),
test = purrr::map(.x = splits, .f = function(x){rsample::assessment(x)}),
fit = purrr::map(.x = train, .f = function(x){model(x)}),
pred = purrr::map2(.x = fit, .y = test, .f = function(x, y){predict(object = x, df = y)})) %>%
tidyr::unnest(c(pred, test), names_repair = "minimal") %>%
dplyr::select("id","ds", "y", "yhat") %>%
dplyr:: group_by(id) %>%
yardstick::rmse(y, yhat) %>%
dplyr::summarise(mean_rmse = mean(.estimate))
})
}
これを実行する。
res <- multi_prophets(
data = df,
list_models = models,
set_initial = 12,
set_assess = 11,
set_skip = 0,
set_cumulative = FALSE
)
res
$p1
# A tibble: 1 x 1
mean_rmse
<dbl>
1 2.71
$p2
# A tibble: 1 x 1
mean_rmse
<dbl>
1 2.51
$p3
# A tibble: 1 x 1
mean_rmse
<dbl>
1 2.45
$p4
# A tibble: 1 x 1
mean_rmse
<dbl>
1 2.61
$p5
# A tibble: 1 x 1
mean_rmse
<dbl>
1 2.62
$p6
# A tibble: 1 x 1
mean_rmse
<dbl>
1 2.46
$p7
# A tibble: 1 x 1
mean_rmse
<dbl>
1 2.61
$p8
# A tibble: 1 x 1
mean_rmse
<dbl>
1 2.62
見にくいのでまとめる。
res %>%
purrr::reduce(.x = ., .f = dplyr::bind_rows) %>%
dplyr::mutate(model = names(res)) %>%
dplyr::select(model, mean_rmse) %>%
dplyr::arrange(mean_rmse)
# A tibble: 8 x 2
model mean_rmse
<chr> <dbl>
1 p3 2.45
2 p6 2.46
3 p2 2.51
4 p7 2.61
5 p4 2.61
6 p5 2.62
7 p8 2.62
8 p1 2.71
以上です。実際に書いてみると、何だが使い勝手が悪そうでですが、prophet()
以外でもちょっと修正すれば状態空間モデルだろうとARIMA系にも使えるので、なんとも言えない。さっき知ったがtidymodel系統のmodeltimeという便利な時系列データの機械学習パッケージがあるそうなので、Introducing
Modeltime: Tidy Time Series Forecasting using
Tidymodels今度やってみよう。
おまけ。{prophet}にはcross_validation()
という便利な関数がある。こんな感じでローリングオリジンCVを実行している。
cross_validation
function (model, horizon, units, period = NULL, initial = NULL)
{
// 略:下記でもここでは不要な部分に該当する箇所はカット
cutoffs <- generate_cutoffs(df, horizon.dt, initial.dt, period.dt)
predicts <- data.frame()
for (i in 1:length(cutoffs)) {
# カットオフの値を代入
cutoff <- cutoffs[i]
# モデルのコピー
m <- prophet_copy(model, cutoff)
# カットオフ以前のデータを取得
history.c <- dplyr::filter(df, ds <= cutoff)
# モデルの実行
fit.args <- c(list(m = m, df = history.c), model$fit.kwargs)
m <- do.call(fit.prophet, fit.args)
# カットオフ+ホライゾン日数まで、カットオフより先のデータを取得する、
df.predict <- dplyr::filter(df, ds > cutoff, ds <= cutoff + horizon.dt)
columns <- "ds"
future <- df.predict[columns]
yhat <- stats::predict(m, future)
df.c <- dplyr::inner_join(df.predict, yhat[predict_columns], by = "ds")
df.c <- df.c[c(predict_columns, "y")]
df.c <- dplyr::select(df.c, y, predict_columns)
df.c$cutoff <- cutoff
# 予測結果を格納
predicts <- rbind(predicts, df.c)
# リープでカットオフのベクトルの長さ分行う
}
return(predicts)
}
<bytecode: 0x7fb79bf01650>
<environment: namespace:prophet>
時系列ではよくMAPEという指標でモデルの評価が行われる。MAPEは1周間の予測結果があったとして、その日毎に実測値に対して、予測値が何%ずれたかを計算(|y-yhat|/y)。その計算した値を合計し、長さで割ることでMAPEを計算する。例えば、MAPEが1%であれば、特定の日の大きなハズレ値に引っ張られていない限り、毎日だいたいその日の実測値に対して予測値は1%くらいずれると解釈できる。
EC2で遊んでた時のメモ。Prophet関係なので、ここのメモを残しておく。
# https://datetime360.com/ja/utc-jst-time/
# JST:1:00 = UTC:16:00
# 0 15 * * * /usr/bin/Rscript "/home/ruser/pred_weight.R" >> /home/ruser/utils/pred_exec_err.log 2>&1
print('------Logging Start------')
Sys.time()
as.POSIXlt(Sys.time(), format = "%Y-%m-%d %H:%M:%S", tz = "Asia/Tokyo")
library(googlesheets4)
library(lubridate)
library(dplyr)
library(yardstick)
library(rstan)
library(prophet)
library(chatr) # remotes::install_github("SugiAki1989/chatr")
# Setting credentials for chatworkAPI & GoogleSpreadSheetAPI
# You should crate SERVICE ACCOUNT in GCP
# & Share Auth of GoogleSpreadSheet to SERVICE ACCOUNT
# & Enable Spreadseet API
chatr_setup(config_file_path = "<path_to_config.yml>")
gs4_auth(
scopes = 'https://www.googleapis.com/auth/spreadsheets',
path =<path_to_config.json>
)
chatr::chatr(code = FALSE, "Start forecasting")
print('------Logging Start------')
# Read weight data
sheet_id <- "**************************"
google_spread_sheet <- googlesheets4::read_sheet(
ss = sheet_id,
range = "sheet1") %>%
dplyr::select(-unit)
# to check
# google_spread_sheet %>% tail()
# Modify date format
# pre(chracter): July 5, 2020 at 10:44PM
# post(date): 2020-07-05
df <- google_spread_sheet %>%
dplyr::mutate(
date0 = stringr::str_replace(date, pattern = '[:space:]at.*', ''),
date1 = stringr::str_replace(date0, pattern = ',', '')
) %>%
tidyr::separate(date1, c("mm", "dd", "yyyy")) %>%
dplyr::mutate(date_mod = lubridate::ymd(paste0(yyyy,"/",mm,"/",dd))) %>%
dplyr::select(date_mod, weight)
# Make is_holiday flg
# wday() retunrs 1:Sun, 6:Fir, 7:Sat
df2model <- df %>%
dplyr::mutate(
date_mod_wday = lubridate::wday(date_mod),
is_holiday = dplyr::if_else(date_mod_wday %in% c(1, 6, 7), 1, 0)
) %>%
dplyr::select(ds = date_mod, y = weight, is_holiday)
# Fitting Model
m <- prophet::prophet(weekly.seasonality = TRUE,)
m <- prophet::add_regressor(m, "is_holiday")
m <- prophet::fit.prophet(m, df2model)
# Evaluate Fitted model
# Ex:
# - intial :89 means 2020/10/01(start forcast date)
# - horizon: 7 means 1 week forcast interval
# - period : 7 means 1 week rolling origin slide
df_cv <- prophet::cross_validation(model = m,
initial = nrow(df2model)/2,
period = 7,
horizon = 7,
units = "days")
cvdt_mst <- df_cv %>%
dplyr::group_by(cutoff) %>%
dplyr::summarise(cv_start_dt = min(ds),
cv_end_dt = max(ds)) %>%
dplyr::mutate(cv_start_dt = lubridate::ymd(cv_start_dt),
cv_end_dt = lubridate::ymd(cv_end_dt))
res_cv <- df_cv %>%
dplyr::group_by(cutoff) %>%
yardstick::mape(truth = y, estimate = yhat) %>%
dplyr::left_join(cvdt_mst, by = "cutoff") %>%
dplyr::mutate(cutoff = lubridate::ymd(cutoff)) %>%
dplyr::select(cutoff, cv_start_dt, cv_end_dt, .metric, .estimate)
res_cv_all <- res_cv %>%
dplyr::summarise(mape_avg = mean(.estimate))
# Predict value from model
future <- prophet::make_future_dataframe(m, periods = 7) %>%
dplyr::mutate(is_holiday = dplyr::if_else(lubridate::wday(ds) %in% c(1, 6, 7), 1, 0))
forecast <- predict(m, future)
# Make post data by chatr
post <- forecast %>%
dplyr::select("ds", "yhat", "yhat_lower", "yhat_upper") %>%
dplyr::mutate(is_fat = dplyr::if_else(yhat_upper >= 70.0, "Fat!", "Safe!")) %>%
tail(7)
# library(glue)
# from <- as.character(Sys.Date())
# to <- as.character(Sys.Date() + 7)
# info <- glue::glue("The forecasted weight for {from} ~ {to} is as follows.")
# Do post to chatwork
chatr::chatr(code = TRUE,
post,
res_cv,
res_cv_all)
print('------Logging End------')
Sys.time()
as.POSIXlt(Sys.time(), format = "%Y-%m-%d %H:%M:%S", tz = "Asia/Tokyo")