UPDATE: 2022-12-21 16:09:48

はじめに

tidymodelsパッケージの使い方をいくつかのノートに分けてまとめている。tidymodelsパッケージは、統計モデルや機械学習モデルを構築するために必要なパッケージをコレクションしているパッケージで、非常に色んなパッケージがある。ここでは、今回はyardstickというパッケージの使い方をまとめていく。モデルの数理的な側面や機械学習の用語などは、このノートでは扱わない。

下記の公式ドキュメントやtidymodelsパッケージに関する書籍を参考にしている。

yardstickパッケージの目的

yardstickパッケージは、tidymodelsパッケージを利用したモデリングにおいて(実際はtidymodelsに限らない)、モデルの予測性能を評価するための関数がまとめられているパッケージ。

公式ドキュメントは下記の通り。

また、下記がわかりやすかったので参考にさせていただいた。

yardstickパッケージの実行例

yardstickパッケージの基本的な利用方法を確認していく。ざっくりとどのような関数があるのかを確認すると、基本的なものからマニアックなものまで幅広く用意されていそうである。

library(tidymodels)
library(tidyverse)

# 評価のための関数だけではない
tibble(function_name = ls("package:yardstick")) %>% 
  filter(!str_detect(function_name, "vec")) %>% 
  print(n = 100)
## # A tibble: 61 × 1
##    function_name              
##    <chr>                      
##  1 accuracy                   
##  2 average_precision          
##  3 bal_accuracy               
##  4 ccc                        
##  5 classification_cost        
##  6 conf_mat                   
##  7 detection_prevalence       
##  8 dots_to_estimate           
##  9 f_meas                     
## 10 finalize_estimator         
## 11 finalize_estimator_internal
## 12 gain_capture               
## 13 gain_curve                 
## 14 get_weights                
## 15 hpc_cv                     
## 16 huber_loss                 
## 17 huber_loss_pseudo          
## 18 iic                        
## 19 j_index                    
## 20 kap                        
## 21 lift_curve                 
## 22 mae                        
## 23 mape                       
## 24 mase                       
## 25 mcc                        
## 26 metric_set                 
## 27 metric_summarizer          
## 28 metric_tweak               
## 29 metrics                    
## 30 mn_log_loss                
## 31 mpe                        
## 32 msd                        
## 33 new_class_metric           
## 34 new_numeric_metric         
## 35 new_prob_metric            
## 36 npv                        
## 37 pathology                  
## 38 poisson_log_loss           
## 39 ppv                        
## 40 pr_auc                     
## 41 pr_curve                   
## 42 precision                  
## 43 recall                     
## 44 rmse                       
## 45 roc_auc                    
## 46 roc_aunp                   
## 47 roc_aunu                   
## 48 roc_curve                  
## 49 rpd                        
## 50 rpiq                       
## 51 rsq                        
## 52 rsq_trad                   
## 53 sens                       
## 54 sensitivity                
## 55 smape                      
## 56 solubility_test            
## 57 spec                       
## 58 specificity                
## 59 tidy                       
## 60 two_class_example          
## 61 validate_estimator

まずは、2クラス分類の予測モデルを評価するための関数をまとめておく。パッケージに付属しているtwo_class_exampleデータや例を参考にする。分類といえば混同行列なので、conf_mat関数で混同行列を作成する。表側に予測値、表頭が観測値という形式で出力される。世間的によく見る混同行列とは反対かもしれない。

two_class_example %>% 
  conf_mat(
    truth = truth,
    estimate = predicted,
    dnn = c("Pred", "Truth")
  ) 
##         Truth
## Pred     Class1 Class2
##   Class1    227     50
##   Class2     31    192

特定の指標で計算したい場合は、指標にあわせた関数が用意されているので、それを利用する。accuracy関数は因子型である必要がある。

list(
  two_class_example %>% accuracy(truth = truth, estimate = predicted),
  sum(227, 192)/sum(227, 31, 50, 192)
)
## [[1]]
## # A tibble: 1 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.838
## 
## [[2]]
## [1] 0.838

precisionは、陽性と予測したもののうち、実際に陽性であるものの割合を表す指標で、precision関数で計算できる。

my_precision <- 227/sum(227, 50)
list(
  two_class_example %>% precision(truth = truth, estimate = predicted),
  my_precision
)
## [[1]]
## # A tibble: 1 × 3
##   .metric   .estimator .estimate
##   <chr>     <chr>          <dbl>
## 1 precision binary         0.819
## 
## [[2]]
## [1] 0.8194946

recallは、True Positive Rate(TPR)とも呼ばれるやつで、実際に陽性であるもののうち、正しく陽性と予測できたものの割合を表す指標で、recall関数で計算できる。ROC曲線の縦軸で使用される。

my_recall <- 227/sum(227, 31)

list(
  two_class_example %>% recall(truth = truth, estimate = predicted),
  my_recall
)
## [[1]]
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 recall  binary         0.880
## 
## [[2]]
## [1] 0.879845

f_measは、recallprecisionの調和平均で、f_meas関数で計算できる。

list(
  two_class_example %>% f_meas(truth = truth, estimate = predicted),
  (2 * my_precision * my_recall)/(my_precision + my_recall)
)
## [[1]]
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 f_meas  binary         0.849
## 
## [[2]]
## [1] 0.8485981

specificityは、True Negative Rate(TNR)、特異度とも呼ばれるやつで、実際に陰性であるもののうち、正しく陰性と予測できたものの割合を表す指標で、recall関数で計算できる。

my_specificity <- 192/sum(50, 192)
list(
  two_class_example %>% specificity(truth = truth, estimate = predicted),
  my_specificity
)
## [[1]]
## # A tibble: 1 × 3
##   .metric     .estimator .estimate
##   <chr>       <chr>          <dbl>
## 1 specificity binary         0.793
## 
## [[2]]
## [1] 0.7933884

LogLossは、 accuracyに確率を組み込んだような指標で、予測を正解確率、不正解確率を含めて評価している。小さいほどモデルの性能がよい。 mn_log_loss関数で計算できるが、estimateには、因子型のレベルが低いもの方の確率を渡す。

# class1を1、class2を0に変換
y <- 2 - as.numeric(two_class_example$truth)
# class1のときは、two_class_example$Class1を使い、class2のときは1-two_class_example$Class1に変換
p <- ifelse(y == 1, two_class_example$Class1, 1- two_class_example$Class1)
my_logloss <- -1*mean(log(p))

list(
  use_prob_col = levels(two_class_example$truth)[[1]], 
  two_class_example %>% mn_log_loss(truth = truth, estimate = Class1),
  my_logloss
)
## $use_prob_col
## [1] "Class1"
## 
## [[2]]
## # A tibble: 1 × 3
##   .metric     .estimator .estimate
##   <chr>       <chr>          <dbl>
## 1 mn_log_loss binary         0.328
## 
## [[3]]
## [1] 0.3283096

AUC(Area Under Curve)は、曲線の下側の面積の大きさで分類予測を評価する指標で、大きいほどモデルの性能がよい。 roc_auc関数で計算できるが、estimateには、因子型のレベルが低いもの方の確率を渡す。

list(
  two_class_example %>% roc_auc(truth = truth, estimate = Class1)
)
## [[1]]
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 roc_auc binary         0.939

ここからは回帰問題で使用する評価指標をまとめておく。組み込みのOrangeデータをサンプルデータに変換する。

set.seed(1989)
regress_sample <- tibble(truth = Orange$age, pred = rnorm(nrow(Orange), 0, 50) + Orange$age)
regress_sample
## # A tibble: 35 × 2
##    truth  pred
##    <dbl> <dbl>
##  1   118  173.
##  2   484  540.
##  3   664  573.
##  4  1004  994.
##  5  1231 1200.
##  6  1372 1355.
##  7  1582 1596.
##  8   118  145.
##  9   484  491.
## 10   664  629.
## # … with 25 more rows

rmseは平均二乗平方根誤差と呼ばれるもので、rmse関数で計算できる。

list(
  regress_sample %>% rmse(truth = truth, estimate = pred),
  sqrt(mean((regress_sample$truth - regress_sample$pred)^2))
)
## [[1]]
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rmse    standard        58.3
## 
## [[2]]
## [1] 58.34786

maeは平均絶対誤差と呼ばれるもので、mae関数で計算できる。

list(
  regress_sample %>% mae(truth = truth, estimate = pred),
  mean(abs(regress_sample$truth - regress_sample$pred))
)
## [[1]]
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 mae     standard        48.9
## 
## [[2]]
## [1] 48.93553

mapeは平均絶対パーセント誤差と呼ばれているもので、mape関数で計算できる。この例だと平均して約13%前後の誤差があることになる。

list(
  regress_sample %>% mape(truth = truth, estimate = pred),
  mean(abs(regress_sample$truth - regress_sample$pred)/regress_sample$truth) * 100
)
## [[1]]
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 mape    standard        13.5
## 
## [[2]]
## [1] 13.51992

便利な関数metric_setがある。これは評価指標をまとめて出力できる関数。回帰でも、

regression_metric_set <- metric_set(rmse, mae, mape)
regress_sample %>% 
  regression_metric_set(truth = truth, estimate = pred)
## # A tibble: 3 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rmse    standard        58.3
## 2 mae     standard        48.9
## 3 mape    standard        13.5

分類でも利用可能。

classification_metric_set <- metric_set(accuracy, precision, recall, f_meas)
two_class_example %>% 
  classification_metric_set(truth = truth, estimate = predicted)
## # A tibble: 4 × 3
##   .metric   .estimator .estimate
##   <chr>     <chr>          <dbl>
## 1 accuracy  binary         0.838
## 2 precision binary         0.819
## 3 recall    binary         0.880
## 4 f_meas    binary         0.849