Description
First of all, thank you all so much for the effort put into parsnip. It's a really great package!
As a first-time user, I found it difficult to track down what the original engine arguments were mapped to in parsnip. I knew very well what the parameter colsample_bytree
of native xgboost meant for tuning, but I had to dig into the source code to found that it was translated into mtry
in parsnip. This involved some trial-and-error on my side to translate some of my current models.
I think it would be beneficial if the mapping between original and standardized arguments was documented either (i) in the function itself or (ii) in a complete reference table, e.g.:
model | engine | parsnip | original |
---|---|---|---|
boost_tree | xgboost | tree_depth | max_depth |
boost_tree | xgboost | trees | nrounds |
boost_tree | xgboost | learn_rate | eta |
boost_tree | xgboost | mtry | colsample_bytree |
boost_tree | xgboost | min_n | min_child_weight |
boost_tree | xgboost | loss_reduction | gamma |
boost_tree | xgboost | sample_size | subsample |
I am not sure how one would go about implementing this. I have thrown together a very hacky solution to list the translation of all current arguments below.
Expand code
library(tidyverse)
# Helper function
extract_params <- function(content) {
content %>%
str_replace_all("\\n", "") %>%
str_extract_all("set_model_arg.*?\\)") %>%
purrr::pluck(1) %>%
tibble(line = .) %>%
mutate(engine = str_extract(line, "(?<=eng = ).*?(?=,)"),
parsnip = str_extract(line, "(?<=parsnip = ).*?(?=,)"),
original = str_extract(line, "(?<=original = ).*?(?=,)")) %>%
select(-line) %>%
mutate_all(str_replace_all, '"', "")
}
##################################### #
# Download parsnip from GITHUB ----
##################################### #
download.file(url = "https://github.com/tidymodels/parsnip/archive/master.zip", "parsnip.zip")
unzip(file_name)
file.remove(file_name)
##################################### #
# List arguments of each model ----
##################################### #
params <- dir("parsnip-master/R", recursive = TRUE, pattern = "_data\\.R", full.names = TRUE) %>%
tibble(file_name = .) %>%
mutate(model = basename(file_name) %>% str_replace_all("_data.R", "")) %>%
filter(!model %in% c("nullmodel", "convert")) %>%
mutate(content = map_chr(file_name, read_file)) %>%
mutate(params = pbapply::pblapply(content, extract_params)) %>%
unnest(params) %>%
select(model, engine, parsnip, original)
Full mapping table
params %>%
knitr::kable()
model | engine | parsnip | original |
---|---|---|---|
boost_tree | xgboost | tree_depth | max_depth |
boost_tree | xgboost | trees | nrounds |
boost_tree | xgboost | learn_rate | eta |
boost_tree | xgboost | mtry | colsample_bytree |
boost_tree | xgboost | min_n | min_child_weight |
boost_tree | xgboost | loss_reduction | gamma |
boost_tree | xgboost | sample_size | subsample |
boost_tree | C5.0 | trees | trials |
boost_tree | C5.0 | min_n | minCases |
boost_tree | C5.0 | sample_size | sample |
boost_tree | spark | tree_depth | max_depth |
boost_tree | spark | trees | max_iter |
boost_tree | spark | learn_rate | step_size |
boost_tree | spark | mtry | feature_subset_strategy |
boost_tree | spark | min_n | min_instances_per_node |
boost_tree | spark | min_info_gain | gamma |
boost_tree | spark | sample_size | subsampling_rate |
decision_tree | rpart | tree_depth | maxdepth |
decision_tree | rpart | min_n | minsplit |
decision_tree | rpart | cost_complexity | cp |
decision_tree | C5.0 | min_n | minCases |
decision_tree | spark | tree_depth | max_depth |
decision_tree | spark | min_n | min_instances_per_node |
linear_reg | glmnet | penalty | lambda |
linear_reg | glmnet | mixture | alpha |
linear_reg | spark | penalty | reg_param |
linear_reg | spark | mixture | elastic_net_param |
logistic_reg | glmnet | penalty | lambda |
logistic_reg | glmnet | mixture | alpha |
logistic_reg | spark | penalty | reg_param |
logistic_reg | spark | mixture | elastic_net_param |
logistic_reg | keras | decay | decay |
mars | earth | num_terms | nprune |
mars | earth | prod_degree | degree |
mars | earth | prune_method | pmethod |
mlp | keras | hidden_units | hidden_units |
mlp | keras | penalty | penalty |
mlp | keras | dropout | dropout |
mlp | keras | epochs | epochs |
mlp | keras | activation | activation |
mlp | nnet | hidden_units | size |
mlp | nnet | penalty | decay |
mlp | nnet | epochs | maxit |
multinom_reg | glmnet | penalty | lambda |
multinom_reg | glmnet | mixture | alpha |
multinom_reg | spark | penalty | reg_param |
multinom_reg | spark | mixture | elastic_net_param |
multinom_reg | keras | decay | decay |
nearest_neighbor | kknn | neighbors | ks |
nearest_neighbor | kknn | weight_func | kernel |
nearest_neighbor | kknn | dist_power | distance |
rand_forest | ranger | mtry | mtry |
rand_forest | ranger | trees | num.trees |
rand_forest | ranger | min_n | min.node.size |
rand_forest | randomForest | mtry | mtry |
rand_forest | randomForest | trees | ntree |
rand_forest | randomForest | min_n | nodesize |
rand_forest | spark | mtry | feature_subset_strategy |
rand_forest | spark | trees | num_trees |
rand_forest | spark | min_n | min_instances_per_node |
surv_reg | flexsurv | dist | dist |
surv_reg | survival | dist | dist |
svm_poly | kernlab | cost | C |
svm_poly | kernlab | degree | degree |
svm_poly | kernlab | scale_factor | scale |
svm_poly | kernlab | margin | epsilon |
svm_rbf | kernlab | cost | C |
svm_rbf | kernlab | rbf_sigma | sigma |
svm_rbf | kernlab | margin | epsilon |
Mapping table by model
params %>%
split(.$model) %>%
map(spread, engine, original) %>%
map(knitr::kable)
$boost_tree
model | parsnip | C5.0 | spark | xgboost |
---|---|---|---|---|
boost_tree | learn_rate | NA | step_size | eta |
boost_tree | loss_reduction | NA | NA | gamma |
boost_tree | min_info_gain | NA | gamma | NA |
boost_tree | min_n | minCases | min_instances_per_node | min_child_weight |
boost_tree | mtry | NA | feature_subset_strategy | colsample_bytree |
boost_tree | sample_size | sample | subsampling_rate | subsample |
boost_tree | tree_depth | NA | max_depth | max_depth |
boost_tree | trees | trials | max_iter | nrounds |
$decision_tree
model | parsnip | C5.0 | rpart | spark |
---|---|---|---|---|
decision_tree | cost_complexity | NA | cp | NA |
decision_tree | min_n | minCases | minsplit | min_instances_per_node |
decision_tree | tree_depth | NA | maxdepth | max_depth |
$linear_reg
model | parsnip | glmnet | spark |
---|---|---|---|
linear_reg | mixture | alpha | elastic_net_param |
linear_reg | penalty | lambda | reg_param |
$logistic_reg
model | parsnip | glmnet | keras | spark |
---|---|---|---|---|
logistic_reg | decay | NA | decay | NA |
logistic_reg | mixture | alpha | NA | elastic_net_param |
logistic_reg | penalty | lambda | NA | reg_param |
$mars
model | parsnip | earth |
---|---|---|
mars | num_terms | nprune |
mars | prod_degree | degree |
mars | prune_method | pmethod |
$mlp
model | parsnip | keras | nnet |
---|---|---|---|
mlp | activation | activation | NA |
mlp | dropout | dropout | NA |
mlp | epochs | epochs | maxit |
mlp | hidden_units | hidden_units | size |
mlp | penalty | penalty | decay |
$multinom_reg
model | parsnip | glmnet | keras | spark |
---|---|---|---|---|
multinom_reg | decay | NA | decay | NA |
multinom_reg | mixture | alpha | NA | elastic_net_param |
multinom_reg | penalty | lambda | NA | reg_param |
$nearest_neighbor
model | parsnip | kknn |
---|---|---|
nearest_neighbor | dist_power | distance |
nearest_neighbor | neighbors | ks |
nearest_neighbor | weight_func | kernel |
$rand_forest
model | parsnip | randomForest | ranger | spark |
---|---|---|---|---|
rand_forest | min_n | nodesize | min.node.size | min_instances_per_node |
rand_forest | mtry | mtry | mtry | feature_subset_strategy |
rand_forest | trees | ntree | num.trees | num_trees |
$surv_reg
model | parsnip | flexsurv | survival |
---|---|---|---|
surv_reg | dist | dist | dist |
$svm_poly
model | parsnip | kernlab |
---|---|---|
svm_poly | cost | C |
svm_poly | degree | degree |
svm_poly | margin | epsilon |
svm_poly | scale_factor | scale |
$svm_rbf
model | parsnip | kernlab |
---|---|---|
svm_rbf | cost | C |
svm_rbf | margin | epsilon |
svm_rbf | rbf_sigma | sigma |