Grid Search with Catlearn and Catlearn Supplementals

This is a brief demonstration of the grid search helper functions provided in catlearn.suppls. There are both parallelized and non-parallelized versions.

load libraries

library(catlearn)
library(catlearn.suppls)

initialize variables. Create a named list for each hyperparameter you plan to test in the grid search, specify how many random model initializations you’d like to average across to calculate response probabilities for each parameter combination.

# # parameter list
short_param_list <- list(beta_val = c(0, 1, 2, 3),
                    learning_rate = c(.05, .10, .15),
                         num_hids = c(4, 5, 6))

# long_param_list <- list(beta_val = c(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
#                    learning_rate = c(.05, .10, .15, .20, .25, .35),
#                         num_hids = c(4, 5, 6, 10, 20))

# # # number of initializations
num_inits = 4

# # # data
# input_list <- get_test_inputs('type1')
input_list <- get_test_inputs('type4')

# # # fit type
fit_type <- 'bestacc'

# # # fit vector
crit_fit_vector <- NULL

run it

# # # single core
system.time(gs_output <- diva_grid_search(short_param_list, num_inits, input_list))
##    user  system elapsed 
##    3.95    0.00    3.99
# # # parallelized
system.time(gs_output <- diva_grid_search_par(short_param_list, num_inits, input_list))
##    user  system elapsed 
##    0.03    0.00    3.04

examine the results with plot_training

plot_training(lapply(gs_output, function(x) x$resp_probs))

plot of chunk unnamed-chunk-16

examine the detailed results of a grid search run

# # # how many paramter settings did we have?
(n_models <- length(gs_output))
## [1] 36
# # # what was the accuracy distribution?
final_accuracy <- lapply(gs_output, function(x) {x$resp_probs[12]})
plot(1:n_models, final_accuracy)

plot of chunk unnamed-chunk-17

# # # what parameter setting had the best performance?
gs_output[[which.max(final_accuracy)]]$params
##    beta_val learning_rate num_hids
## 33        0          0.15        6
# # # what parameter setting had the worst perfomance?
gs_output[[which.min(final_accuracy)]]$params
##   beta_val learning_rate num_hids
## 7        2           0.1        4
# # # plot em
plot_training(list(gs_output[[which.max(final_accuracy)]]$resp_probs, 
  gs_output[[which.min(final_accuracy)]]$resp_probs))

plot of chunk unnamed-chunk-17

# # # what comes as output for each parameter setting?
names(gs_output[[which.max(final_accuracy)]])
## [1] "resp_probs" "params"     "st"
# # # plot the training curves for a parameter subset (hid units = 5)
hidunit5_respprobs <- list()
for (i in 1:length(gs_output)) {
  if (gs_output[[i]]$params$num_hids == 5){
    hidunit5_respprobs[[paste0(i)]] <- gs_output[[i]]$resp_probs
  } 
}

# # # how many?
(n_models <- length(hidunit5_respprobs))
## [1] 12
# # # accuracy?
plot(1:n_models, unlist(lapply(hidunit5_respprobs, function(x) x[[12]])))

plot of chunk unnamed-chunk-17

plot_training(hidunit5_respprobs)

plot of chunk unnamed-chunk-17