Please keep in mind, that DALEXtra now supports usage of keras models imported from Python with dedicated explain_keras function. It is recommended to use it.

1 Introduction

DALEX is designed to work with various black-box models like tree ensembles, linear models, neural networks etc. Unfortunately R packages that create such models are very inconsistent. Different tools use different interfaces to train, validate and use models. Fortunately DALEX can handle it all easily.

In this vignette we will show explanations for Multi Layer Perceptron (MLP) trained with h2o and keras packages.

2 Classification use case - Titanic data

library(dplyr)
library(DALEX)
library(DALEXtra)
library(keras)
library(titanic)
library(fastDummies)
library(h2o)
set.seed(123)

To illustrate applications of DALEX to classification problems we will use the titanic_train dataset available in the titanic package. Our goal is to predict the probability that the person will survive the catastrophe based on selected features such as cabin class, sex, age, number of family members on the ship, fare and embarkation. In both packages we will try to build a model with same architecture and inputs.

In the begining we will prepare and clean the data. The first difference between keras and h2o is that in keras predictors and exaplined variable have to be specified as separate numeric tensors. This means that if we want to insert a factor into keras model we have to encode it first. We will do one-hot encoding using dummy_cols function from fastDummies package (for both, h2o and keras, so the number and type of inputs will be identical).

# Data preparation and cleaning
titanic_small <- titanic_train %>%
                 select(Survived, Pclass, Sex, Age, SibSp, Parch, Fare, Embarked) %>%
                 mutate_at(c("Survived", "Sex", "Embarked"), as.factor) %>%
                 mutate(Family_members = SibSp + Parch) %>% # Calculate family members
                 na.omit() %>%
                 dummy_cols() %>%
                 select(-Sex, -Embarked, -Survived_0, -Survived_1, -Parch, -SibSp)

print(head(titanic_small))
##   Survived Pclass Age    Fare Family_members Sex_female Sex_male Embarked_
## 1        0      3  22  7.2500              1          0        1         0
## 2        1      1  38 71.2833              1          1        0         0
## 3        1      3  26  7.9250              0          1        0         0
## 4        1      1  35 53.1000              1          1        0         0
## 5        0      3  35  8.0500              0          0        1         0
## 6        0      1  54 51.8625              0          0        1         0
##   Embarked_C Embarked_Q Embarked_S
## 1          0          0          1
## 2          1          0          0
## 3          0          0          1
## 4          0          0          1
## 5          0          0          1
## 6          0          0          1
# Data preprocessing for Keras
titanic_small_y <- titanic_small %>% 
                   select(Survived) %>%
                   mutate(Survived = as.numeric(as.character(Survived))) %>%
                   as.matrix()

titanic_small_x <- titanic_small %>%
                   select(-Survived) %>%
                   as.matrix()

3 Models

We can build MLP model in h2o using h2o.deeplearning function. To do this w need to first initialize h2o and we need to convert titanic_small to H2OFrame.

h2o.init()
h2o.no_progress()

titanic_h2o <- as.h2o(titanic_small, destination_frame = "titanic_small")

model_h2o <- h2o.deeplearning(x = 2:11,
                              y = 1,
                              training_frame = "titanic_small",
                              activation = "Rectifier", # ReLU as activation functions
                              hidden = c(16, 8), # Two hidden layers with 16 and 8 neurons
                              epochs = 100,
                              rate = 0.01, # Learning rate
                              adaptive_rate = FALSE, # Simple SGD
                              loss = "CrossEntropy")

To build neural network in keras we have to stack layers, in case of MLP we will use layer_dense.

model_keras <- keras_model_sequential() %>% # Initialization
  layer_dense(units = 16, # 16 neurons in first hidden layer
              activation = "relu", # ReLU as activation function
              input_shape = c(10)) %>% # Ten inputs
  layer_dense(units = 8, activation = "relu") %>%
  layer_dense(units = 1, activation = "sigmoid")

model_keras %>% compile(
  optimizer = optimizer_sgd(lr = 0.01), # Simple SGD with learning rate 0.01
  loss = "binary_crossentropy",
  metrics = c("accuracy")
)

history <- model_keras %>% fit(
  titanic_small_x,
  titanic_small_y,
  epochs = 40,
  validation_split = 0.2
)

We can now check predictions from both models. Remember that h2o and keras are using different implementations of same algorithms and there is a lot of other, often randomized parameters like initial weights values, so to get exactly the same results you would have to consider all of them.

henry <- data.frame(
  Pclass = 1,
  Age = 8,
  Fare = 72,
  Family_members = 0,
  Sex_male = 1,
  Sex_female = 0,
  Embarked_S = 0,
  Embarked_C = 1,
  Embarked_Q = 0,
  Embarked_ = 0
)

henry_h2o <- as.h2o(henry, destination_frame = "henry")
henry_keras <- as.matrix(henry)
predict(model_h2o, henry_h2o) %>% print()
predict(model_keras, henry_keras) %>% print()

4 The explain() function

The first step of using the DALEX package is to wrap-up the black-box model with meta-data that unifies model interfacing.

To create an explainer we use explain() function. For the models created by h2o package we will use the explain_h2o funtion from DALEXtra package.

explainer_titanic_h2o   <- DALEXtra::explain_h2o(model = model_h2o,
                                                 data = titanic_small[ , -1],
                                                 y = as.numeric(as.character(titanic_small$Survived)),
                                                 label = "MLP_h2o",
                                                 colorize = FALSE)

explainer_titanic_keras <- DALEX::explain(model = model_keras,
                                          data = titanic_small_x,
                                          y = as.numeric(titanic_small_y),
                                          type = "classification",
                                          label = "MLP_keras",
                                          colorize = FALSE)

5 Model performance

Function model_performance() calculates predictions and residuals for validation dataset.

mp_titinic_h2o   <- model_performance(explainer_titanic_h2o)
mp_titanic_keras <- model_performance(explainer_titanic_keras)

Generic function print() returns quantiles for residuals.

mp_titinic_h2o
## Measures for:  classification
## recall     : 0.7068966 
## precision  : 0.8874459 
## f1         : 0.7869482 
## accuracy   : 0.8445378 
## auc        : 0.8983897
## 
## Residuals:
##            0%           10%           20%           30%           40% 
## -9.954327e-01 -3.220674e-01 -1.592363e-01 -1.109982e-01 -7.308569e-02 
##           50%           60%           70%           80%           90% 
## -2.791604e-02  2.880694e-06  2.809896e-03  3.623441e-02  5.488800e-01 
##          100% 
##  9.904999e-01
mp_titanic_keras
## Measures for:  classification
## recall     : 0.4689655 
## precision  : 0.7272727 
## f1         : 0.5702306 
## accuracy   : 0.7128852 
## auc        : 0.7532775
## 
## Residuals:
##         0%        10%        20%        30%        40%        50%        60% 
## -0.8265051 -0.4426317 -0.3477609 -0.3010883 -0.2615811 -0.2057302  0.1789826 
##        70%        80%        90%       100% 
##  0.4714622  0.5191331  0.6694979  0.9073019

Generic function plot() shows reversed empirical cumulative distribution function for absolute values from residuals. Plots can be generated for one or more models.

plot(mp_titinic_h2o, mp_titanic_keras)

We are also able to use the plot() function to get an alternative comparison of residuals. Setting the geom = "boxplot" parameter we can compare the distribution of residuals for selected models.

plot(mp_titinic_h2o, mp_titanic_keras, geom = "boxplot")

6 Variable importance

Using he DALEX package we are able to better understand which variables are important.

Model agnostic variable importance is calculated by means of permutations. We simply substract the loss function calculated for validation dataset with permuted values for a single variable from the loss function calculated for validation dataset.

This method is implemented in the model_parts() function.

vi_titinic_h2o   <- model_parts(explainer_titanic_h2o)
vi_titinic_keras <- model_parts(explainer_titanic_keras)

# We can compare all models using the generic plot() function.

plot(vi_titinic_h2o, vi_titinic_keras)

Length of the interval coresponds to a variable importance. Longer interval means larger loss, so the variable is more important.

For better comparison of the models we can hook the variabe importance at 0 using the type="difference".

vi_titinic_h2o   <- model_parts(explainer_titanic_h2o, type="difference")
vi_titinic_keras <- model_parts(explainer_titanic_keras, type="difference")
plot(vi_titinic_h2o, vi_titinic_keras)

7 Variable response

As previously we create explainers which are designed to better understand the relation between a variable and model output: PDP plots and ALE plots.

7.1 Partial Dependence Plot

Partial Dependence Plots (PDP) are one of the most popular methods for exploration of the relation between a continuous variable and the model outcome.

mp_age_h2o    <- model_profile(explainer_titanic_h2o, variable =  "Age")
mp_age_keras  <- model_profile(explainer_titanic_keras, variable =  "Age")
plot(mp_age_h2o, mp_age_keras)

7.2 Acumulated Local Effects plot

Acumulated Local Effects (ALE) plot is the extension of PDP, that is more suited for highly correlated variables.

mp_age_h2o    <- model_profile(explainer_titanic_h2o, variable =  "Age", type = "accumulated")
mp_age_keras  <- model_profile(explainer_titanic_keras, variable =  "Age", type = "accumulated")
plot(mp_age_h2o, mp_age_keras)

8 Prediction understanding

The function predict_parts() is a wrapper around a breakDown package. Model prediction is visualized with Break Down Plots, which show the contribution of every variable present in the model. Function predict_parts() generates variable attributions for selected prediction. The generic plot() function shows these attributions.

pp_h2o   <- predict_parts(explainer_titanic_h2o, henry, type = "break_down")
pp_keras <- predict_parts(explainer_titanic_keras, henry_keras, type = "break_down")

plot(pp_h2o)

plot(pp_keras)

9 Session info

sessionInfo()
## R version 4.0.2 (2020-06-22)
## Platform: x86_64-w64-mingw32/x64 (64-bit)
## Running under: Windows 10 x64 (build 18363)
## 
## Matrix products: default
## 
## locale:
## [1] LC_COLLATE=English_United States.1252 
## [2] LC_CTYPE=English_United States.1252   
## [3] LC_MONETARY=English_United States.1252
## [4] LC_NUMERIC=C                          
## [5] LC_TIME=English_United States.1252    
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
## [1] h2o_3.32.0.1      fastDummies_1.6.3 titanic_0.1.0     keras_2.3.0.0    
## [5] DALEXtra_2.0      DALEX_2.0.1       dplyr_1.0.2      
## 
## loaded via a namespace (and not attached):
##  [1] Rcpp_1.0.5        pillar_1.4.6      compiler_4.0.2    ingredients_2.0  
##  [5] bitops_1.0-6      base64enc_0.1-3   tools_4.0.2       bit_4.0.4        
##  [9] zeallot_0.1.0     digest_0.6.25     jsonlite_1.7.1    evaluate_0.14    
## [13] lifecycle_0.2.0   tibble_3.0.3      gtable_0.3.0      lattice_0.20-41  
## [17] pkgconfig_2.0.3   rlang_0.4.7       Matrix_1.2-18     yaml_2.2.1       
## [21] xfun_0.19         stringr_1.4.0     knitr_1.30        generics_0.1.0   
## [25] vctrs_0.3.4       rappdirs_0.3.1    bit64_4.0.5       grid_4.0.2       
## [29] tidyselect_1.1.0  data.table_1.13.4 reticulate_1.18   glue_1.4.2       
## [33] R6_2.4.1          iBreakDown_1.3.1  rmarkdown_2.6     farver_2.0.3     
## [37] whisker_0.4       ggplot2_3.3.2     purrr_0.3.4       magrittr_2.0.1   
## [41] codetools_0.2-16  tfruns_1.4        scales_1.1.1      ellipsis_0.3.1   
## [45] htmltools_0.5.0   colorspace_1.4-1  labeling_0.3      tensorflow_2.2.0 
## [49] stringi_1.5.3     RCurl_1.98-1.2    munsell_0.5.0     crayon_1.3.4