UPDATE: 2022-12-25 16:39:37
tidymodels
パッケージの使い方をいくつかのノートに分けてまとめている。tidymodels
パッケージは、統計モデルや機械学習モデルを構築するために必要なパッケージをコレクションしているパッケージで、非常に色んなパッケージがある。ここでは、今回はDALEX
パッケージについてまとめていく。モデルの数理的な側面や機械学習の用語などは、このノートでは扱わない。
下記の公式ドキュメントやtidymodels
パッケージに関する書籍を参考にしている。
説明可能性には、大域的(データレベル)な説明と局所的(インスタンスレベル)な説明が出てくるが、今回は双方を扱う。
DALEX
パッケージの目的DALEX
パッケージは、DrWhyという取り組みの中に存在する説明可能なAI(XAI)用のツールのコレクションの1つパッケージ。この取り組みでは、機械学習のモデルが下記の観点から検討されることを望んでいる。
モデルの中で実際に何が起こっているかを理解することや、モデルが出力する予測値がどのように生成されたのかを理解することをこの取組では目指しており、そのパッケージの1つとして、DALEX
パッケージがある。DALEX
パッケージについては下記の公式サイトおよび書籍を参考にした。
数理的な側面についてはここでは扱わないが、ここで紹介する機械学習を解釈する手法については下記の書籍がわかりやすい。
DALEX
パッケージを利用するためには、モデルが必要になるので、モデル作成を行っておく。動けば良いモデルであって、役に立つモデルではない点は注意。ここではタイタニックのデータを利用する。これは、多くの人によっては分析がなされているため、生存するために必要な特徴量が明らかになっているため、DALEX
パッケージの関数の出力を理解しやすくするため。モデルの説明は下記の通り。
Survived
: 0=死亡、1=生存Pclass
: 旅客クラス(1=1等, 2=2等, 3=3等)Sex
: 性別(male=男性, female=女性)Age
: 年齢Sibsp
: 同乗兄弟,配偶者数Parch
: 同乗親,子供数Fare
: 旅客運賃Cabin
: 客室番号Embarked
: 出港地(C=Cherbourg, Q=Queenstown,
S=Southampton)生存のために重要な特徴は下記の通り。
Sex
):
男性よりも女性や子供の方が生存率が高い。つまり、男性は生存しにくい。Age
):
年齢が若いほど生存率が高い。つまり、年齢が高いと生存しにくい。Pclass
):
1等に比べて2,3等の乗客の生存率が低い。つまり、低クラス(=2,3)だと生存しにくい。SibSp
,
Parch
):家族がいる乗客の生存率が高い。つまり、独り身だとだと生存しにくい。モデル作成を行なう。
library(tidymodels)
library(tidyverse)
library(DALEX)
library(DALEXtra)
library(lime)
library(localModel)
library(vip)
library(patchwork)
<- read_csv("https://raw.githubusercontent.com/ogrisel/parallel_ml_tutorial/master/notebooks/titanic_train.csv") %>%
df select(-Name, -PassengerId, -Ticket)
# rsample
set.seed(1989)
<- df %>% initial_split(prop = 0.8, strata = "Survived")
df_initial <- df_initial %>% training()
df_train <- df_initial %>% testing()
df_test
set.seed(1989)
<-
df_train_stratified_splits vfold_cv(df_train, v = 5, strata = "Survived")
# recipes
<- recipe(Survived ~ Pclass + Sex + Age + SibSp + Parch + Fare + Embarked, data = df_train) %>%
recipe step_impute_median(Age, Fare) %>%
step_impute_mode(Embarked) %>%
step_mutate_at(Survived, Pclass, Sex, Embarked, fn = factor) %>%
step_mutate(Travelers = SibSp + Parch + 1) %>%
step_rm(SibSp, Parch) %>%
step_dummy(all_nominal_predictors(), one_hot = FALSE)
# %>% step_normalize(all_numeric_predictors())
# parsnip
<- rand_forest(mtry = tune(), trees = tune(), min_n = tune()) %>%
model set_engine("ranger", importance = "impurity") %>%
set_mode("classification")
# workflows
<- workflow() %>%
workflow add_recipe(recipe) %>%
add_model(model)
# tune / dials
set.seed(1989)
<- workflow %>%
hyper_parameter_grid extract_parameter_set_dials() %>%
update(
mtry = mtry(range = c(4, 8)),
trees = trees(range = c(500, 3000)),
min_n = min_n(range = c(50, 100)),
%>%
) grid_latin_hypercube(size = 3)
# tune / dials
<-
workflow_tuned %>%
workflow tune_grid(
resamples = df_train_stratified_splits,
grid = hyper_parameter_grid,
metrics = metric_set(accuracy),
control = control_resamples(
extract = extract_model,
save_pred = TRUE
)
)
# workflow_tuned %>%
# collect_metrics()
<- workflow_tuned %>%
better_paramters select_best(metric = "accuracy")
<- workflow %>%
better_workflow finalize_workflow(parameters = better_paramters)
set.seed(1989)
<-
model_trained_better_workflow %>%
better_workflow fit(df_train)
model_trained_better_workflow
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: rand_forest()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 6 Recipe Steps
##
## • step_impute_median()
## • step_impute_mode()
## • step_mutate_at()
## • step_mutate()
## • step_rm()
## • step_dummy()
##
## ── Model ───────────────────────────────────────────────────────────────────────
## Ranger result
##
## Call:
## ranger::ranger(x = maybe_data_frame(x), y = y, mtry = min_cols(~4L, x), num.trees = ~636L, min.node.size = min_rows(~64L, x), importance = ~"impurity", num.threads = 1, verbose = FALSE, seed = sample.int(10^5, 1), probability = TRUE)
##
## Type: Probability estimation
## Number of trees: 636
## Sample size: 712
## Number of independent variables: 8
## Mtry: 4
## Target node size: 64
## Variable importance mode: impurity
## Splitrule: gini
## OOB prediction error (Brier s.): 0.1290546
あとで前処理済みデータが必要になるので、ここで作成しておく。
<- recipe %>%
df_test_baked prep() %>%
bake(df_test)
DALECX
パッケージを使う準備として、Models’ explainers
を作る必要がある。これはパッケージが異なるとモデルの内部構造やオブジェクトの中身が変わってしまうため、統一したインターフェースを提供するオブジェクトに変換する必要がある。これがexplainer
の役割で、explain
関数で作成できる。引数は下記の通り。
model
: 変換するモデルを指定data
:
モデルが適用されるデータフレーム。データは、目的変数は含まない。y
:
data
で与えられたデータに対応する説明変数predict_function
:
予測スコアを返す関数。デフォルトはpredict
関数が使用されるが、エラーを引き起こす可能性があるverbose
: 診断メッセージを表示するかどうかprecalculate
:
予測値と残差の計算を説明文作成時に行うかどうかmodel_info
:
モデルに関する情報を提供する名前付きリストtype
:
モデルの種類に関する情報で、classification
またはregression
を指定label
: モデルの一意な名前<-
explainer ::explain(
DALEXmodel = model_trained_better_workflow %>% extract_fit_parsnip(),
data = df_test_baked %>% select(!Survived),
y = as.integer(df_test_baked %>% pull(Survived)),
label = "randomforest"
)
## Preparation of a new explainer is initiated
## -> model label : randomforest
## -> data : 179 rows 8 cols
## -> data : tibble converted into a data.frame
## -> target variable : 179 values
## -> predict function : yhat.model_fit will be used ( default )
## -> predicted values : No value for predict function target column. ( default )
## -> model_info : package parsnip , ver. 1.0.3 , task classification ( default )
## -> predicted values : numerical, min = 0.02319849 , mean = 0.3965378 , max = 0.9908621
## -> residual function : difference between y and yhat ( default )
## -> residuals : numerical, min = 0.01477724 , mean = 0.988937 , max = 1.962641
## A new explainer has been created!
class(explainer)
## [1] "explainer"
explain
関数を適用すると、explainer
クラスのオブジェクトが作成される。このオブジェクトは以下を含むリスト。
model
: 説明されるモデルdata
: モデルが適用されたデータy
: データに対応する従属変数の観測値y_hat
: データにモデルを適用して得られた予測値residuals
: y
と y_hat
に基づいて計算された残差predict_function
:
モデルの予測値を得るために使用される関数residual_function
:
残差を求めるために使用される関数class
: モデルのクラスlabel
: モデルのラベル。model_info
:
モデルに関する情報を提供する名前付きリストexplainer
クラスのオブジェクトには、モデルの説明を作成するために必要なすべての要素が含まれる。
インスタンスレベルでは、モデルが特定の1つの観測値に対してどのように予測をもたらすかを理解するのに役立つ。つまり、タイタニックデータであれば、乗客1人のレベルの話。例えば、説明変数の値を調整して、仮想的な人物データを作成すると、その仮想人物の生存予測はどうなるのか、どの変数が有効なのか、特定の変数を上下させるとどうなるのか、これらの説明を得たいときにインスタンスレベルでの手法は役に立つ。
Break-down
Plotsは、単一のオブザベーション(タイタニックデータであればいち個人)に対するモデルの予測を理解しようとするとき、どの変数が予測結果に対して、どのように貢献しているのかを調べられる。Break-down
Plotsは、加算的でわかりやすいという利点をもつ一方で、交互作用などを含む場合は誤解を生む可能性がある。ただ、交互作用を含めたモデルであれば簡単に拡張できる(predict_parts(type = "break_down_interactions")
を指定する)。
説明のために仮想的な人物を作成する。survive_low
は生存率が低い人物で、survive_high
は生存率が高い人物。
<- tibble(
survive_low Age = 60,
Fare = 15,
Travelers = 1,
Pclass_X2 = 0,
Pclass_X3 = 1,
Sex_male = 1,
Embarked_Q = 0,
Embarked_S = 1,
)
<- tibble(
survive_high Age = 30,
Fare = 30,
Travelers = 2,
Pclass_X2 = 0,
Pclass_X3 = 0,
Sex_male = 0,
Embarked_Q = 0,
Embarked_S = 0,
)
Break-down
Plotsはpredict_parts
関数で作成できる。survive_low
は生存率は0.127予測されるが、生存率を下げる要因として、男性であること(Sex_male=1
)、等級が低いこと(Pclass_X3=1
)、年齢が高いこと(Age=60
)、これらによって生存率が下がっていることがわかる。
<- predict_parts(explainer = explainer,
break_down_low new_observation = survive_low,
# "break_down","shap","oscillations","break_down_interactions"を指定できる
type = "break_down")
plot(break_down_low)
一方で、
survive_low
は生存率は0.948と予測されるが、生存率をあげる要因として、女性であること(Sex_male=0
)、等級が高いこと(Pclass_X3=0
)、料金が高いこと(Fare=30
)、これらによって生存率が上がっていることがわかる。
<- predict_parts(explainer = explainer,
break_down_high new_observation = survive_high,
# "break_down","shap","oscillations","break_down_interactions"を指定できる
type = "break_down")
plot(break_down_high)
predict_parts
関数は他にも下記の引数が取れる。
order
: 説明変数の順序を指定keep_distributions
:
分布をバイオリンプロットとして表示<- predict_parts(explainer = explainer,
break_down_high2 new_observation = survive_high,
# "break_down","shap","oscillations","break_down_interactions"を指定できる
type = "break_down",
order = c("Pclass_X2", "Embarked_Q", "Age", "Travelers",
"Embarked_S", "Fare", "Pclass_X3", "Sex_male"),
keep_distributions = TRUE
)plot(break_down_high2, plot_distributions = TRUE)
Break-down Plotsの問題点として、同じデータ出会っても、変数の並べ方で見方が大きく変わってしまう点が挙げられる。下記の3枚のBreak-down Plotsは同じデータであるにもかかわらず、分解のされ方が変わっているため、この点には注意が必要である。
<- tibble(
survive_middle Embarked_S = 1,
Embarked_Q = 0,
Sex_male = 0,
Pclass_X3 = 1,
Pclass_X2 = 0,
Travelers = 1,
Fare = 15,
Age = 60,
)
<- predict_parts(explainer = explainer,
break_down_middle1 new_observation = survive_middle,
type = "break_down")
<- predict_parts(explainer = explainer,
break_down_middle2 new_observation = survive_middle,
type = "break_down",
order = c("Pclass_X2", "Embarked_Q", "Age", "Travelers",
"Embarked_S", "Fare", "Pclass_X3", "Sex_male")
)
<- predict_parts(explainer = explainer,
break_down_middle3 new_observation = survive_middle,
type = "break_down",
order = c( "Travelers", "Pclass_X2", "Embarked_S", "Fare",
"Pclass_X3", "Embarked_Q", "Sex_male", "Age")
)
plot(break_down_middle1) + plot(break_down_middle2) + plot(break_down_middle3)
この問題を解決する1つの手段として、変数の順序の影響を取り除くために、Shapley値を利用して平均値を計算する方法がある。SHapley Additive exPlanations(SHAP)は、協力ゲーム理論の分野で開発された「Shapley値」に基づいている。
赤と緑のバーでポジティブなのか、ネガティブなのかを表し、順序により大きさの違いを箱ひげ図で表現している。このプロットをみれば、生存の観点から重要な変数が女性(Sex_male=0
)であること、等級が低いこと(Pclass_X3=1
)がわかる。
<- predict_parts(
shap_survive_high explainer = explainer,
new_observation = survive_high,
type = "shap"
)
<- predict_parts(
shap_survive_middle explainer = explainer,
new_observation = survive_middle,
type = "shap"
)
<- predict_parts(
shap_survive_low explainer = explainer,
new_observation = survive_low,
type = "shap"
)plot(shap_survive_low) + plot(shap_survive_middle) + plot(shap_survive_high)
Break-down PlotsとSHAPは、計算量が大きくなりやすいため、説明変数が多くないモデルに適していると言われる。それを解決する方法の1つとして、LIMEがある。二値分類の場合であれば、LIMEは決定領域を単純な線形モデルを使用して、人口データから局所近似を作成することで、複雑なモデルの「ローカル説明者」として機能する。これが名前(LIME: Local Interpretable Model-agnostic Explanations)の由来である。
predict_surrogate
関数で計算でき、n_features
は
K-LASSO
法で選択される説明変数の最大数を指定し、n_permutations
で局所モデル近似のためにサンプリング人工データの数を指定する。十分な分散が計算できないダミー変数などについては、ワーニングが表示される。
計算結果のfeature
カラムには、K-LASSO法で非ゼロの係数が与えられた説明変数が示される。feature_value
カラムは提供された観測値の値であり、feature_desc
カラムは、元の説明変数がどのように変換されたかを表す。連続変数は特定のしきい値で分類されている。feature_weight
カラムは、K-LASSO
法で選択された変数の推定された係数を表し、model_intercept
カラムは切片を表す。これらの値の近似によってモデルを解釈できるようにする。
<- DALEXtra::model_type.dalex_explainer
model_type.dalex_explainer <- DALEXtra::predict_model.dalex_explainer
predict_model.dalex_explainer
<- predict_surrogate(
lime_survie_high explainer = explainer,
new_observation = survive_high,
n_features = 10,
n_permutations = 1000,
type = "lime"
)
lime_survie_high
## # A tibble: 8 × 11
## model_type case model_r2 model_int…¹ model…² feature featu…³ featur…⁴ featu…⁵
## <chr> <chr> <dbl> <dbl> <dbl> <chr> <dbl> <dbl> <chr>
## 1 regression 1 0.238 0.760 0.870 Age 30 9.52e-3 28.0 <…
## 2 regression 1 0.238 0.760 0.870 Fare 30 1.09e-2 13.0 <…
## 3 regression 1 0.238 0.760 0.870 Travel… 2 1.02e-1 Travel…
## 4 regression 1 0.238 0.760 0.870 Pclass… 0 -3.34e-3 Pclass…
## 5 regression 1 0.238 0.760 0.870 Pclass… 0 3.24e-4 Pclass…
## 6 regression 1 0.238 0.760 0.870 Sex_ma… 0 -8.70e-3 Sex_ma…
## 7 regression 1 0.238 0.760 0.870 Embark… 0 6.36e-4 Embark…
## 8 regression 1 0.238 0.760 0.870 Embark… 0 -2.06e-3 Embark…
## # … with 2 more variables: data <list>, prediction <dbl>, and abbreviated
## # variable names ¹model_intercept, ²model_prediction, ³feature_value,
## # ⁴feature_weight, ⁵feature_desc
棒グラフの長さは(絶対値)を示し、色は推定された係数の符号を示している。赤がネガティブで、青はポジティブ。
plot(lime_survie_high)
LIMEの計算結果をlocalModel
パッケージと組み合わせることで、各変数と生存率の関係をより詳細に理解できる。例えば、Age
を使って可視化すると、Age
の値が40歳あたりを基準に、年齢が高くなると生存確率が低下していることがわかる
<- predict_surrogate(
localmodel_survie_high explainer = explainer,
new_observation = survive_high,
seed = 1989,
size = 1000,
type = "localModel"
)
plot_interpretable_feature(localmodel_survie_high, "Age")
データセットレベルでは、インスタンスレベルとは異なり、モデルの予測に各変数がどのように機能するかを理解することが目的である。例えば、モデルの変数の中で重要な変数を知りたい場合、変数の高低がモデルの予測に与える影響などを知ることができる。
変数重要度はモデルの予測において、重要な変数を知ることができ、モデルの予測に影響を与えない変数はモデルから除外する、重要な変数のさらなる探索のための順序付け、ドメイン知識に基づいたモデルの有効性評価などができる。
変数の重要度は、注目している説明変数がモデルから取り除かれた場合、モデルのパフォーマンスがどの程度変化するかを測定することで重要度を測定する。重要な変数であれば、モデルから取り除かれると予測性能が下がることが予想され、重要ではない変数はモデルから取り除かれても、予測には影響しない、ということ。
変数重要度だけであれば、簡単に可視化できる。
%>%
model_trained_better_workflow extract_fit_parsnip() %>%
vip(num_features = 10) +
theme_bw()
model_parts
関数では、計算に利用する評価指標の違いから、見え方が異なっている。変数をシャッフルしてないときと、シャッフルしたときの比率(ratio
)が一緒であれば1に近くなり、一緒でなければ1から遠くなる。
set.seed(1989)
<- model_parts(
vip explainer = explainer,
type = "ratio", # differnce, ratio
n_sample = 1000, # サンプルする数
B = 10 # シャッフル回数
)
plot(vip)
Partial-dependence Plotは、モデルの予測値と説明変数の関係を表すもので、変数の大小が予測値にどのような影響を与えるのかを確認できる。
variables
: 計算する説明変数を指定N
: ランダムサンプリングされる観測値の数type
: partial(default)
,
conditional
, accumulated
から指定variable_type
:
numerical
は連続変数のみ、categorical
はカテゴリ変数のみ<- model_profile(
pdp explainer = explainer
# variables = c("Age", "Fare")
)
plot(pdp)
グループ化して、変数と予測の関係を見ることもできる。
<- model_profile(explainer = explainer,
pdp_g1 variables = "Age", groups = "Sex_male")
<- model_profile(explainer = explainer,
pdp_g2 variables = "Fare", groups = "Sex_male")
plot(pdp_g1) + plot(pdp_g2)
Ceteris-paribusプロファイルはインスタンスレベルに書くほうが適切かもしれないが、データを絞らなければデータレベルでの振る舞いを確認できるとも考えられるので、ここでまとめておく。 おそらくIndividual Conditional Expectation(ICE)と呼ばれるものと同じ。
これは、変数の値が変化した場合にモデルの予測がどのように変化するかを示すもの。1つの線がインスタンス1つに対応する。点は実際のインスタンスの観測値を表す。グループ化PDPの結果からわかるように、どちらの変数でも上下で別れているが、これは性別の影響が関係していると思われる。
<- predict_profile(
ice_age explainer = explainer,
new_observation = df_test_baked,
variables = "Age"
)
<- predict_profile(
ice_fare explainer = explainer,
new_observation = df_test_baked,
variables = "Fare"
)
plot(ice_age, variables = "Age") + plot(ice_fare, variables = "Fare")