|
| 1 | +# Prototype parsnip code for decision trees |
| 2 | + |
| 3 | +#' General Interface for Decision Tree Models |
| 4 | +#' |
| 5 | +#' `decision_tree` is a way to generate a _specification_ of a model |
| 6 | +#' before fitting and allows the model to be created using |
| 7 | +#' different packages in R or via Spark. The main arguments for the |
| 8 | +#' model are: |
| 9 | +#' \itemize{ |
| 10 | +#' \item \code{cost_complexity}: The cost/complexity parameter (a.k.a. `Cp`) |
| 11 | +#' used by CART models (`rpart` only). |
| 12 | +#' \item \code{tree_depth}: The _maximum_ depth of a tree (`rpart` and |
| 13 | +#' `spark` only). |
| 14 | +#' \item \code{min_n}: The minimum number of data points in a node |
| 15 | +#' that are required for the node to be split further. |
| 16 | +#' } |
| 17 | +#' These arguments are converted to their specific names at the |
| 18 | +#' time that the model is fit. Other options and argument can be |
| 19 | +#' set using `set_engine`. If left to their defaults |
| 20 | +#' here (`NULL`), the values are taken from the underlying model |
| 21 | +#' functions. If parameters need to be modified, `update` can be used |
| 22 | +#' in lieu of recreating the object from scratch. |
| 23 | +#' |
| 24 | +#' @inheritParams boost_tree |
| 25 | +#' @param mode A single character string for the type of model. |
| 26 | +#' Possible values for this model are "unknown", "regression", or |
| 27 | +#' "classification". |
| 28 | +#' @param cost_complexity A positive number for the the cost/complexity |
| 29 | +#' parameter (a.k.a. `Cp`) used by CART models (`rpart` only). |
| 30 | +#' @param tree_depth An integer for maximum depth of the tree. |
| 31 | +#' @param min_n An integer for the minimum number of data points |
| 32 | +#' in a node that are required for the node to be split further. |
| 33 | +#' @details |
| 34 | +#' The model can be created using the `fit()` function using the |
| 35 | +#' following _engines_: |
| 36 | +#' \itemize{ |
| 37 | +#' \item \pkg{R}: `"rpart"` or `"C5.0"` (classification only) |
| 38 | +#' \item \pkg{Spark}: `"spark"` |
| 39 | +#' } |
| 40 | +#' |
| 41 | +#' Note that, for `rpart` models, but `cost_complexity` and |
| 42 | +#' `tree_depth` can be both be specified but the package will give |
| 43 | +#' precedence to `cost_complexity`. Also, `tree_depth` values |
| 44 | +#' greater than 30 `rpart` will give nonsense results on 32-bit |
| 45 | +#' machines. |
| 46 | +#' |
| 47 | +#' @section Engine Details: |
| 48 | +#' |
| 49 | +#' Engines may have pre-set default arguments when executing the |
| 50 | +#' model fit call. For this type of |
| 51 | +#' model, the template of the fit calls are:: |
| 52 | +#' |
| 53 | +#' \pkg{rpart} classification |
| 54 | +#' |
| 55 | +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "rpart")} |
| 56 | +#' |
| 57 | +#' \pkg{rpart} regression |
| 58 | +#' |
| 59 | +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "regression"), "rpart")} |
| 60 | +#' |
| 61 | +#' \pkg{C5.0} classification |
| 62 | +#' |
| 63 | +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "C5.0")} |
| 64 | +#' |
| 65 | +#' \pkg{spark} classification |
| 66 | +#' |
| 67 | +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "spark")} |
| 68 | +#' |
| 69 | +#' \pkg{spark} regression |
| 70 | +#' |
| 71 | +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "regression"), "spark")} |
| 72 | +#' |
| 73 | +#' @note For models created using the spark engine, there are |
| 74 | +#' several differences to consider. First, only the formula |
| 75 | +#' interface to via `fit` is available; using `fit_xy` will |
| 76 | +#' generate an error. Second, the predictions will always be in a |
| 77 | +#' spark table format. The names will be the same as documented but |
| 78 | +#' without the dots. Third, there is no equivalent to factor |
| 79 | +#' columns in spark tables so class predictions are returned as |
| 80 | +#' character columns. Fourth, to retain the model object for a new |
| 81 | +#' R session (via `save`), the `model$fit` element of the `parsnip` |
| 82 | +#' object should be serialized via `ml_save(object$fit)` and |
| 83 | +#' separately saved to disk. In a new session, the object can be |
| 84 | +#' reloaded and reattached to the `parsnip` object. |
| 85 | +#' |
| 86 | +#' @importFrom purrr map_lgl |
| 87 | +#' @seealso [varying()], [fit()] |
| 88 | +#' @examples |
| 89 | +#' decision_tree(mode = "classification", tree_depth = 5) |
| 90 | +#' # Parameters can be represented by a placeholder: |
| 91 | +#' decision_tree(mode = "regression", cost_complexity = varying()) |
| 92 | +#' @export |
| 93 | + |
| 94 | +decision_tree <- |
| 95 | + function(mode = "unknown", cost_complexity = NULL, tree_depth = NULL, min_n = NULL) { |
| 96 | + |
| 97 | + args <- list( |
| 98 | + cost_complexity = enquo(cost_complexity), |
| 99 | + tree_depth = enquo(tree_depth), |
| 100 | + min_n = enquo(min_n) |
| 101 | + ) |
| 102 | + |
| 103 | + new_model_spec( |
| 104 | + "decision_tree", |
| 105 | + args = args, |
| 106 | + eng_args = NULL, |
| 107 | + mode = mode, |
| 108 | + method = NULL, |
| 109 | + engine = NULL |
| 110 | + ) |
| 111 | + } |
| 112 | + |
| 113 | +#' @export |
| 114 | +print.decision_tree <- function(x, ...) { |
| 115 | + cat("Random Forest Model Specification (", x$mode, ")\n\n", sep = "") |
| 116 | + model_printer(x, ...) |
| 117 | + |
| 118 | + if(!is.null(x$method$fit$args)) { |
| 119 | + cat("Model fit template:\n") |
| 120 | + print(show_call(x)) |
| 121 | + } |
| 122 | + invisible(x) |
| 123 | +} |
| 124 | + |
| 125 | +# ------------------------------------------------------------------------------ |
| 126 | + |
| 127 | +#' @export |
| 128 | +#' @inheritParams update.boost_tree |
| 129 | +#' @param object A random forest model specification. |
| 130 | +#' @examples |
| 131 | +#' model <- decision_tree(cost_complexity = 10, min_n = 3) |
| 132 | +#' model |
| 133 | +#' update(model, cost_complexity = 1) |
| 134 | +#' update(model, cost_complexity = 1, fresh = TRUE) |
| 135 | +#' @method update decision_tree |
| 136 | +#' @rdname decision_tree |
| 137 | +#' @export |
| 138 | +update.decision_tree <- |
| 139 | + function(object, |
| 140 | + cost_complexity = NULL, tree_depth = NULL, min_n = NULL, |
| 141 | + fresh = FALSE, ...) { |
| 142 | + update_dot_check(...) |
| 143 | + args <- list( |
| 144 | + cost_complexity = enquo(cost_complexity), |
| 145 | + tree_depth = enquo(tree_depth), |
| 146 | + min_n = enquo(min_n) |
| 147 | + ) |
| 148 | + |
| 149 | + if (fresh) { |
| 150 | + object$args <- args |
| 151 | + } else { |
| 152 | + null_args <- map_lgl(args, null_value) |
| 153 | + if (any(null_args)) |
| 154 | + args <- args[!null_args] |
| 155 | + if (length(args) > 0) |
| 156 | + object$args[names(args)] <- args |
| 157 | + } |
| 158 | + |
| 159 | + new_model_spec( |
| 160 | + "decision_tree", |
| 161 | + args = object$args, |
| 162 | + eng_args = object$eng_args, |
| 163 | + mode = object$mode, |
| 164 | + method = NULL, |
| 165 | + engine = object$engine |
| 166 | + ) |
| 167 | + } |
| 168 | + |
| 169 | +# ------------------------------------------------------------------------------ |
| 170 | + |
| 171 | +#' @export |
| 172 | +translate.decision_tree <- function(x, engine = x$engine, ...) { |
| 173 | + if (is.null(engine)) { |
| 174 | + message("Used `engine = 'ranger'` for translation.") |
| 175 | + engine <- "ranger" |
| 176 | + } |
| 177 | + |
| 178 | + x <- translate.default(x, engine, ...) |
| 179 | + |
| 180 | + # slightly cleaner code using |
| 181 | + arg_vals <- x$method$fit$args |
| 182 | + |
| 183 | + if (x$engine == "spark") { |
| 184 | + if (x$mode == "unknown") { |
| 185 | + stop( |
| 186 | + "For spark random forests models, the mode cannot be 'unknown' ", |
| 187 | + "if the specification is to be translated.", |
| 188 | + call. = FALSE |
| 189 | + ) |
| 190 | + } else { |
| 191 | + arg_vals$type <- x$mode |
| 192 | + } |
| 193 | + |
| 194 | + # See "Details" in ?ml_random_forest_classifier. `feature_subset_strategy` |
| 195 | + # should be character even if it contains a number. |
| 196 | + if (any(names(arg_vals) == "feature_subset_strategy") && |
| 197 | + isTRUE(is.numeric(quo_get_expr(arg_vals$feature_subset_strategy)))) { |
| 198 | + arg_vals$feature_subset_strategy <- |
| 199 | + paste(quo_get_expr(arg_vals$feature_subset_strategy)) |
| 200 | + } |
| 201 | + } |
| 202 | + |
| 203 | + # add checks to error trap or change things for this method |
| 204 | + if (engine == "ranger") { |
| 205 | + if (any(names(arg_vals) == "importance")) |
| 206 | + if (isTRUE(is.logical(quo_get_expr(arg_vals$importance)))) |
| 207 | + stop("`importance` should be a character value. See ?ranger::ranger.", |
| 208 | + call. = FALSE) |
| 209 | + # unless otherwise specified, classification models are probability forests |
| 210 | + if (x$mode == "classification" && !any(names(arg_vals) == "probability")) |
| 211 | + arg_vals$probability <- TRUE |
| 212 | + |
| 213 | + } |
| 214 | + x$method$fit$args <- arg_vals |
| 215 | + |
| 216 | + x |
| 217 | +} |
| 218 | + |
| 219 | +# ------------------------------------------------------------------------------ |
| 220 | + |
| 221 | +check_args.decision_tree <- function(object) { |
| 222 | + if (object$engine == "C5.0" && object$mode == "regression") |
| 223 | + stop("C5.0 is classification only.", call. = FALSE) |
| 224 | + invisible(object) |
| 225 | +} |
| 226 | + |
| 227 | +# ------------------------------------------------------------------------------ |
| 228 | + |
| 229 | +#' Decision trees via rpart |
| 230 | +#' |
| 231 | +#' `rpart_train` is a wrapper for [rpart::rpart()] tree-based models |
| 232 | +#' where all of the model arguments are in the main function. |
| 233 | +#' |
| 234 | +#' @param formula A model formula. |
| 235 | +#' @param data A data frame. |
| 236 | +#' @param cp A non-negative number for complexity parameter. Any split |
| 237 | +#' that does not decrease the overall lack of fit by a factor of |
| 238 | +#' `cp` is not attempted. For instance, with anova splitting, |
| 239 | +#' this means that the overall R-squared must increase by `cp` at |
| 240 | +#' each step. The main role of this parameter is to save computing |
| 241 | +#' time by pruning off splits that are obviously not worthwhile. |
| 242 | +#' Essentially,the user informs the program that any split which |
| 243 | +#' does not improve the fit by `cp` will likely be pruned off by |
| 244 | +#' cross-validation, and that hence the program need not pursue it. |
| 245 | +#' @param weights Optional case weights. |
| 246 | +#' @param minsplit An integer for the minimum number of observations |
| 247 | +#' that must exist in a node in order for a split to be attempted. |
| 248 | +#' @param maxdepth An integer for the maximum depth of any node |
| 249 | +#' of the final tree, with the root node counted as depth 0. |
| 250 | +#' Values greater than 30 `rpart` will give nonsense results on |
| 251 | +#' 32-bit machines. This function will truncate `maxdepth` to 30 in |
| 252 | +#' those cases. |
| 253 | +#' @param ... Other arguments to pass to either `rpart` or `rpart.control`. |
| 254 | +#' @return A fitted rpart model. |
| 255 | +#' @export |
| 256 | +rpart_train <- |
| 257 | + function(formula, data, weights = NULL, cp = 0.01, minsplit = 20, maxdepth = 30, ...) { |
| 258 | + bitness <- 8 * .Machine$sizeof.pointer |
| 259 | + if (bitness == 32 & maxdepth > 30) |
| 260 | + maxdepth <- 30 |
| 261 | + |
| 262 | + other_args <- list(...) |
| 263 | + protect_ctrl <- c("minsplit", "maxdepth", "cp") |
| 264 | + protect_fit <- NULL |
| 265 | + f_names <- names(formals(getFromNamespace("rpart", "rpart"))) |
| 266 | + c_names <- names(formals(getFromNamespace("rpart.control", "rpart"))) |
| 267 | + other_args <- other_args[!(other_args %in% c(protect_ctrl, protect_fit))] |
| 268 | + ctrl_args <- other_args[names(other_args) %in% c_names] |
| 269 | + fit_args <- other_args[names(other_args) %in% f_names] |
| 270 | + |
| 271 | + ctrl <- call2("rpart.control", .ns = "rpart") |
| 272 | + ctrl$minsplit <- minsplit |
| 273 | + ctrl$maxdepth <- maxdepth |
| 274 | + ctrl$cp <- cp |
| 275 | + for(i in names(ctrl_args)) |
| 276 | + ctrl[[i]] <- ctrl_args[[i]] |
| 277 | + |
| 278 | + fit_call <- call2("rpart", .ns = "rpart") |
| 279 | + fit_call$formula <- expr(formula) |
| 280 | + fit_call$data <- expr(data) |
| 281 | + fit_call$control <- ctrl |
| 282 | + if(!is.null(weights)) |
| 283 | + fit_call$weights <- quote(weights) |
| 284 | + |
| 285 | + for(i in names(fit_args)) |
| 286 | + fit_call[[i]] <- fit_args[[i]] |
| 287 | + |
| 288 | + eval_tidy(fit_call) |
| 289 | + } |
| 290 | + |
0 commit comments