This journal documents some of the computation issues I have run into and my “solutions” or “conclusions”.

# Load libraries
library(caret)
library(lime)
library(plotly)
library(randomForest)
library(tidyverse)

Plotly Dead Space Issue

I found an issue with plotly. Below is the reproducible example that I sent to Carson. He was able to fix the issue, so the example below no longer has the problem.

# Example with Plotly to show dead space in heatmap

# System information
# R version 3.5.1 (2018-07-02)
# Platform: x86_64-apple-darwin15.6.0 (64-bit)
# Running under: macOS  10.14

# Load libraries
# library(plotly) # version 4.8.0
# library(ggplot2) # version 3.1.0
# library(tidyr) # version 0.8.2

# Example from: https://plot.ly/r/heatmaps/ -------------------------------------------

# Create dataset
m <- matrix(rnorm(9), nrow = 3, ncol = 3)

# Create plotly heatmap - no dead space to be found
plot_ly(x = c("a", "b", "c"), y = c("d", "e", "f"), z = m, type = "heatmap")

# Example using ggplotly function -----------------------------------------------------

# Reshape the data for ggplot
m_gathered <- data.frame(m) %>%
  gather(key = column) %>%
  mutate(row = factor(rep(c("X1", "X2", "X3"), 3))) %>%
  select(column, row, value)

# Create ggplot heatmap
p <- ggplot(m_gathered, aes(x = column, y = row, fill = value)) +
  geom_tile()

# Apply plotly to ggplot heatmap - dead space in the middle of (X1, X1)
ggplotly(p)

# Carson's suggested fix for now
style(ggplotly(p), hoverinfo = "skip", traces = 2)

# Create ggplot heatmap without a legend
p_nolegend <- ggplot(m_gathered, aes(x = column, y = row, fill = value)) +
  geom_tile() +
  theme(legend.position = "none") 

# Apply plotly to ggplot heatmap - the dead space is gone!
ggplotly(p_nolegend)

Caret vs RandomForest

The lime function in lime is set up to work with specific packages. For example, lime works with a random forest model fit using the caret package, but it is not set up to work with a random forest fit using the randomForest package. I found a suggestion to apply the function as_classifier from the lime package to a model fit using randomForest in order for the lime function to accept the model. It seemed to work. However, I wanted to compare the lime results from a model fit using caret and the same model fit using randomForest. To do this, I used the iris data. The code below goes through the process of fitting the two models (rf_model and caret_model). Then the lime and explain functions are applied to both models.

# Code for comparing the output from LIME when the model is 
# fit with caret and randomForest
# Last Updated: 2018/11/13

## Set up -----------------------------------------------------------------------------

# Split up the data set
iris_test <- iris[1:5, 1:4]
iris_train <- iris[-(1:5), 1:4]
iris_lab <- iris[[5]][-(1:5)]

## LIME with caret --------------------------------------------------------------------

# Create Random Forest model on iris data
caret_model <- train(iris_train, iris_lab, method = 'rf')

# Create an explainer object
caret_explainer <- lime::lime(iris_train, caret_model)

# Explain new observation
caret_explanation <- lime::explain(iris_test, caret_explainer, n_labels = 1, n_features = 4)

## LIME with randomForest -------------------------------------------------------------

rf_model <- randomForest(iris_train, iris_lab)

# Create an explainer object
rf_explainer <- lime::lime(iris_train, model = as_classifier(rf_model))

# Explain new observation
rf_explanation <- lime::explain(iris_test, rf_explainer, n_labels = 1, n_features = 4)

To compare the lime explanations from the two models, I extracted the \(R^2\) value from the simple model, the simple model intercept, the simple model prediction, the feature values, and the feature weights from both explanation datasets. I computed the MSE between each of these values from the two models, and I plotted them on a scatterplot. Since lime is based on random permutations, I would not expect the values from the two models to be exact. However, I would like them to be close. The MSEs are all close to zero, and the plots suggest that the values do an okay job of following the 1-1 line. We decided this seems like the two versions of the explanations are reasonably exchangeable.

## Comparisons -----------------------------------------------------------------------

# Grab the numeric caret explanation variables of interest
caret_numeric <- caret_explanation %>%
  select(model_r2, model_intercept, model_prediction, feature_value, feature_weight) %>%
  gather(key = "variable", value = "caret_value")

# Grab the numeric randomForest explanation variables of interest
rf_numeric <- rf_explanation %>%
  select(model_r2, model_intercept, model_prediction, feature_value, feature_weight) %>%
  gather(key = "variable", value = "rf_value")

# Join the two
lime_numeric <- caret_numeric %>%
  mutate(rf_value = rf_numeric$rf_value, 
         variable = factor(variable))

# Look at the MSEs for the variables
lime_numeric %>%
  group_by(variable) %>%
  summarise(MSE = sum((caret_value - rf_value)^2) / dim(lime_numeric)[1])
## # A tibble: 5 x 2
##   variable                MSE
##   <fct>                 <dbl>
## 1 feature_value    0         
## 2 feature_weight   0.0000178 
## 3 model_intercept  0.00000992
## 4 model_prediction 0.00000943
## 5 model_r2         0.0000294
# Scatterplots of the randomForest versus caret variable values
ggplot(lime_numeric, aes(x = caret_value, y = rf_value)) + 
  geom_point() + 
  facet_wrap( ~ variable, scales = "free") + 
  geom_abline(intercept = 0, slope = 1)

Comparing Furrr Times

The amount of time it took me to run the explain function with all of my input options was getting pretty long. Heike suggest that I use the furrr package, which implements the purrr functions using the speed of the future package. I ran and timed the code below to see how much faster the code was. Using the multiprocess option, the code took about half the amount of time to run (from 218.188 seconds to 108.731 seconds)!

# It took about half the time when using the function from furrr!

library(furrr)
library(future)
library(tictoc)

# Slow way
plan(sequential)
tictoc::tic()
sensitivity_explain <- future_pmap(.l = as.list(sensitivity_inputs %>% 
                                                         select(-case)),
                                          .f = run_lime, # run_lime is one of my helper functions
                                          train = hamby173and252_train %>% select(rf_features),
                                          test = hamby224_test %>% arrange(case) %>% select(rf_features) %>% na.omit(),
                                          rfmodel = as_classifier(rtrees),
                                          label = "TRUE",
                                          nfeatures = 3,
                                          seed = FALSE)
tictoc::toc()
# 218.188 sec elapsed

# Fast way
plan(multiprocess)
tictoc::tic()
sensitivity_explain <- future_pmap(.l = as.list(sensitivity_inputs %>% 
                                                         select(-case)),
                                          .f = run_lime, # run_lime is one of my helper functions
                                          train = hamby173and252_train %>% select(rf_features),
                                          test = hamby224_test %>% arrange(case) %>% select(rf_features) %>% na.omit(),
                                          rfmodel = as_classifier(rtrees),
                                          label = "TRUE",
                                          nfeatures = 3,
                                          seed = FALSE)
tictoc::toc()
# 107.731 sec elapsed

Understanding seriation

The example below is from the paper “Getting Things in Order: An Introduction to the R Package seriation” by Hahsler, Hornik, and Buchta.

library(seriation)

data("iris")
x <- as.matrix(iris[-5])
x <- x[sample(seq_len(nrow(x))), ]
d <- dist(x)
o <- seriate(d)
str(o)
## List of 1
##  $ : 'ser_permutation_vector' int [1:150] 138 52 37 97 60 128 80 25 19 30 ...
##   ..- attr(*, "method")= chr "Spectral"
##  - attr(*, "class")= chr [1:2] "ser_permutation" "list"
class(0)
## [1] "numeric"
head(get_order(o), 15)
##  [1] 138  52  37  97  60 128  80  25  19  30  87 116  40 117  20
pimage(d, main = "Random")

pimage(d, o, main = "Reordered")

cbind(random = criterion(d), reordered = criterion(d, o))
##                          random    reordered
## 2SUM               3.049907e+07 1.782159e+07
## AR_deviations      9.676554e+05 9.887392e+03
## AR_events          5.603890e+05 5.492400e+04
## BAR                1.660980e+05 5.660997e+04
## Cor_R              7.763929e-03 3.719539e-01
## Gradient_raw      -1.878200e+04 9.920580e+05
## Gradient_weighted -3.521581e+04 1.771427e+06
## Inertia            2.116761e+08 3.569103e+08
## Lazy_path_length   3.081463e+04 6.705889e+03
## Least_squares      7.889743e+07 7.648857e+07
## LS                 5.691793e+06 4.487365e+06
## ME                 5.823005e+03 7.253697e+03
## Moore_stress       1.197436e+04 1.111651e+03
## Neumann_stress     6.177089e+03 5.387757e+02
## Path_length        3.944366e+02 9.128657e+01
## RGAR               5.082432e-01 4.981317e-02
pimage(x, main = "Random")

o_2mode <- c(o, ser_permutation(seq_len(ncol(x))))
pimage(x, o_2mode, main = "Reordered")

I wanted to learn how to use the order that is output from seriate to create my own plot using ggplot. I create a new dataframe below that includes the order and create two plots. Note that the “order” from seriate, which can be obtained using get_order is a vector with the case numbers ordered (not the order associated with a case).

# Add a case variable (ordered by seriate order) and my own order variable to my dataframe
x2 <- data.frame(x) %>% 
  mutate(case = factor(1:n(), levels = get_order(o))) %>%
  arrange(case) %>%
  mutate(order = 1:n())

# Plot ordered using the case variable
x2 %>%
  gather(key = variable, value = value, 1:4) %>%
  ggplot(aes(x = variable, y = case, fill = value)) + 
  geom_tile()

# Plot ordered using the order variable
x2 %>%
  gather(key = variable, value = value, 1:4) %>%
  arrange(order) %>%
  ggplot(aes(x = variable, y = order, fill = value)) + 
  geom_tile()

Session Info

sessionInfo()
## R version 4.0.2 (2020-06-22)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Catalina 10.15.6
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] seriation_1.2-8     forcats_0.5.0       stringr_1.4.0      
##  [4] dplyr_1.0.2         purrr_0.3.4         readr_1.3.1        
##  [7] tidyr_1.1.2         tibble_3.0.3        tidyverse_1.3.0    
## [10] randomForest_4.6-14 plotly_4.9.2.1      lime_0.5.1         
## [13] caret_6.0-86        ggplot2_3.3.2.9000  lattice_0.20-41    
## 
## loaded via a namespace (and not attached):
##   [1] colorspace_1.4-1     ellipsis_0.3.1       class_7.3-17        
##   [4] fs_1.5.0             rstudioapi_0.11      farver_2.0.3        
##   [7] prodlim_2019.11.13   fansi_0.4.1          lubridate_1.7.9     
##  [10] xml2_1.3.2           codetools_0.2-16     splines_4.0.2       
##  [13] knitr_1.29           shinythemes_1.1.2    jsonlite_1.7.1      
##  [16] pROC_1.16.2          broom_0.7.0          cluster_2.1.0       
##  [19] dbplyr_1.4.4         shiny_1.5.0          compiler_4.0.2      
##  [22] httr_1.4.2           backports_1.1.10     assertthat_0.2.1    
##  [25] Matrix_1.2-18        fastmap_1.0.1        lazyeval_0.2.2      
##  [28] cli_2.0.2            later_1.1.0.1        htmltools_0.5.0     
##  [31] tools_4.0.2          gtable_0.3.0         glue_1.4.2          
##  [34] reshape2_1.4.4       Rcpp_1.0.5           cellranger_1.1.0    
##  [37] vctrs_0.3.4          gdata_2.18.0         nlme_3.1-148        
##  [40] iterators_1.0.12     timeDate_3043.102    gower_0.2.2         
##  [43] xfun_0.17            rvest_0.3.6          mime_0.9            
##  [46] lifecycle_0.2.0      gtools_3.8.2         dendextend_1.14.0   
##  [49] MASS_7.3-51.6        scales_1.1.1         ipred_0.9-9         
##  [52] TSP_1.1-10           hms_0.5.3            promises_1.1.1      
##  [55] yaml_2.2.1           gridExtra_2.3        rpart_4.1-15        
##  [58] stringi_1.5.3        gclus_1.3.2          foreach_1.5.0       
##  [61] e1071_1.7-3          caTools_1.18.0       lava_1.6.7          
##  [64] shape_1.4.4          rlang_0.4.7          pkgconfig_2.0.3     
##  [67] bitops_1.0-6         qap_0.1-1            evaluate_0.14       
##  [70] recipes_0.1.13       htmlwidgets_1.5.1    labeling_0.3        
##  [73] tidyselect_1.1.0     plyr_1.8.6           magrittr_1.5        
##  [76] R6_2.4.1             gplots_3.0.4         generics_0.0.2      
##  [79] DBI_1.1.0            pillar_1.4.6         haven_2.3.1         
##  [82] withr_2.2.0          survival_3.1-12      nnet_7.3-14         
##  [85] modelr_0.1.8         crayon_1.3.4         KernSmooth_2.23-17  
##  [88] utf8_1.1.4           rmarkdown_2.3        viridis_0.5.1       
##  [91] grid_4.0.2           readxl_1.3.1         data.table_1.13.0   
##  [94] blob_1.2.1           ModelMetrics_1.2.2.2 reprex_0.3.0        
##  [97] digest_0.6.25        xtable_1.8-4         httpuv_1.5.4        
## [100] stats4_4.0.2         munsell_0.5.0        glmnet_4.0-2        
## [103] registry_0.5-1       viridisLite_0.3.0