Machine learning models may provide magical predictions,...
...but being able to explain how many machine learning models produce the predictions is not an easy task.
General trends I've noticed:
Many recent papers
Often machine learning and computer science perspectives
Lots of European authors
Key resources for this talk:
Setting the Stage
Methods
Model Agnostic
Random Forest Specific
Neural Network Specific
Concluding Thoughts
Additional Methods and Resources
A Cautionary Conclusion
There are not agreed upon definitions...
Interpretable Machine Learning (Molnar 2020)
Methods for Interpreting and Understanding Deep Neural Networks (Montavon, Samek, and Muller 2017)
The Mythos of Model Interpretability (Lipton 2017)
Explaining Explanations: An Overview of Interpretability of Machine Learning (Gilpin et. al. 2019)
My definitions (based on a conversation with Nick Street (University of Iowa))...
Interpretability = the ability to directly use the parameters of a model to understand the mechanism of how the model makes predictions
ˆy=ˆβ0+ˆβ1x1+⋯+ˆβpxp
Explainability = the ability to use the model in an indirect manner to understand the relationships in the data captured by the mode
Figure from LIME paper (Ribeiro 2016)
Stop Explaining Black Box Machine Learning Models for High Stakes Decisions and Use Interpretable Models Instead by Cynthia Rudin:
"Explanations must be wrong. They cannot have perfect fidelity with respect to the original model. If the explanation was completely faithful to what the original model computes, the explanation would equal the original model..."
"...it is possible that the explanation leaves out so much information that it makes no sense."
Rudin has worked on developing machine learning models with direct interpretability
Advantages
Can be applied to any model
Convenient if comparing various types of predictive models
Disadvantages
From Interpretable Machine Learning (Molnar)
Example data in Interpretable Machine Learning - can be accessed here
bike <- load("data/bike.RData")data(bike)# Fit a random forestbike_mod = randomForest::randomForest(x = bike %>% dplyr::select(-cnt), y = bike$cnt)
season | yr | mnth | holiday | weekday | workingday | weathersit | temp | hum | windspeed | cnt | days_since_2011 |
---|
season | yr | mnth | holiday | weekday | workingday | weathersit | temp | hum | windspeed | cnt | days_since_2011 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
1 | SPRING | 2011 | JAN | NO HOLIDAY | SAT | NO WORKING DAY | MISTY | 8.175849 | 80.5833 | 10.749882 | 985 | 0 |
2 | SPRING | 2011 | JAN | NO HOLIDAY | SUN | NO WORKING DAY | MISTY | 9.083466 | 69.6087 | 16.652113 | 801 | 1 |
3 | SPRING | 2011 | JAN | NO HOLIDAY | MON | WORKING DAY | GOOD | 1.229108 | 43.7273 | 16.636703 | 1349 | 2 |
Purpose: Visualize marginal relationship between one (or two) predictors and model predictions
Estimated partial dependence function:
ˆfxint(xint)=1nn∑i=1ˆf(xint,x(i)other)
# Create a "predictor" object that # holds the model and the databike_pred = Predictor$new( model = bike_mod, data = bike)# Compute the partial dependence # function for temp and windspeedpdp = FeatureEffect$new( predictor = bike_pred, feature = c("hum", "temp"), method = "pdp") # Create the partial dependence plotpdp$plot() + viridis::scale_fill_viridis( option = "D") + labs(x = "Humidity", y = "Temperature", fill = "Prediction")
Partial dependence plot with two variables
Purpose: Similar to partial dependence plots, but consider each observation separately instead of taking an average.
Estimated individual conditional expectation function:
ˆf(i)xint(xint)=ˆf(xint,x(i)other)
# Compute the ICE functionice = FeatureEffect$new( predictor = bike_pred, feature = "temp", method = "ice")# Create the plotplot(ice)
ICE plot for temperature
"Sometimes it can be hard to tell whether the ICE curves differ between individuals because they start at different predictions. A simple solution is to center the curves at a certain point in the feature and display only the difference in the prediction to this point." - Molnar
# Center the ICE function for # temperature at the # minimum temperature and # include the pdpice_centered = FeatureEffect$new( predictor = bike_pred, feature = "temp", center.at = min(bike$temp), method = "pdp+ice")# Create the plotplot(ice_centered)
Accumulated Local Effects (ALE) Plots
Apley and Zhu (2016)
Feature Interaction Plots
Code for the plots on the previous slide
Accumulated Local Effects (ALE) Plot
# Compute the ALEsale = FeatureEffect$new( predictor = bike_pred, feature = c("hum", "temp"), method = "ale", grid.size = 40)# Plot the ALEsplot(ale) + scale_x_continuous( "Relative Humidity") + scale_y_continuous( "Temperature")+ viridis::scale_fill_viridis( option = "D") + labs(fill = "ALE")
Feature Interaction Plot
# Compute the interaction metricsint = Interaction$new( predictor = bike_pred, grid.size = 100, feature = "season") # Plot the interaction metricsplot(int) + scale_x_continuous( "2-way interaction strength")
Provide a nice overview of the predictions across
PCP plot with bike data (made with ggpcp)
Code for the plot on the previous slide
# Determine order of features bike_ft_ordered = bike_vi %>% arrange(desc(IncNodePurity)) %>% pull(var)# Create the pcpbike %>% mutate(rf_pred = predict(bike_mod)) %>% ggplot(aes(color = rf_pred)) + ggpcp::geom_pcp(aes(vars = dplyr::vars(all_of(bike_ft_ordered))), alpha = 0.4) + viridis::scale_color_viridis(option = "D") + labs(x = "Featured ordered by feature importance (left to right)", y = "Standardized Feature Value", color = "Random Forest Prediction") + theme(legend.position = "bottom") + guides(color = guide_colourbar(barwidth = 15))
R package [Rfviz] for interactive parallel coordinate plots with random forest models, but it could be extended to other machine learning models.
Background
Idea originally proposed by Breiman (2001) for random forests
Adapted to be used with all models by Fisher, Rudin, and Dominici (2018)
Concept
Measure feature importance by seeing how much the prediction error is affected when a feature is permuted
important feature: one that affects the prediction error when changed
non-important feature: one that does not affect the prediction error when changed
Permutation feature importance of bike data random forest
# Create the predictor # (seemingly FeatureImp # requires y)bike_pred = Predictor$new( model = bike_mod, data = bike, y = bike$cnt)# Compute the feature# importance valuesbike_imp = FeatureImp$new( predictor = bike_pred, loss = 'mae')# Plot the feature# importance valuesplot(bike_imp)
Point = median permutation importance
Bars = 5th and 95th permutation importance quantiles
Permutation based feature importance method that returns p-values
Example from the paper comparing Gini importance values to their permutation feature importance method with p-values
Three additional measures for feature importance:
Individual Conditional Importance (ICI)
Partial Importance (PI)
Shapley Feature Importance (SFIMP)
ICI and PI available in the featureImportance R package
Idea: Use an interpretable model to explain a black-box model
Procedure:
Train a black-box model
Obtain predictions from black-box model on a set of data (training data or other)
Fit an interpretable model (linear regression model, tree, etc)
black-box predictions ~ predictor variables
Cautions: How to know if the global surrogate is a good enough approximation of the complex model?
Using a classification tree as the global surrogates for the random forest model fit to the sine data
LIME = Local Interpretable Model-Agnostic Explanations
Idea: Use game theory to determine contributions of predictor variables to one prediction of interest
Game Theory Connection: Shapley values are "a method for assigning payouts to players depending on their contribution to the total payout."
Game Theory Term | Machine Learning Meaning |
---|---|
collaborative game | machine learning model prediction for one prediction |
players | predictor variables |
payout | contribution of a predictor variable to the prediction |
gain | actual prediction - average prediction for all instances |
Interpretation: "The value of the jth feature contributed ϕj to the prediction of this particular instance compared to the average prediction for the dataset."
# Select obs of interest and perpare datax_int = bike[names(bike) != 'cnt'][285,]# Compute prediction valuesavg_pred = mean(predict(bike_mod))actual_pred = predict(bike_mod, newdata = bike[names(bike) != 'cnt'][285,])diff_pred = actual_pred - avg_pred# Compute shapley valuespredictor = Predictor$new(model = bike_mod, data = bike[names(bike) != 'cnt'])shapley = Shapley$new(predictor = predictor, x.interest = x_int)# Create the plotplot(shapley) + scale_y_continuous("Feature value contribution") + ggtitle(sprintf("Actual prediction: %.0f\nAverage prediction: %.0f\nDifference: %.0f", actual_pred, avg_pred, diff_pred))
Shapley values for one observation from the bike rental random forest
Idea: Aggregation of many trees (bootstrap data and randomly select predictors for each tree)
Mean decrease in impurity (gini importance): measures the average improvement in node purity for a predictor variable
# Extract the importance valuesbike_rfimp <- bike_mod$importance# Put the feature importance in a dfbike_vi <- data.frame(var = rownames(bike_rfimp), bike_rfimp) %>% arrange(IncNodePurity)# Create a feature importance plotbike_vi %>% mutate(var = factor(x = var, levels = bike_vi$var)) %>% ggplot(aes(x = var, y = IncNodePurity)) + geom_col() + coord_flip() + labs(x = "Feature")
Bike random forest feature importance plot
Cut points from all trees for two predictor variables
R package for visually exploring random forests fit using randomForests or randomForest
library(ggRandomForests)
Out-of-bag errors versus number of trees
plot(gg_error(bike_mod)) + theme_gray()
Variable importance plot
plot(gg_vimp(bike_mod)) + theme_gray()
Previously mentioned...R package for interacting with parallel coordinate plots for random forests
# Prepare datarfprep <- rfviz::rf_prep(x = bike[names(bike) != "cnt"], y = bike$cnt)# View plotsrfviz::rf_viz(rfprep, input = TRUE, imp = TRUE, cmd = TRUE, hl_color = 'black')
Method that creates plots similar to partial dependence plots
From the paper:
"We suggest to first use feature contributions, a method to decompose trees by splitting features, and then subsequently perform projections. The advantages of forest floor over partial dependence plots is that interactions are not masked by averaging."
R package: forestFloor
Forest floor plots (figure from the paper)
Idea: Combination of many non-linear regression models
Image source
Idea: Determine values of predictor variables that maximize activation functions at a specific "location" in the neural network
Formula Version: For a node in the network: argmax
Image from Olah, Mordvintsev, and Schubert (2017)
Purpose: To identify the features that are important for making a prediction for a single observation
Concept: Makes use of back-propagation algorithm to determine gradient values associated with a predictor variable which indicate how much a predictor variable influences the prediction
In practice:
Commonly used with convolutional neural networks to identify important pixels in an image
Many algorithms for creating saliency maps
Image from Simonyan, Vedaldi, and Zisserman (2014)
Idea: Make use of the Grand Tour to visualize behaviors of neural networks
Image from Simonyan, Vedaldi, and Zisserman (2014)
Model Agnostic Methods
Scoped Rules (Anchors) Ribeiro, Singh, Guestrin (2018)
Sensitivity Analyses Cortez and Embrechts (2013)
Example-Based Explanations
Counterfactual examples
Adversarial examples
Prototypes and criticisms
Influential instances
Model Specific
Neural networks: See articles on Distill
Tree ensembles:
General Model Viz
Many more....
Additional Resources for Overviews for Explainable Machine Learning
Mohseni, Zarei, and Ragan (2019)
Review of method types
model agnostic versus model specific
global versus local explanations
static versus interactive
models versus metrics
Good News
many methods to try out
lots of research opportunities
opportunity for creating useful visualizations
Cautions
this is a relatively new field
unsure which are the most trusted methods
a seemingly simple method may not be so simple
model based methods
add an additional layer of complexity to an already complex situation
almost seems naive to expect a simple model to capture the complex relationship in a black-box model
Machine learning models may provide magical predictions,...
Keyboard shortcuts
↑, ←, Pg Up, k | Go to previous slide |
↓, →, Pg Dn, Space, j | Go to next slide |
Home | Go to first slide |
End | Go to last slide |
Number + Return | Go to specific slide |
b / m / f | Toggle blackout / mirrored / fullscreen mode |
c | Clone slideshow |
p | Toggle presenter mode |
t | Restart the presentation timer |
?, h | Toggle this help |
Esc | Back to slideshow |