UPDATE: 2022-12-16 23:28:17

はじめに

tidymodelsパッケージの使い方をいくつかのノートに分けてまとめている。tidymodelsパッケージは、統計モデルや機械学習モデルを構築するために必要なパッケージをコレクションしているパッケージで、非常に色んなパッケージがある。ここでは、今回はrsampleというパッケージの使い方をまとめていく。モデルの数理的な側面や機械学習の用語などは、このノートでは扱わない。

下記の公式ドキュメントやtidymodelsパッケージに関する書籍を参考にしている。

rsampleパッケージの目的

rsampleパッケージは、分析に必要なに様々はタイプのリサンプリングデータを作成する関数を提供しているパッケージ。例えば、ブートストラップ法でリサンプリングされたデータセットからサンプリング分布を推定する際に利用したり、機械学習のモデルパフォーマンスの評価のためのデータセットを作成できる。

rsampleパッケージの実行例

まずは必要なパッケージとデータを読み込む。利用するデータはmodeldataパッケージに含まれるcredit_data。ここで読み込んでいるデータは、のちのちの予測の際に使用するテストデータとモデル学習用に分割されたもののうち、モデル学習用の方を読み込んでいる。

library(tidymodels)
library(tidyverse)

df_past <- read_csv("https://raw.githubusercontent.com/SugiAki1989/statistical_note/main/note_TidyModels00/df_past.csv")
dim(df_past)
## [1] 4008   14

データを分割する基本的な関数はinitial_split関数。この関数で初期分割を行い、訓練データdf_trainと評価データdf_testに分割する。

set.seed(1989)
df_initial <- df_past %>% 
  initial_split(prop = 0.8, strata = "Status")

df_initial
## <Training/Testing/Total>
## <3206/802/4008>

訓練データdf_trainと評価データdf_testに分割する際は、training関数で訓練データを抽出でき、testing関数で評価データを抽出できる。

df_train <- df_initial %>% training()
df_test <- df_initial %>% testing()

list(dim(df_train), dim(df_test))
## [[1]]
## [1] 3206   14
## 
## [[2]]
## [1] 802  14

K分割クロスバリデーションを行うことが一般的なので、さきほどの訓練データdf_trainをK分割する。K分割する関数はvfold_cv関数。

ここでは5分割しており、1つ目は[2564/642]となっているがtidymodelsパッケージの世界では、左の[2564/analysisと呼ばれ、右側の/642]assessmentと呼ばれる。

set.seed(1989)
df_train_kfoldspilit <- df_train %>% vfold_cv(v = 5, strata = "Status")
df_train_kfoldspilit
## #  5-fold cross-validation using stratification 
## # A tibble: 5 × 2
##   splits             id   
##   <list>             <chr>
## 1 <split [2564/642]> Fold1
## 2 <split [2564/642]> Fold2
## 3 <split [2565/641]> Fold3
## 4 <split [2565/641]> Fold4
## 5 <split [2566/640]> Fold5

1つ目のフォールドデータの分析用データ(Analysis)を取り出すには、pluck関数とanalysis関数を利用することで抽出できる。

df_train_kfoldspilit %>% 
  pluck("splits", 1) %>% 
  analysis()
## # A tibble: 2,564 × 14
##    Status Senior…¹ Home   Time   Age Marital Records Job   Expen…² Income Assets
##    <chr>     <dbl> <chr> <dbl> <dbl> <chr>   <chr>   <chr>   <dbl>  <dbl>  <dbl>
##  1 bad           0 pare…    48    41 married no      part…      90     80      0
##  2 bad           0 igno…    48    36 married no      part…      45    130    750
##  3 bad           2 rent     60    25 single  no      fixed      46    107      0
##  4 bad           3 owner    24    23 married no      fixed      75     85   5000
##  5 bad           0 owner    36    23 single  no      part…      45    122   2500
##  6 bad           1 rent     54    36 married no      fixed      70     99      0
##  7 bad           5 rent     48    31 single  no      fixed      44     90      0
##  8 bad           2 owner    60    43 married no      part…      75     71   3000
##  9 bad           2 rent     36    27 separa… no      fixed      48    128      0
## 10 bad           4 pare…    42    27 single  no      fixed      35     70      0
## # … with 2,554 more rows, 3 more variables: Debt <dbl>, Amount <dbl>,
## #   Price <dbl>, and abbreviated variable names ¹​Seniority, ²​Expenses

同じく、1つ目のフォールドデータの評価用データ(Assess)を取り出すには、pluck関数とassessment関数を利用することで抽出できる。

df_train_kfoldspilit %>% 
  pluck("splits", 1) %>% 
  assessment()
## # A tibble: 642 × 14
##    Status Senior…¹ Home   Time   Age Marital Records Job   Expen…² Income Assets
##    <chr>     <dbl> <chr> <dbl> <dbl> <chr>   <chr>   <chr>   <dbl>  <dbl>  <dbl>
##  1 bad           0 other    18    21 single  yes     part…      35     50      0
##  2 bad           1 rent     48    29 married yes     free…      85    100      0
##  3 bad           0 rent     60    25 single  no      part…      40     50      0
##  4 bad           0 rent     36    29 married no      part…      78    180      0
##  5 bad           2 other    48    23 single  no      fixed      35    140      0
##  6 bad           0 owner    36    39 single  no      free…      35     NA   4000
##  7 bad           3 owner    36    40 married yes     free…      35    200  10000
##  8 bad           1 pare…    48    28 single  no      fixed      35     83      0
##  9 bad           2 rent     60    28 married no      fixed      74    138      0
## 10 bad           0 pare…    60    21 single  no      part…      35     86      0
## # … with 632 more rows, 3 more variables: Debt <dbl>, Amount <dbl>,
## #   Price <dbl>, and abbreviated variable names ¹​Seniority, ²​Expenses

毎回、pluck関数とanalysis関数、assessment関数を使ってデータを取り出すのは面倒なので、これをデータフレームに取り出して格納しておく。

set.seed(1989)
df_train_stratified_kfoldspilits <- df_train %>% 
  vfold_cv(v = 5, strata = "Status") %>% 
  mutate(
    analysis = map(.x = splits, .f = function(x){analysis(x)}),
    assessment = map(.x = splits, .f = function(x){assessment(x)})
    )
df_train_stratified_kfoldspilits
## #  5-fold cross-validation using stratification 
## # A tibble: 5 × 4
##   splits             id    analysis              assessment         
##   <list>             <chr> <list>                <list>             
## 1 <split [2564/642]> Fold1 <tibble [2,564 × 14]> <tibble [642 × 14]>
## 2 <split [2564/642]> Fold2 <tibble [2,564 × 14]> <tibble [642 × 14]>
## 3 <split [2565/641]> Fold3 <tibble [2,565 × 14]> <tibble [641 × 14]>
## 4 <split [2565/641]> Fold4 <tibble [2,565 × 14]> <tibble [641 × 14]>
## 5 <split [2566/640]> Fold5 <tibble [2,566 × 14]> <tibble [640 × 14]>

ホールドアウト法でデータを分割する際は、validation_split(prop = 0.8)関数を利用すればデータを分割できる。

rsampleパッケージは時系列のデータ分割にも対応している。時系列のダミーデータを作成する。

set.seed(1989)
dt <- seq(as.Date("2022-01-01"), as.Date("2022-12-31"), by = "day")
x <- rnorm(length(dt), 0, 1)
df_timeseries <- tibble(id = 1:length(dt), dt, x)
df_timeseries
## # A tibble: 365 × 3
##       id dt              x
##    <int> <date>      <dbl>
##  1     1 2022-01-01  1.10 
##  2     2 2022-01-02  1.12 
##  3     3 2022-01-03 -1.82 
##  4     4 2022-01-04 -0.194
##  5     5 2022-01-05 -0.613
##  6     6 2022-01-06 -0.346
##  7     7 2022-01-07  0.278
##  8     8 2022-01-08  0.535
##  9     9 2022-01-09  0.143
## 10    10 2022-01-10 -0.694
## # … with 355 more rows

使う関数はrolling_origin関数を使用する。initial = 290としているので、1から290番目のレコードまでが1つ目のフォールドの分析用データ(Analysis)となっている。assess = 30としているので、291から320番目のレコードまでが1つ目のフォールドの評価用データ(Assess)となっている。また、skip = 9としているので、各フォールドの間隔は10日ごとになっている。なぜ9なのかは疑問に思うところだが、ドキュメントによると0始まりなので、その影響。

When skip = 0, the resampling data sets will increment by one position.

df_timeseries %>% 
  rolling_origin(
    initial = 290,      # anaysisデータのレコード数
    assess = 30,       # assessmentデータのレコード数
    skip = 9,          # 各フォールドのanaysisデータの間隔 10 − 1 = 9
    cumulative = FALSE  # anaysisデータを累積するかどうか
  ) %>% 
  mutate(
    analysis = map(.x = splits, .f = function(x){analysis(x)}),
    assessment = map(.x = splits, .f = function(x){assessment(x)}),
    analysis_min    = map_int(.x = analysis,    .f = function(x){x %>% summarise(min(id)) %>% pull()}),
    analysis_max    = map_int(.x = analysis,    .f = function(x){x %>% summarise(max(id)) %>% pull()}),
    assessment_min  = map_int(.x = assessment,  .f = function(x){x %>% summarise(min(id)) %>% pull()}),
    assessment_max  = map_int(.x = assessment,  .f = function(x){x %>% summarise(max(id)) %>% pull()})
  ) 
## # Rolling origin forecast resampling 
## # A tibble: 5 × 8
##   splits           id     analysis assessment analysis…¹ analy…² asses…³ asses…⁴
##   <list>           <chr>  <list>   <list>          <int>   <int>   <int>   <int>
## 1 <split [290/30]> Slice1 <tibble> <tibble>            1     290     291     320
## 2 <split [290/30]> Slice2 <tibble> <tibble>           11     300     301     330
## 3 <split [290/30]> Slice3 <tibble> <tibble>           21     310     311     340
## 4 <split [290/30]> Slice4 <tibble> <tibble>           31     320     321     350
## 5 <split [290/30]> Slice5 <tibble> <tibble>           41     330     331     360
## # … with abbreviated variable names ¹​analysis_min, ²​analysis_max,
## #   ³​assessment_min, ⁴​assessment_max

cumulative = TRUEとすれば開始位置から累積されていくので、各フォールドのanalysis_min1となる。

df_timeseries %>% 
  rolling_origin(
    initial = 290,      # anaysisデータのレコード数
    assess = 30,       # assessmentデータのレコード数
    skip = 9,           # 各フォールドのanaysisデータの間隔 10 − 1 = 9
    cumulative = TRUE   # anaysisデータを累積するかどうか
  ) %>% 
  mutate(
    analysis = map(.x = splits, .f = function(x){analysis(x)}),
    assessment = map(.x = splits, .f = function(x){assessment(x)}),
    analysis_min    = map_int(.x = analysis,    .f = function(x){x %>% summarise(min(id)) %>% pull()}),
    analysis_max    = map_int(.x = analysis,    .f = function(x){x %>% summarise(max(id)) %>% pull()}),
    assessment_min  = map_int(.x = assessment,  .f = function(x){x %>% summarise(min(id)) %>% pull()}),
    assessment_max  = map_int(.x = assessment,  .f = function(x){x %>% summarise(max(id)) %>% pull()})
  )
## # Rolling origin forecast resampling 
## # A tibble: 5 × 8
##   splits           id     analysis assessment analysis…¹ analy…² asses…³ asses…⁴
##   <list>           <chr>  <list>   <list>          <int>   <int>   <int>   <int>
## 1 <split [290/30]> Slice1 <tibble> <tibble>            1     290     291     320
## 2 <split [300/30]> Slice2 <tibble> <tibble>            1     300     301     330
## 3 <split [310/30]> Slice3 <tibble> <tibble>            1     310     311     340
## 4 <split [320/30]> Slice4 <tibble> <tibble>            1     320     321     350
## 5 <split [330/30]> Slice5 <tibble> <tibble>            1     330     331     360
## # … with abbreviated variable names ¹​analysis_min, ²​analysis_max,
## #   ³​assessment_min, ⁴​assessment_max