Skip to content

C5 changes and decision tree method #105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 31, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ Suggests:
rmarkdown,
survival,
keras,
C50,
xgboost,
covr

5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ S3method(predict_raw,"_lognet")
S3method(predict_raw,"_multnet")
S3method(predict_raw,model_fit)
S3method(print,boost_tree)
S3method(print,decision_tree)
S3method(print,linear_reg)
S3method(print,logistic_reg)
S3method(print,mars)
Expand All @@ -41,6 +42,7 @@ S3method(print,surv_reg)
S3method(print,svm_poly)
S3method(print,svm_rbf)
S3method(translate,boost_tree)
S3method(translate,decision_tree)
S3method(translate,default)
S3method(translate,mars)
S3method(translate,mlp)
Expand All @@ -51,6 +53,7 @@ S3method(translate,svm_rbf)
S3method(type_sum,model_fit)
S3method(type_sum,model_spec)
S3method(update,boost_tree)
S3method(update,decision_tree)
S3method(update,linear_reg)
S3method(update,logistic_reg)
S3method(update,mars)
Expand All @@ -76,6 +79,7 @@ export(.y)
export(C5.0_train)
export(boost_tree)
export(check_empty_ellipse)
export(decision_tree)
export(fit)
export(fit.model_spec)
export(fit_control)
Expand Down Expand Up @@ -107,6 +111,7 @@ export(predict_quantile.model_fit)
export(predict_raw)
export(predict_raw.model_fit)
export(rand_forest)
export(rpart_train)
export(set_args)
export(set_engine)
export(set_mode)
Expand Down
12 changes: 8 additions & 4 deletions R/boost_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -437,17 +437,21 @@ C5.0_train <-
other_args <- list(...)
protect_ctrl <- c("minCases", "sample")
protect_fit <- "trials"
f_names <- names(formals(getFromNamespace("C5.0.default", "C50")))
c_names <- names(formals(getFromNamespace("C5.0Control", "C50")))
other_args <- other_args[!(other_args %in% c(protect_ctrl, protect_fit))]
ctrl_args <- other_args[names(other_args) %in% names(formals(C50::C5.0Control))]
fit_args <- other_args[names(other_args) %in% names(formals(C50::C5.0.default))]
ctrl_args <- other_args[names(other_args) %in% c_names]
fit_args <- other_args[names(other_args) %in% f_names]

ctrl <- expr(C50::C5.0Control())
ctrl <- call2("C5.0Control", .ns = "C50")
ctrl$minCases <- minCases
ctrl$sample <- sample
for(i in names(ctrl_args))
ctrl[[i]] <- ctrl_args[[i]]

fit_call <- expr(C50::C5.0(x = x, y = y))
fit_call <- call2("C5.0", .ns = "C50")
fit_call$x <- expr(x)
fit_call$y <- expr(y)
fit_call$trials <- trials
fit_call$control <- ctrl
if(!is.null(weights))
Expand Down
290 changes: 290 additions & 0 deletions R/decision_tree.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
# Prototype parsnip code for decision trees

#' General Interface for Decision Tree Models
#'
#' `decision_tree` is a way to generate a _specification_ of a model
#' before fitting and allows the model to be created using
#' different packages in R or via Spark. The main arguments for the
#' model are:
#' \itemize{
#' \item \code{cost_complexity}: The cost/complexity parameter (a.k.a. `Cp`)
#' used by CART models (`rpart` only).
#' \item \code{tree_depth}: The _maximum_ depth of a tree (`rpart` and
#' `spark` only).
#' \item \code{min_n}: The minimum number of data points in a node
#' that are required for the node to be split further.
#' }
#' These arguments are converted to their specific names at the
#' time that the model is fit. Other options and argument can be
#' set using `set_engine`. If left to their defaults
#' here (`NULL`), the values are taken from the underlying model
#' functions. If parameters need to be modified, `update` can be used
#' in lieu of recreating the object from scratch.
#'
#' @inheritParams boost_tree
#' @param mode A single character string for the type of model.
#' Possible values for this model are "unknown", "regression", or
#' "classification".
#' @param cost_complexity A positive number for the the cost/complexity
#' parameter (a.k.a. `Cp`) used by CART models (`rpart` only).
#' @param tree_depth An integer for maximum depth of the tree.
#' @param min_n An integer for the minimum number of data points
#' in a node that are required for the node to be split further.
#' @details
#' The model can be created using the `fit()` function using the
#' following _engines_:
#' \itemize{
#' \item \pkg{R}: `"rpart"` or `"C5.0"` (classification only)
#' \item \pkg{Spark}: `"spark"`
#' }
#'
#' Note that, for `rpart` models, but `cost_complexity` and
#' `tree_depth` can be both be specified but the package will give
#' precedence to `cost_complexity`. Also, `tree_depth` values
#' greater than 30 `rpart` will give nonsense results on 32-bit
#' machines.
#'
#' @section Engine Details:
#'
#' Engines may have pre-set default arguments when executing the
#' model fit call. For this type of
#' model, the template of the fit calls are::
#'
#' \pkg{rpart} classification
#'
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "rpart")}
#'
#' \pkg{rpart} regression
#'
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "regression"), "rpart")}
#'
#' \pkg{C5.0} classification
#'
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "C5.0")}
#'
#' \pkg{spark} classification
#'
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "spark")}
#'
#' \pkg{spark} regression
#'
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "regression"), "spark")}
#'
#' @note For models created using the spark engine, there are
#' several differences to consider. First, only the formula
#' interface to via `fit` is available; using `fit_xy` will
#' generate an error. Second, the predictions will always be in a
#' spark table format. The names will be the same as documented but
#' without the dots. Third, there is no equivalent to factor
#' columns in spark tables so class predictions are returned as
#' character columns. Fourth, to retain the model object for a new
#' R session (via `save`), the `model$fit` element of the `parsnip`
#' object should be serialized via `ml_save(object$fit)` and
#' separately saved to disk. In a new session, the object can be
#' reloaded and reattached to the `parsnip` object.
#'
#' @importFrom purrr map_lgl
#' @seealso [varying()], [fit()]
#' @examples
#' decision_tree(mode = "classification", tree_depth = 5)
#' # Parameters can be represented by a placeholder:
#' decision_tree(mode = "regression", cost_complexity = varying())
#' @export

decision_tree <-
function(mode = "unknown", cost_complexity = NULL, tree_depth = NULL, min_n = NULL) {

args <- list(
cost_complexity = enquo(cost_complexity),
tree_depth = enquo(tree_depth),
min_n = enquo(min_n)
)

new_model_spec(
"decision_tree",
args = args,
eng_args = NULL,
mode = mode,
method = NULL,
engine = NULL
)
}

#' @export
print.decision_tree <- function(x, ...) {
cat("Random Forest Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if(!is.null(x$method$fit$args)) {
cat("Model fit template:\n")
print(show_call(x))
}
invisible(x)
}

# ------------------------------------------------------------------------------

#' @export
#' @inheritParams update.boost_tree
#' @param object A random forest model specification.
#' @examples
#' model <- decision_tree(cost_complexity = 10, min_n = 3)
#' model
#' update(model, cost_complexity = 1)
#' update(model, cost_complexity = 1, fresh = TRUE)
#' @method update decision_tree
#' @rdname decision_tree
#' @export
update.decision_tree <-
function(object,
cost_complexity = NULL, tree_depth = NULL, min_n = NULL,
fresh = FALSE, ...) {
update_dot_check(...)
args <- list(
cost_complexity = enquo(cost_complexity),
tree_depth = enquo(tree_depth),
min_n = enquo(min_n)
)

if (fresh) {
object$args <- args
} else {
null_args <- map_lgl(args, null_value)
if (any(null_args))
args <- args[!null_args]
if (length(args) > 0)
object$args[names(args)] <- args
}

new_model_spec(
"decision_tree",
args = object$args,
eng_args = object$eng_args,
mode = object$mode,
method = NULL,
engine = object$engine
)
}

# ------------------------------------------------------------------------------

#' @export
translate.decision_tree <- function(x, engine = x$engine, ...) {
if (is.null(engine)) {
message("Used `engine = 'ranger'` for translation.")
engine <- "ranger"
}

x <- translate.default(x, engine, ...)

# slightly cleaner code using
arg_vals <- x$method$fit$args

if (x$engine == "spark") {
if (x$mode == "unknown") {
stop(
"For spark random forests models, the mode cannot be 'unknown' ",
"if the specification is to be translated.",
call. = FALSE
)
} else {
arg_vals$type <- x$mode
}

# See "Details" in ?ml_random_forest_classifier. `feature_subset_strategy`
# should be character even if it contains a number.
if (any(names(arg_vals) == "feature_subset_strategy") &&
isTRUE(is.numeric(quo_get_expr(arg_vals$feature_subset_strategy)))) {
arg_vals$feature_subset_strategy <-
paste(quo_get_expr(arg_vals$feature_subset_strategy))
}
}

# add checks to error trap or change things for this method
if (engine == "ranger") {
if (any(names(arg_vals) == "importance"))
if (isTRUE(is.logical(quo_get_expr(arg_vals$importance))))
stop("`importance` should be a character value. See ?ranger::ranger.",
call. = FALSE)
# unless otherwise specified, classification models are probability forests
if (x$mode == "classification" && !any(names(arg_vals) == "probability"))
arg_vals$probability <- TRUE

}
x$method$fit$args <- arg_vals

x
}

# ------------------------------------------------------------------------------

check_args.decision_tree <- function(object) {
if (object$engine == "C5.0" && object$mode == "regression")
stop("C5.0 is classification only.", call. = FALSE)
invisible(object)
}

# ------------------------------------------------------------------------------

#' Decision trees via rpart
#'
#' `rpart_train` is a wrapper for [rpart::rpart()] tree-based models
#' where all of the model arguments are in the main function.
#'
#' @param formula A model formula.
#' @param data A data frame.
#' @param cp A non-negative number for complexity parameter. Any split
#' that does not decrease the overall lack of fit by a factor of
#' `cp` is not attempted. For instance, with anova splitting,
#' this means that the overall R-squared must increase by `cp` at
#' each step. The main role of this parameter is to save computing
#' time by pruning off splits that are obviously not worthwhile.
#' Essentially,the user informs the program that any split which
#' does not improve the fit by `cp` will likely be pruned off by
#' cross-validation, and that hence the program need not pursue it.
#' @param weights Optional case weights.
#' @param minsplit An integer for the minimum number of observations
#' that must exist in a node in order for a split to be attempted.
#' @param maxdepth An integer for the maximum depth of any node
#' of the final tree, with the root node counted as depth 0.
#' Values greater than 30 `rpart` will give nonsense results on
#' 32-bit machines. This function will truncate `maxdepth` to 30 in
#' those cases.
#' @param ... Other arguments to pass to either `rpart` or `rpart.control`.
#' @return A fitted rpart model.
#' @export
rpart_train <-
function(formula, data, weights = NULL, cp = 0.01, minsplit = 20, maxdepth = 30, ...) {
bitness <- 8 * .Machine$sizeof.pointer
if (bitness == 32 & maxdepth > 30)
maxdepth <- 30

other_args <- list(...)
protect_ctrl <- c("minsplit", "maxdepth", "cp")
protect_fit <- NULL
f_names <- names(formals(getFromNamespace("rpart", "rpart")))
c_names <- names(formals(getFromNamespace("rpart.control", "rpart")))
other_args <- other_args[!(other_args %in% c(protect_ctrl, protect_fit))]
ctrl_args <- other_args[names(other_args) %in% c_names]
fit_args <- other_args[names(other_args) %in% f_names]

ctrl <- call2("rpart.control", .ns = "rpart")
ctrl$minsplit <- minsplit
ctrl$maxdepth <- maxdepth
ctrl$cp <- cp
for(i in names(ctrl_args))
ctrl[[i]] <- ctrl_args[[i]]

fit_call <- call2("rpart", .ns = "rpart")
fit_call$formula <- expr(formula)
fit_call$data <- expr(data)
fit_call$control <- ctrl
if(!is.null(weights))
fit_call$weights <- quote(weights)

for(i in names(fit_args))
fit_call[[i]] <- fit_args[[i]]

eval_tidy(fit_call)
}

Loading