I am interested in how lime performs on other random forest models. This journal applies all of the input lime options to a random forest model fit to the iris data.

# Load libraries
library(caret)
library(e1071)
library(furrr)
library(future)
library(gretchenalbrecht)
library(limeaid)

# Source functions
source("../../code/helper_functions.R")

Training and Testing Data

Iris is randomly split into training and testing datasets such that all of the species are represented in the testing data.

# Set a seed
set.seed(20190311)

# Randomly select two cases from within each of the three species of irises
selected <- c(sample(1:50, 4), sample(51:100, 4), sample(101:150, 4))

# Determine the case numbers that were not selected
cases <- 1:150
not_selected <- cases[!(cases %in% selected)]

# Split up the features of the data into training and testing parts
iris_train <- iris[-selected, ]
iris_test <- iris[selected, ] %>% mutate(case = 1:n())

Random Forest Model

A random forest model is fit to the iris data. The predictions from the model for the testing data are shown in the table below with the actual observed values. The model gets all of the predictions correct.

# Random forest model run on the iris training data
iris_model <- train(x = iris_train %>% select(-Species), 
                    y = iris_train %>% pull(Species), 
                    method = 'rf') 

# Predictions made using the random forest model on the testing data
iris_model_predict <- predict(iris_model, iris_test %>% select(-Species, - case))

# Matrix of observed and predicted values
iris_test_obs_pred <- data.frame(Observed = iris_test %>% pull(Species),
                                 Predicted = iris_model_predict)

# Print the table
knitr::kable(iris_test_obs_pred, align = 'c')
Observed Predicted
setosa setosa
setosa setosa
setosa setosa
setosa setosa
versicolor versicolor
versicolor versicolor
versicolor versicolor
versicolor versicolor
virginica virginica
virginica versicolor
virginica virginica
virginica virginica

Apply LIME

LIME is applied using all density estimation methods with 2 to 5 bins for all of the bin based methods.

# Create a file path
file_iris_explanations = "../../../data/iris_explanations.rds"

# Implement and save or load LIME explanations
if (!file.exists(file_iris_explanations)) {
  
  # Apply LIME
  iris_explanations <-
    apply_lime(
      train = iris_train %>% select(names(iris %>% select(-Species))),
      test = iris_test %>% select(-case, -Species),
      model = iris_model,
      label = "virginica",
      n_features = 3,
      sim_method = c('quantile_bins', 'equal_bins'),
      nbins = 2:6,
      feature_select = "auto",
      dist_fun = "gower",
      kernel_width = NULL,
      gower_pow = c(0.5, 1, 10),
      return_perms = FALSE,
      all_fs = FALSE,
      seed = 20190914
    )

  # Save the explanations
  saveRDS(object = iris_explanations, file = file_iris_explanations)
  
} else {

  # Load the explanations
  iris_explanations <- readRDS(file_iris_explanations)

}

Visualizing LIME Results

All of the results in this section are in terms of the species virginica.

Feature Heatmap

plot_feature_heatmap(iris_explanations$explain, order_method = "PCA") +
  scale_fill_gretchenalbrecht(palette = "last_rays", discrete = TRUE) +
  scale_color_gretchenalbrecht(palette = "last_rays", discrete = TRUE, reverse = TRUE) +
  theme_bw() +
  theme(
    axis.text.y = element_blank(),
    axis.ticks.y = element_blank(),
    strip.background = element_rect(color = "white", fill = "white"),
    strip.text.y.right = element_text(angle = 0),
    legend.position = "bottom"
  ) +
  guides(fill = guide_legend(nrow = 3)) +
  labs(y = "Case", color = "Complex Model Feature", fill = "Complex Model Feature")

Assessment Metric Plot

plot_metrics(
  iris_explanations$explain,
  add_lines = TRUE,
  line_alpha = 0.75
) +
  theme_bw() +
  theme(
    strip.background = element_rect(color = "white", fill = "white"),
    strip.placement = "outside"
  )

Session Info

sessionInfo()
## R version 4.0.2 (2020-06-22)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Catalina 10.15.6
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] assertthat_0.2.1       tree_1.0-40            forcats_0.5.0         
##  [4] stringr_1.4.0          dplyr_1.0.2            purrr_0.3.4           
##  [7] readr_1.3.1            tidyr_1.1.2            tibble_3.0.3          
## [10] tidyverse_1.3.0        limeaid_0.0.1          gretchenalbrecht_0.1.0
## [13] furrr_0.1.0            future_1.18.0          e1071_1.7-3           
## [16] caret_6.0-86           ggplot2_3.3.2.9000     lattice_0.20-41       
## 
## loaded via a namespace (and not attached):
##   [1] colorspace_1.4-1     ellipsis_0.3.1       class_7.3-17        
##   [4] fs_1.5.0             rstudioapi_0.11      farver_2.0.3        
##   [7] listenv_0.8.0        fansi_0.4.1          prodlim_2019.11.13  
##  [10] lubridate_1.7.9      xml2_1.3.2           codetools_0.2-16    
##  [13] splines_4.0.2        lime_0.5.1           knitr_1.29          
##  [16] shinythemes_1.1.2    jsonlite_1.7.1       pROC_1.16.2         
##  [19] broom_0.7.0          cluster_2.1.0        dbplyr_1.4.4        
##  [22] shiny_1.5.0          compiler_4.0.2       httr_1.4.2          
##  [25] backports_1.1.10     Matrix_1.2-18        fastmap_1.0.1       
##  [28] cli_2.0.2            later_1.1.0.1        htmltools_0.5.0     
##  [31] tools_4.0.2          gtable_0.3.0         glue_1.4.2          
##  [34] reshape2_1.4.4       Rcpp_1.0.5           cellranger_1.1.0    
##  [37] vctrs_0.3.4          gdata_2.18.0         nlme_3.1-148        
##  [40] iterators_1.0.12     timeDate_3043.102    gower_0.2.2         
##  [43] xfun_0.17            globals_0.12.5       testthat_2.3.2      
##  [46] rvest_0.3.6          mime_0.9             lifecycle_0.2.0     
##  [49] gtools_3.8.2         dendextend_1.14.0    MASS_7.3-51.6       
##  [52] scales_1.1.1         ipred_0.9-9          TSP_1.1-10          
##  [55] hms_0.5.3            promises_1.1.1       parallel_4.0.2      
##  [58] yaml_2.2.1           gridExtra_2.3        rpart_4.1-15        
##  [61] stringi_1.5.3        highr_0.8            gclus_1.3.2         
##  [64] randomForest_4.6-14  foreach_1.5.0        checkmate_2.0.0     
##  [67] seriation_1.2-8      caTools_1.18.0       lava_1.6.7          
##  [70] shape_1.4.4          rlang_0.4.7          pkgconfig_2.0.3     
##  [73] bitops_1.0-6         evaluate_0.14        labeling_0.3        
##  [76] recipes_0.1.13       htmlwidgets_1.5.1    tidyselect_1.1.0    
##  [79] plyr_1.8.6           magrittr_1.5         R6_2.4.1            
##  [82] gplots_3.0.4         generics_0.0.2       DBI_1.1.0           
##  [85] pillar_1.4.6         haven_2.3.1          withr_2.2.0         
##  [88] survival_3.1-12      nnet_7.3-14          modelr_0.1.8        
##  [91] crayon_1.3.4         KernSmooth_2.23-17   rmarkdown_2.3       
##  [94] viridis_0.5.1        grid_4.0.2           readxl_1.3.1        
##  [97] data.table_1.13.0    blob_1.2.1           ModelMetrics_1.2.2.2
## [100] reprex_0.3.0         digest_0.6.25        xtable_1.8-4        
## [103] httpuv_1.5.4         stats4_4.0.2         munsell_0.5.0       
## [106] glmnet_4.0-2         registry_0.5-1       viridisLite_0.3.0