UPDATE: 2022-12-25 16:39:37

はじめに

tidymodelsパッケージの使い方をいくつかのノートに分けてまとめている。tidymodelsパッケージは、統計モデルや機械学習モデルを構築するために必要なパッケージをコレクションしているパッケージで、非常に色んなパッケージがある。ここでは、今回はDALEXパッケージについてまとめていく。モデルの数理的な側面や機械学習の用語などは、このノートでは扱わない。

下記の公式ドキュメントやtidymodelsパッケージに関する書籍を参考にしている。

説明可能性には、大域的(データレベル)な説明と局所的(インスタンスレベル)な説明が出てくるが、今回は双方を扱う。

DALEXパッケージの目的

DALEXパッケージは、DrWhyという取り組みの中に存在する説明可能なAI(XAI)用のツールのコレクションの1つパッケージ。この取り組みでは、機械学習のモデルが下記の観点から検討されることを望んでいる。

  • Effective: パフォーマンスの高いモデルなのか。
  • Transparent: モデルの予測に関する解釈可能性と説明可能性があるか。
  • Fair: 人種差別を促すモデルではないか。
  • Secure: ハッキングされない安全性があるか。
  • Confidential: 機密情報としてモデルが管理されているか。
  • Reproducible: 再現可能かどうか

モデルの中で実際に何が起こっているかを理解することや、モデルが出力する予測値がどのように生成されたのかを理解することをこの取組では目指しており、そのパッケージの1つとして、DALEXパッケージがある。DALEXパッケージについては下記の公式サイトおよび書籍を参考にした。

数理的な側面についてはここでは扱わないが、ここで紹介する機械学習を解釈する手法については下記の書籍がわかりやすい。

モデルの作成

DALEXパッケージを利用するためには、モデルが必要になるので、モデル作成を行っておく。動けば良いモデルであって、役に立つモデルではない点は注意。ここではタイタニックのデータを利用する。これは、多くの人によっては分析がなされているため、生存するために必要な特徴量が明らかになっているため、DALEXパッケージの関数の出力を理解しやすくするため。モデルの説明は下記の通り。

  • Survived: 0=死亡、1=生存
  • Pclass: 旅客クラス(1=1等, 2=2等, 3=3等)
  • Sex: 性別(male=男性, female=女性)
  • Age: 年齢
  • Sibsp: 同乗兄弟,配偶者数
  • Parch: 同乗親,子供数
  • Fare: 旅客運賃
  • Cabin: 客室番号
  • Embarked: 出港地(C=Cherbourg, Q=Queenstown, S=Southampton)

生存のために重要な特徴は下記の通り。

  • 性別(Sex): 男性よりも女性や子供の方が生存率が高い。つまり、男性は生存しにくい。
  • 年齢(Age): 年齢が若いほど生存率が高い。つまり、年齢が高いと生存しにくい。
  • 乗船クラス(Pclass): 1等に比べて2,3等の乗客の生存率が低い。つまり、低クラス(=2,3)だと生存しにくい。
  • 家族の有無(SibSp, Parch):家族がいる乗客の生存率が高い。つまり、独り身だとだと生存しにくい。

モデル作成を行なう。

library(tidymodels)
library(tidyverse)
library(DALEX)
library(DALEXtra)
library(lime)
library(localModel)
library(vip)
library(patchwork)

df <- read_csv("https://raw.githubusercontent.com/ogrisel/parallel_ml_tutorial/master/notebooks/titanic_train.csv") %>% 
  select(-Name, -PassengerId, -Ticket)

# rsample
set.seed(1989)
df_initial <- df %>% initial_split(prop = 0.8, strata = "Survived")
df_train <- df_initial %>% training()
df_test <- df_initial %>% testing()

set.seed(1989)
df_train_stratified_splits <- 
  vfold_cv(df_train, v = 5, strata = "Survived")

# recipes
recipe <- recipe(Survived ~ Pclass + Sex + Age + SibSp + Parch + Fare + Embarked, data = df_train) %>%
  step_impute_median(Age, Fare) %>% 
  step_impute_mode(Embarked) %>% 
  step_mutate_at(Survived, Pclass, Sex, Embarked, fn = factor) %>%
  step_mutate(Travelers = SibSp + Parch + 1) %>% 
  step_rm(SibSp, Parch) %>% 
  step_dummy(all_nominal_predictors(), one_hot = FALSE)
  # %>% step_normalize(all_numeric_predictors()) 

# parsnip
model <- rand_forest(mtry = tune(), trees = tune(), min_n = tune()) %>%
  set_engine("ranger", importance = "impurity") %>%
  set_mode("classification")

# workflows
workflow <- workflow() %>% 
  add_recipe(recipe) %>% 
  add_model(model)

# tune / dials
set.seed(1989)
hyper_parameter_grid <- workflow %>% 
  extract_parameter_set_dials() %>% 
  update(
    mtry = mtry(range = c(4, 8)),
    trees = trees(range = c(500, 3000)),
    min_n = min_n(range = c(50, 100)),
  ) %>% 
  grid_latin_hypercube(size = 3)

# tune / dials
workflow_tuned <- 
  workflow %>% 
  tune_grid(
    resamples = df_train_stratified_splits,
    grid = hyper_parameter_grid,
    metrics = metric_set(accuracy),
    control = control_resamples(
      extract = extract_model, 
      save_pred = TRUE         
    )
  )

# workflow_tuned %>% 
#   collect_metrics()
better_paramters <- workflow_tuned %>% 
  select_best(metric = "accuracy")

better_workflow <- workflow %>% 
  finalize_workflow(parameters = better_paramters)

set.seed(1989)
model_trained_better_workflow <- 
  better_workflow %>% 
  fit(df_train)

model_trained_better_workflow
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: rand_forest()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 6 Recipe Steps
## 
## • step_impute_median()
## • step_impute_mode()
## • step_mutate_at()
## • step_mutate()
## • step_rm()
## • step_dummy()
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## Ranger result
## 
## Call:
##  ranger::ranger(x = maybe_data_frame(x), y = y, mtry = min_cols(~4L,      x), num.trees = ~636L, min.node.size = min_rows(~64L, x),      importance = ~"impurity", num.threads = 1, verbose = FALSE,      seed = sample.int(10^5, 1), probability = TRUE) 
## 
## Type:                             Probability estimation 
## Number of trees:                  636 
## Sample size:                      712 
## Number of independent variables:  8 
## Mtry:                             4 
## Target node size:                 64 
## Variable importance mode:         impurity 
## Splitrule:                        gini 
## OOB prediction error (Brier s.):  0.1290546

あとで前処理済みデータが必要になるので、ここで作成しておく。

df_test_baked <- recipe %>%
  prep() %>%
  bake(df_test)

DALECXパッケージを使う準備として、Models’ explainersを作る必要がある。これはパッケージが異なるとモデルの内部構造やオブジェクトの中身が変わってしまうため、統一したインターフェースを提供するオブジェクトに変換する必要がある。これがexplainerの役割で、explain関数で作成できる。引数は下記の通り。

  • model: 変換するモデルを指定
  • data: モデルが適用されるデータフレーム。データは、目的変数は含まない。
  • y: dataで与えられたデータに対応する説明変数
  • predict_function: 予測スコアを返す関数。デフォルトはpredict関数が使用されるが、エラーを引き起こす可能性がある
  • verbose: 診断メッセージを表示するかどうか
  • precalculate: 予測値と残差の計算を説明文作成時に行うかどうか
  • model_info: モデルに関する情報を提供する名前付きリスト
  • type: モデルの種類に関する情報で、classificationまたはregressionを指定
  • label: モデルの一意な名前
explainer <-
  DALEX::explain(
    model = model_trained_better_workflow %>% extract_fit_parsnip(),
    data = df_test_baked %>% select(!Survived),
    y = as.integer(df_test_baked %>% pull(Survived)),
    label = "randomforest"
  )
## Preparation of a new explainer is initiated
##   -> model label       :  randomforest 
##   -> data              :  179  rows  8  cols 
##   -> data              :  tibble converted into a data.frame 
##   -> target variable   :  179  values 
##   -> predict function  :  yhat.model_fit  will be used (  default  )
##   -> predicted values  :  No value for predict function target column. (  default  )
##   -> model_info        :  package parsnip , ver. 1.0.3 , task classification (  default  ) 
##   -> predicted values  :  numerical, min =  0.02319849 , mean =  0.3965378 , max =  0.9908621  
##   -> residual function :  difference between y and yhat (  default  )
##   -> residuals         :  numerical, min =  0.01477724 , mean =  0.988937 , max =  1.962641  
##   A new explainer has been created!
class(explainer)
## [1] "explainer"

explain関数を適用すると、explainerクラスのオブジェクトが作成される。このオブジェクトは以下を含むリスト。

  • model: 説明されるモデル
  • data: モデルが適用されたデータ
  • y: データに対応する従属変数の観測値
  • y_hat: データにモデルを適用して得られた予測値
  • residuals: yy_hat に基づいて計算された残差
  • predict_function: モデルの予測値を得るために使用される関数
  • residual_function: 残差を求めるために使用される関数
  • class: モデルのクラス
  • label: モデルのラベル。
  • model_info: モデルに関する情報を提供する名前付きリスト

explainerクラスのオブジェクトには、モデルの説明を作成するために必要なすべての要素が含まれる。

インスタンスレベル

インスタンスレベルでは、モデルが特定の1つの観測値に対してどのように予測をもたらすかを理解するのに役立つ。つまり、タイタニックデータであれば、乗客1人のレベルの話。例えば、説明変数の値を調整して、仮想的な人物データを作成すると、その仮想人物の生存予測はどうなるのか、どの変数が有効なのか、特定の変数を上下させるとどうなるのか、これらの説明を得たいときにインスタンスレベルでの手法は役に立つ。

Break-down Plots

Break-down Plotsは、単一のオブザベーション(タイタニックデータであればいち個人)に対するモデルの予測を理解しようとするとき、どの変数が予測結果に対して、どのように貢献しているのかを調べられる。Break-down Plotsは、加算的でわかりやすいという利点をもつ一方で、交互作用などを含む場合は誤解を生む可能性がある。ただ、交互作用を含めたモデルであれば簡単に拡張できる(predict_parts(type = "break_down_interactions")を指定する)。

説明のために仮想的な人物を作成する。survive_lowは生存率が低い人物で、survive_highは生存率が高い人物。

survive_low <- tibble(
  Age = 60,
  Fare = 15,
  Travelers = 1,
  Pclass_X2 = 0,
  Pclass_X3 = 1,
  Sex_male = 1,
  Embarked_Q = 0,
  Embarked_S = 1,
)

survive_high <- tibble(
  Age = 30,
  Fare = 30,
  Travelers = 2,
  Pclass_X2 = 0,
  Pclass_X3 = 0,
  Sex_male = 0,
  Embarked_Q = 0,
  Embarked_S = 0,
)

Break-down Plotsはpredict_parts関数で作成できる。survive_lowは生存率は0.127予測されるが、生存率を下げる要因として、男性であること(Sex_male=1)、等級が低いこと(Pclass_X3=1)、年齢が高いこと(Age=60)、これらによって生存率が下がっていることがわかる。

break_down_low <- predict_parts(explainer = explainer,
                                new_observation = survive_low,
                                # "break_down","shap","oscillations","break_down_interactions"を指定できる
                                type = "break_down")
plot(break_down_low) 

一方で、survive_lowは生存率は0.948と予測されるが、生存率をあげる要因として、女性であること(Sex_male=0)、等級が高いこと(Pclass_X3=0)、料金が高いこと(Fare=30)、これらによって生存率が上がっていることがわかる。

break_down_high <- predict_parts(explainer = explainer,
                                new_observation = survive_high,
                                # "break_down","shap","oscillations","break_down_interactions"を指定できる
                                type = "break_down")
plot(break_down_high) 

predict_parts関数は他にも下記の引数が取れる。

  • order: 説明変数の順序を指定
  • keep_distributions: 分布をバイオリンプロットとして表示
break_down_high2 <- predict_parts(explainer = explainer,
                                new_observation = survive_high,
                                # "break_down","shap","oscillations","break_down_interactions"を指定できる
                                type = "break_down",
                                order = c("Pclass_X2", "Embarked_Q", "Age", "Travelers",
                                          "Embarked_S", "Fare", "Pclass_X3", "Sex_male"),
                                keep_distributions = TRUE
                                )
plot(break_down_high2, plot_distributions = TRUE)

SHAP

Break-down Plotsの問題点として、同じデータ出会っても、変数の並べ方で見方が大きく変わってしまう点が挙げられる。下記の3枚のBreak-down Plotsは同じデータであるにもかかわらず、分解のされ方が変わっているため、この点には注意が必要である。

survive_middle <- tibble(
  Embarked_S = 1,
  Embarked_Q = 0,
  Sex_male = 0,
  Pclass_X3 = 1,
  Pclass_X2 = 0,
  Travelers = 1,
  Fare = 15,
  Age = 60,
)

break_down_middle1 <- predict_parts(explainer = explainer,
                                new_observation = survive_middle,
                                type = "break_down")

break_down_middle2 <- predict_parts(explainer = explainer,
                                new_observation = survive_middle,
                                type = "break_down",
                                order = c("Pclass_X2", "Embarked_Q", "Age", "Travelers",
                                          "Embarked_S", "Fare", "Pclass_X3", "Sex_male")
                                )

break_down_middle3 <- predict_parts(explainer = explainer,
                                new_observation = survive_middle,
                                type = "break_down",
                                order = c( "Travelers", "Pclass_X2", "Embarked_S", "Fare",
                                           "Pclass_X3", "Embarked_Q", "Sex_male", "Age")
                                )

plot(break_down_middle1) + plot(break_down_middle2) + plot(break_down_middle3)

この問題を解決する1つの手段として、変数の順序の影響を取り除くために、Shapley値を利用して平均値を計算する方法がある。SHapley Additive exPlanations(SHAP)は、協力ゲーム理論の分野で開発された「Shapley値」に基づいている。

赤と緑のバーでポジティブなのか、ネガティブなのかを表し、順序により大きさの違いを箱ひげ図で表現している。このプロットをみれば、生存の観点から重要な変数が女性(Sex_male=0)であること、等級が低いこと(Pclass_X3=1)がわかる。

shap_survive_high <- predict_parts(
  explainer = explainer,
  new_observation = survive_high,
  type = "shap"
)

shap_survive_middle <- predict_parts(
  explainer = explainer,
  new_observation = survive_middle,
  type = "shap"
)

shap_survive_low <- predict_parts(
  explainer = explainer,
  new_observation = survive_low,
  type = "shap"
)
plot(shap_survive_low) + plot(shap_survive_middle) + plot(shap_survive_high)

LIME

Break-down PlotsとSHAPは、計算量が大きくなりやすいため、説明変数が多くないモデルに適していると言われる。それを解決する方法の1つとして、LIMEがある。二値分類の場合であれば、LIMEは決定領域を単純な線形モデルを使用して、人口データから局所近似を作成することで、複雑なモデルの「ローカル説明者」として機能する。これが名前(LIME: Local Interpretable Model-agnostic Explanations)の由来である。

predict_surrogate関数で計算でき、n_featuresは K-LASSO 法で選択される説明変数の最大数を指定し、n_permutationsで局所モデル近似のためにサンプリング人工データの数を指定する。十分な分散が計算できないダミー変数などについては、ワーニングが表示される。

計算結果のfeatureカラムには、K-LASSO法で非ゼロの係数が与えられた説明変数が示される。feature_valueカラムは提供された観測値の値であり、feature_descカラムは、元の説明変数がどのように変換されたかを表す。連続変数は特定のしきい値で分類されている。feature_weightカラムは、K-LASSO 法で選択された変数の推定された係数を表し、model_interceptカラムは切片を表す。これらの値の近似によってモデルを解釈できるようにする。

model_type.dalex_explainer <- DALEXtra::model_type.dalex_explainer
predict_model.dalex_explainer <- DALEXtra::predict_model.dalex_explainer

lime_survie_high <- predict_surrogate(
  explainer = explainer,
  new_observation = survive_high,
  n_features = 10,
  n_permutations = 1000,
  type = "lime"
)

lime_survie_high
## # A tibble: 8 × 11
##   model_type case  model_r2 model_int…¹ model…² feature featu…³ featur…⁴ featu…⁵
##   <chr>      <chr>    <dbl>       <dbl>   <dbl> <chr>     <dbl>    <dbl> <chr>  
## 1 regression 1        0.238       0.760   0.870 Age          30  9.52e-3 28.0 <…
## 2 regression 1        0.238       0.760   0.870 Fare         30  1.09e-2 13.0 <…
## 3 regression 1        0.238       0.760   0.870 Travel…       2  1.02e-1 Travel…
## 4 regression 1        0.238       0.760   0.870 Pclass…       0 -3.34e-3 Pclass…
## 5 regression 1        0.238       0.760   0.870 Pclass…       0  3.24e-4 Pclass…
## 6 regression 1        0.238       0.760   0.870 Sex_ma…       0 -8.70e-3 Sex_ma…
## 7 regression 1        0.238       0.760   0.870 Embark…       0  6.36e-4 Embark…
## 8 regression 1        0.238       0.760   0.870 Embark…       0 -2.06e-3 Embark…
## # … with 2 more variables: data <list>, prediction <dbl>, and abbreviated
## #   variable names ¹​model_intercept, ²​model_prediction, ³​feature_value,
## #   ⁴​feature_weight, ⁵​feature_desc

棒グラフの長さは(絶対値)を示し、色は推定された係数の符号を示している。赤がネガティブで、青はポジティブ。

plot(lime_survie_high)

LIMEの計算結果をlocalModelパッケージと組み合わせることで、各変数と生存率の関係をより詳細に理解できる。例えば、Ageを使って可視化すると、Ageの値が40歳あたりを基準に、年齢が高くなると生存確率が低下していることがわかる

localmodel_survie_high <- predict_surrogate(
  explainer = explainer,
  new_observation = survive_high,
  seed = 1989,
  size = 1000,
  type = "localModel"
)

plot_interpretable_feature(localmodel_survie_high, "Age")

データセットレベル

データセットレベルでは、インスタンスレベルとは異なり、モデルの予測に各変数がどのように機能するかを理解することが目的である。例えば、モデルの変数の中で重要な変数を知りたい場合、変数の高低がモデルの予測に与える影響などを知ることができる。

Variable Importance Measures

変数重要度はモデルの予測において、重要な変数を知ることができ、モデルの予測に影響を与えない変数はモデルから除外する、重要な変数のさらなる探索のための順序付け、ドメイン知識に基づいたモデルの有効性評価などができる。

変数の重要度は、注目している説明変数がモデルから取り除かれた場合、モデルのパフォーマンスがどの程度変化するかを測定することで重要度を測定する。重要な変数であれば、モデルから取り除かれると予測性能が下がることが予想され、重要ではない変数はモデルから取り除かれても、予測には影響しない、ということ。

変数重要度だけであれば、簡単に可視化できる。

model_trained_better_workflow %>% 
  extract_fit_parsnip() %>% 
  vip(num_features = 10) + 
  theme_bw()

model_parts関数では、計算に利用する評価指標の違いから、見え方が異なっている。変数をシャッフルしてないときと、シャッフルしたときの比率(ratio)が一緒であれば1に近くなり、一緒でなければ1から遠くなる。

set.seed(1989)
vip <- model_parts(
  explainer = explainer,
  type = "ratio", # differnce, ratio
  n_sample = 1000, # サンプルする数
  B = 10 # シャッフル回数
)

plot(vip)

Partial-dependence Plot

Partial-dependence Plotは、モデルの予測値と説明変数の関係を表すもので、変数の大小が予測値にどのような影響を与えるのかを確認できる。

  • variables: 計算する説明変数を指定
  • N: ランダムサンプリングされる観測値の数
  • type: partial(default) , conditional, accumulatedから指定
  • variable_type: numericalは連続変数のみ、categoricalはカテゴリ変数のみ
pdp <- model_profile(
  explainer = explainer
  # variables = c("Age", "Fare")
  )

plot(pdp)

グループ化して、変数と予測の関係を見ることもできる。

pdp_g1 <- model_profile(explainer = explainer, 
                        variables = "Age", groups = "Sex_male")

pdp_g2 <- model_profile(explainer = explainer, 
                        variables = "Fare", groups = "Sex_male")

plot(pdp_g1) + plot(pdp_g2)

Ceteris-paribusプロファイル

Ceteris-paribusプロファイルはインスタンスレベルに書くほうが適切かもしれないが、データを絞らなければデータレベルでの振る舞いを確認できるとも考えられるので、ここでまとめておく。 おそらくIndividual Conditional Expectation(ICE)と呼ばれるものと同じ。

これは、変数の値が変化した場合にモデルの予測がどのように変化するかを示すもの。1つの線がインスタンス1つに対応する。点は実際のインスタンスの観測値を表す。グループ化PDPの結果からわかるように、どちらの変数でも上下で別れているが、これは性別の影響が関係していると思われる。

ice_age <- predict_profile(
  explainer = explainer,
  new_observation = df_test_baked,
  variables = "Age"
)

ice_fare <- predict_profile(
  explainer = explainer,
  new_observation = df_test_baked,
  variables = "Fare"
)

plot(ice_age, variables = "Age") + plot(ice_fare, variables = "Fare")