I am interested in how lime performs on other random forest models. This journal applies all of the input lime options to a random forest model fit to the iris data.
# Load libraries
library(caret)
library(e1071)
library(furrr)
library(future)
library(gretchenalbrecht)
library(limeaid)
# Source functions
source("../../code/helper_functions.R")
Iris is randomly split into training and testing datasets such that all of the species are represented in the testing data.
# Set a seed
set.seed(20190311)
# Randomly select two cases from within each of the three species of irises
selected <- c(sample(1:50, 4), sample(51:100, 4), sample(101:150, 4))
# Determine the case numbers that were not selected
cases <- 1:150
not_selected <- cases[!(cases %in% selected)]
# Split up the features of the data into training and testing parts
iris_train <- iris[-selected, ]
iris_test <- iris[selected, ] %>% mutate(case = 1:n())
A random forest model is fit to the iris data. The predictions from the model for the testing data are shown in the table below with the actual observed values. The model gets all of the predictions correct.
# Random forest model run on the iris training data
iris_model <- train(x = iris_train %>% select(-Species),
y = iris_train %>% pull(Species),
method = 'rf')
# Predictions made using the random forest model on the testing data
iris_model_predict <- predict(iris_model, iris_test %>% select(-Species, - case))
# Matrix of observed and predicted values
iris_test_obs_pred <- data.frame(Observed = iris_test %>% pull(Species),
Predicted = iris_model_predict)
# Print the table
knitr::kable(iris_test_obs_pred, align = 'c')
| Observed | Predicted |
|---|---|
| setosa | setosa |
| setosa | setosa |
| setosa | setosa |
| setosa | setosa |
| versicolor | versicolor |
| versicolor | versicolor |
| versicolor | versicolor |
| versicolor | versicolor |
| virginica | virginica |
| virginica | versicolor |
| virginica | virginica |
| virginica | virginica |
LIME is applied using all density estimation methods with 2 to 5 bins for all of the bin based methods.
# Create a file path
file_iris_explanations = "../../../data/iris_explanations.rds"
# Implement and save or load LIME explanations
if (!file.exists(file_iris_explanations)) {
# Apply LIME
iris_explanations <-
apply_lime(
train = iris_train %>% select(names(iris %>% select(-Species))),
test = iris_test %>% select(-case, -Species),
model = iris_model,
label = "virginica",
n_features = 3,
sim_method = c('quantile_bins', 'equal_bins'),
nbins = 2:6,
feature_select = "auto",
dist_fun = "gower",
kernel_width = NULL,
gower_pow = c(0.5, 1, 10),
return_perms = FALSE,
all_fs = FALSE,
seed = 20190914
)
# Save the explanations
saveRDS(object = iris_explanations, file = file_iris_explanations)
} else {
# Load the explanations
iris_explanations <- readRDS(file_iris_explanations)
}
All of the results in this section are in terms of the species virginica.
plot_feature_heatmap(iris_explanations$explain, order_method = "PCA") +
scale_fill_gretchenalbrecht(palette = "last_rays", discrete = TRUE) +
scale_color_gretchenalbrecht(palette = "last_rays", discrete = TRUE, reverse = TRUE) +
theme_bw() +
theme(
axis.text.y = element_blank(),
axis.ticks.y = element_blank(),
strip.background = element_rect(color = "white", fill = "white"),
strip.text.y.right = element_text(angle = 0),
legend.position = "bottom"
) +
guides(fill = guide_legend(nrow = 3)) +
labs(y = "Case", color = "Complex Model Feature", fill = "Complex Model Feature")
plot_metrics(
iris_explanations$explain,
add_lines = TRUE,
line_alpha = 0.75
) +
theme_bw() +
theme(
strip.background = element_rect(color = "white", fill = "white"),
strip.placement = "outside"
)
sessionInfo()
## R version 4.0.2 (2020-06-22)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Catalina 10.15.6
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] assertthat_0.2.1 tree_1.0-40 forcats_0.5.0
## [4] stringr_1.4.0 dplyr_1.0.2 purrr_0.3.4
## [7] readr_1.3.1 tidyr_1.1.2 tibble_3.0.3
## [10] tidyverse_1.3.0 limeaid_0.0.1 gretchenalbrecht_0.1.0
## [13] furrr_0.1.0 future_1.18.0 e1071_1.7-3
## [16] caret_6.0-86 ggplot2_3.3.2.9000 lattice_0.20-41
##
## loaded via a namespace (and not attached):
## [1] colorspace_1.4-1 ellipsis_0.3.1 class_7.3-17
## [4] fs_1.5.0 rstudioapi_0.11 farver_2.0.3
## [7] listenv_0.8.0 fansi_0.4.1 prodlim_2019.11.13
## [10] lubridate_1.7.9 xml2_1.3.2 codetools_0.2-16
## [13] splines_4.0.2 lime_0.5.1 knitr_1.29
## [16] shinythemes_1.1.2 jsonlite_1.7.1 pROC_1.16.2
## [19] broom_0.7.0 cluster_2.1.0 dbplyr_1.4.4
## [22] shiny_1.5.0 compiler_4.0.2 httr_1.4.2
## [25] backports_1.1.10 Matrix_1.2-18 fastmap_1.0.1
## [28] cli_2.0.2 later_1.1.0.1 htmltools_0.5.0
## [31] tools_4.0.2 gtable_0.3.0 glue_1.4.2
## [34] reshape2_1.4.4 Rcpp_1.0.5 cellranger_1.1.0
## [37] vctrs_0.3.4 gdata_2.18.0 nlme_3.1-148
## [40] iterators_1.0.12 timeDate_3043.102 gower_0.2.2
## [43] xfun_0.17 globals_0.12.5 testthat_2.3.2
## [46] rvest_0.3.6 mime_0.9 lifecycle_0.2.0
## [49] gtools_3.8.2 dendextend_1.14.0 MASS_7.3-51.6
## [52] scales_1.1.1 ipred_0.9-9 TSP_1.1-10
## [55] hms_0.5.3 promises_1.1.1 parallel_4.0.2
## [58] yaml_2.2.1 gridExtra_2.3 rpart_4.1-15
## [61] stringi_1.5.3 highr_0.8 gclus_1.3.2
## [64] randomForest_4.6-14 foreach_1.5.0 checkmate_2.0.0
## [67] seriation_1.2-8 caTools_1.18.0 lava_1.6.7
## [70] shape_1.4.4 rlang_0.4.7 pkgconfig_2.0.3
## [73] bitops_1.0-6 evaluate_0.14 labeling_0.3
## [76] recipes_0.1.13 htmlwidgets_1.5.1 tidyselect_1.1.0
## [79] plyr_1.8.6 magrittr_1.5 R6_2.4.1
## [82] gplots_3.0.4 generics_0.0.2 DBI_1.1.0
## [85] pillar_1.4.6 haven_2.3.1 withr_2.2.0
## [88] survival_3.1-12 nnet_7.3-14 modelr_0.1.8
## [91] crayon_1.3.4 KernSmooth_2.23-17 rmarkdown_2.3
## [94] viridis_0.5.1 grid_4.0.2 readxl_1.3.1
## [97] data.table_1.13.0 blob_1.2.1 ModelMetrics_1.2.2.2
## [100] reprex_0.3.0 digest_0.6.25 xtable_1.8-4
## [103] httpuv_1.5.4 stats4_4.0.2 munsell_0.5.0
## [106] glmnet_4.0-2 registry_0.5-1 viridisLite_0.3.0