UPDATE: 2022-12-16 23:28:17
tidymodels
パッケージの使い方をいくつかのノートに分けてまとめている。tidymodels
パッケージは、統計モデルや機械学習モデルを構築するために必要なパッケージをコレクションしているパッケージで、非常に色んなパッケージがある。ここでは、今回はrsample
というパッケージの使い方をまとめていく。モデルの数理的な側面や機械学習の用語などは、このノートでは扱わない。
下記の公式ドキュメントやtidymodels
パッケージに関する書籍を参考にしている。
rsample
パッケージの目的rsample
パッケージは、分析に必要なに様々はタイプのリサンプリングデータを作成する関数を提供しているパッケージ。例えば、ブートストラップ法でリサンプリングされたデータセットからサンプリング分布を推定する際に利用したり、機械学習のモデルパフォーマンスの評価のためのデータセットを作成できる。
rsample
パッケージの実行例まずは必要なパッケージとデータを読み込む。利用するデータはmodeldata
パッケージに含まれるcredit_data
。ここで読み込んでいるデータは、のちのちの予測の際に使用するテストデータとモデル学習用に分割されたもののうち、モデル学習用の方を読み込んでいる。
library(tidymodels)
library(tidyverse)
<- read_csv("https://raw.githubusercontent.com/SugiAki1989/statistical_note/main/note_TidyModels00/df_past.csv")
df_past dim(df_past)
## [1] 4008 14
データを分割する基本的な関数はinitial_split
関数。この関数で初期分割を行い、訓練データdf_train
と評価データdf_test
に分割する。
set.seed(1989)
<- df_past %>%
df_initial initial_split(prop = 0.8, strata = "Status")
df_initial
## <Training/Testing/Total>
## <3206/802/4008>
訓練データdf_train
と評価データdf_test
に分割する際は、training
関数で訓練データを抽出でき、testing
関数で評価データを抽出できる。
<- df_initial %>% training()
df_train <- df_initial %>% testing()
df_test
list(dim(df_train), dim(df_test))
## [[1]]
## [1] 3206 14
##
## [[2]]
## [1] 802 14
K分割クロスバリデーションを行うことが一般的なので、さきほどの訓練データdf_train
をK分割する。K分割する関数はvfold_cv
関数。
ここでは5分割しており、1つ目は[2564/642]
となっているがtidymodels
パッケージの世界では、左の[2564/
がanalysis
と呼ばれ、右側の/642]
がassessment
と呼ばれる。
set.seed(1989)
<- df_train %>% vfold_cv(v = 5, strata = "Status")
df_train_kfoldspilit df_train_kfoldspilit
## # 5-fold cross-validation using stratification
## # A tibble: 5 × 2
## splits id
## <list> <chr>
## 1 <split [2564/642]> Fold1
## 2 <split [2564/642]> Fold2
## 3 <split [2565/641]> Fold3
## 4 <split [2565/641]> Fold4
## 5 <split [2566/640]> Fold5
1つ目のフォールドデータの分析用データ(Analysis
)を取り出すには、pluck
関数とanalysis
関数を利用することで抽出できる。
%>%
df_train_kfoldspilit pluck("splits", 1) %>%
analysis()
## # A tibble: 2,564 × 14
## Status Senior…¹ Home Time Age Marital Records Job Expen…² Income Assets
## <chr> <dbl> <chr> <dbl> <dbl> <chr> <chr> <chr> <dbl> <dbl> <dbl>
## 1 bad 0 pare… 48 41 married no part… 90 80 0
## 2 bad 0 igno… 48 36 married no part… 45 130 750
## 3 bad 2 rent 60 25 single no fixed 46 107 0
## 4 bad 3 owner 24 23 married no fixed 75 85 5000
## 5 bad 0 owner 36 23 single no part… 45 122 2500
## 6 bad 1 rent 54 36 married no fixed 70 99 0
## 7 bad 5 rent 48 31 single no fixed 44 90 0
## 8 bad 2 owner 60 43 married no part… 75 71 3000
## 9 bad 2 rent 36 27 separa… no fixed 48 128 0
## 10 bad 4 pare… 42 27 single no fixed 35 70 0
## # … with 2,554 more rows, 3 more variables: Debt <dbl>, Amount <dbl>,
## # Price <dbl>, and abbreviated variable names ¹Seniority, ²Expenses
同じく、1つ目のフォールドデータの評価用データ(Assess
)を取り出すには、pluck
関数とassessment
関数を利用することで抽出できる。
%>%
df_train_kfoldspilit pluck("splits", 1) %>%
assessment()
## # A tibble: 642 × 14
## Status Senior…¹ Home Time Age Marital Records Job Expen…² Income Assets
## <chr> <dbl> <chr> <dbl> <dbl> <chr> <chr> <chr> <dbl> <dbl> <dbl>
## 1 bad 0 other 18 21 single yes part… 35 50 0
## 2 bad 1 rent 48 29 married yes free… 85 100 0
## 3 bad 0 rent 60 25 single no part… 40 50 0
## 4 bad 0 rent 36 29 married no part… 78 180 0
## 5 bad 2 other 48 23 single no fixed 35 140 0
## 6 bad 0 owner 36 39 single no free… 35 NA 4000
## 7 bad 3 owner 36 40 married yes free… 35 200 10000
## 8 bad 1 pare… 48 28 single no fixed 35 83 0
## 9 bad 2 rent 60 28 married no fixed 74 138 0
## 10 bad 0 pare… 60 21 single no part… 35 86 0
## # … with 632 more rows, 3 more variables: Debt <dbl>, Amount <dbl>,
## # Price <dbl>, and abbreviated variable names ¹Seniority, ²Expenses
毎回、pluck
関数とanalysis
関数、assessment
関数を使ってデータを取り出すのは面倒なので、これをデータフレームに取り出して格納しておく。
set.seed(1989)
<- df_train %>%
df_train_stratified_kfoldspilits vfold_cv(v = 5, strata = "Status") %>%
mutate(
analysis = map(.x = splits, .f = function(x){analysis(x)}),
assessment = map(.x = splits, .f = function(x){assessment(x)})
) df_train_stratified_kfoldspilits
## # 5-fold cross-validation using stratification
## # A tibble: 5 × 4
## splits id analysis assessment
## <list> <chr> <list> <list>
## 1 <split [2564/642]> Fold1 <tibble [2,564 × 14]> <tibble [642 × 14]>
## 2 <split [2564/642]> Fold2 <tibble [2,564 × 14]> <tibble [642 × 14]>
## 3 <split [2565/641]> Fold3 <tibble [2,565 × 14]> <tibble [641 × 14]>
## 4 <split [2565/641]> Fold4 <tibble [2,565 × 14]> <tibble [641 × 14]>
## 5 <split [2566/640]> Fold5 <tibble [2,566 × 14]> <tibble [640 × 14]>
ホールドアウト法でデータを分割する際は、validation_split(prop = 0.8)
関数を利用すればデータを分割できる。
rsample
パッケージは時系列のデータ分割にも対応している。時系列のダミーデータを作成する。
set.seed(1989)
<- seq(as.Date("2022-01-01"), as.Date("2022-12-31"), by = "day")
dt <- rnorm(length(dt), 0, 1)
x <- tibble(id = 1:length(dt), dt, x)
df_timeseries df_timeseries
## # A tibble: 365 × 3
## id dt x
## <int> <date> <dbl>
## 1 1 2022-01-01 1.10
## 2 2 2022-01-02 1.12
## 3 3 2022-01-03 -1.82
## 4 4 2022-01-04 -0.194
## 5 5 2022-01-05 -0.613
## 6 6 2022-01-06 -0.346
## 7 7 2022-01-07 0.278
## 8 8 2022-01-08 0.535
## 9 9 2022-01-09 0.143
## 10 10 2022-01-10 -0.694
## # … with 355 more rows
使う関数はrolling_origin
関数を使用する。initial = 290
としているので、1から290番目のレコードまでが1つ目のフォールドの分析用データ(Analysis
)となっている。assess = 30
としているので、291から320番目のレコードまでが1つ目のフォールドの評価用データ(Assess
)となっている。また、skip = 9
としているので、各フォールドの間隔は10日ごとになっている。なぜ9なのかは疑問に思うところだが、ドキュメントによると0始まりなので、その影響。
When skip = 0, the resampling data sets will increment by one position.
%>%
df_timeseries rolling_origin(
initial = 290, # anaysisデータのレコード数
assess = 30, # assessmentデータのレコード数
skip = 9, # 各フォールドのanaysisデータの間隔 10 − 1 = 9
cumulative = FALSE # anaysisデータを累積するかどうか
%>%
) mutate(
analysis = map(.x = splits, .f = function(x){analysis(x)}),
assessment = map(.x = splits, .f = function(x){assessment(x)}),
analysis_min = map_int(.x = analysis, .f = function(x){x %>% summarise(min(id)) %>% pull()}),
analysis_max = map_int(.x = analysis, .f = function(x){x %>% summarise(max(id)) %>% pull()}),
assessment_min = map_int(.x = assessment, .f = function(x){x %>% summarise(min(id)) %>% pull()}),
assessment_max = map_int(.x = assessment, .f = function(x){x %>% summarise(max(id)) %>% pull()})
)
## # Rolling origin forecast resampling
## # A tibble: 5 × 8
## splits id analysis assessment analysis…¹ analy…² asses…³ asses…⁴
## <list> <chr> <list> <list> <int> <int> <int> <int>
## 1 <split [290/30]> Slice1 <tibble> <tibble> 1 290 291 320
## 2 <split [290/30]> Slice2 <tibble> <tibble> 11 300 301 330
## 3 <split [290/30]> Slice3 <tibble> <tibble> 21 310 311 340
## 4 <split [290/30]> Slice4 <tibble> <tibble> 31 320 321 350
## 5 <split [290/30]> Slice5 <tibble> <tibble> 41 330 331 360
## # … with abbreviated variable names ¹analysis_min, ²analysis_max,
## # ³assessment_min, ⁴assessment_max
cumulative = TRUE
とすれば開始位置から累積されていくので、各フォールドのanalysis_min
は1
となる。
%>%
df_timeseries rolling_origin(
initial = 290, # anaysisデータのレコード数
assess = 30, # assessmentデータのレコード数
skip = 9, # 各フォールドのanaysisデータの間隔 10 − 1 = 9
cumulative = TRUE # anaysisデータを累積するかどうか
%>%
) mutate(
analysis = map(.x = splits, .f = function(x){analysis(x)}),
assessment = map(.x = splits, .f = function(x){assessment(x)}),
analysis_min = map_int(.x = analysis, .f = function(x){x %>% summarise(min(id)) %>% pull()}),
analysis_max = map_int(.x = analysis, .f = function(x){x %>% summarise(max(id)) %>% pull()}),
assessment_min = map_int(.x = assessment, .f = function(x){x %>% summarise(min(id)) %>% pull()}),
assessment_max = map_int(.x = assessment, .f = function(x){x %>% summarise(max(id)) %>% pull()})
)
## # Rolling origin forecast resampling
## # A tibble: 5 × 8
## splits id analysis assessment analysis…¹ analy…² asses…³ asses…⁴
## <list> <chr> <list> <list> <int> <int> <int> <int>
## 1 <split [290/30]> Slice1 <tibble> <tibble> 1 290 291 320
## 2 <split [300/30]> Slice2 <tibble> <tibble> 1 300 301 330
## 3 <split [310/30]> Slice3 <tibble> <tibble> 1 310 311 340
## 4 <split [320/30]> Slice4 <tibble> <tibble> 1 320 321 350
## 5 <split [330/30]> Slice5 <tibble> <tibble> 1 330 331 360
## # … with abbreviated variable names ¹analysis_min, ²analysis_max,
## # ³assessment_min, ⁴assessment_max