Skip to content

Commit ab1f399

Browse files
authored
Merge pull request #105 from tidymodels/C5-changes
C5 changes and decision tree method
2 parents a9ad95f + 4cb1494 commit ab1f399

File tree

7 files changed

+752
-4
lines changed

7 files changed

+752
-4
lines changed

NAMESPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ S3method(predict_raw,"_lognet")
2828
S3method(predict_raw,"_multnet")
2929
S3method(predict_raw,model_fit)
3030
S3method(print,boost_tree)
31+
S3method(print,decision_tree)
3132
S3method(print,linear_reg)
3233
S3method(print,logistic_reg)
3334
S3method(print,mars)
@@ -41,6 +42,7 @@ S3method(print,surv_reg)
4142
S3method(print,svm_poly)
4243
S3method(print,svm_rbf)
4344
S3method(translate,boost_tree)
45+
S3method(translate,decision_tree)
4446
S3method(translate,default)
4547
S3method(translate,mars)
4648
S3method(translate,mlp)
@@ -51,6 +53,7 @@ S3method(translate,svm_rbf)
5153
S3method(type_sum,model_fit)
5254
S3method(type_sum,model_spec)
5355
S3method(update,boost_tree)
56+
S3method(update,decision_tree)
5457
S3method(update,linear_reg)
5558
S3method(update,logistic_reg)
5659
S3method(update,mars)
@@ -76,6 +79,7 @@ export(.y)
7679
export(C5.0_train)
7780
export(boost_tree)
7881
export(check_empty_ellipse)
82+
export(decision_tree)
7983
export(fit)
8084
export(fit.model_spec)
8185
export(fit_control)
@@ -107,6 +111,7 @@ export(predict_quantile.model_fit)
107111
export(predict_raw)
108112
export(predict_raw.model_fit)
109113
export(rand_forest)
114+
export(rpart_train)
110115
export(set_args)
111116
export(set_engine)
112117
export(set_mode)

R/boost_tree.R

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -437,17 +437,21 @@ C5.0_train <-
437437
other_args <- list(...)
438438
protect_ctrl <- c("minCases", "sample")
439439
protect_fit <- "trials"
440+
f_names <- names(formals(getFromNamespace("C5.0.default", "C50")))
441+
c_names <- names(formals(getFromNamespace("C5.0Control", "C50")))
440442
other_args <- other_args[!(other_args %in% c(protect_ctrl, protect_fit))]
441-
ctrl_args <- other_args[names(other_args) %in% names(formals(C50::C5.0Control))]
442-
fit_args <- other_args[names(other_args) %in% names(formals(C50::C5.0.default))]
443+
ctrl_args <- other_args[names(other_args) %in% c_names]
444+
fit_args <- other_args[names(other_args) %in% f_names]
443445

444-
ctrl <- expr(C50::C5.0Control())
446+
ctrl <- call2("C5.0Control", .ns = "C50")
445447
ctrl$minCases <- minCases
446448
ctrl$sample <- sample
447449
for(i in names(ctrl_args))
448450
ctrl[[i]] <- ctrl_args[[i]]
449451

450-
fit_call <- expr(C50::C5.0(x = x, y = y))
452+
fit_call <- call2("C5.0", .ns = "C50")
453+
fit_call$x <- expr(x)
454+
fit_call$y <- expr(y)
451455
fit_call$trials <- trials
452456
fit_call$control <- ctrl
453457
if(!is.null(weights))

R/decision_tree.R

Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
# Prototype parsnip code for decision trees
2+
3+
#' General Interface for Decision Tree Models
4+
#'
5+
#' `decision_tree` is a way to generate a _specification_ of a model
6+
#' before fitting and allows the model to be created using
7+
#' different packages in R or via Spark. The main arguments for the
8+
#' model are:
9+
#' \itemize{
10+
#' \item \code{cost_complexity}: The cost/complexity parameter (a.k.a. `Cp`)
11+
#' used by CART models (`rpart` only).
12+
#' \item \code{tree_depth}: The _maximum_ depth of a tree (`rpart` and
13+
#' `spark` only).
14+
#' \item \code{min_n}: The minimum number of data points in a node
15+
#' that are required for the node to be split further.
16+
#' }
17+
#' These arguments are converted to their specific names at the
18+
#' time that the model is fit. Other options and argument can be
19+
#' set using `set_engine`. If left to their defaults
20+
#' here (`NULL`), the values are taken from the underlying model
21+
#' functions. If parameters need to be modified, `update` can be used
22+
#' in lieu of recreating the object from scratch.
23+
#'
24+
#' @inheritParams boost_tree
25+
#' @param mode A single character string for the type of model.
26+
#' Possible values for this model are "unknown", "regression", or
27+
#' "classification".
28+
#' @param cost_complexity A positive number for the the cost/complexity
29+
#' parameter (a.k.a. `Cp`) used by CART models (`rpart` only).
30+
#' @param tree_depth An integer for maximum depth of the tree.
31+
#' @param min_n An integer for the minimum number of data points
32+
#' in a node that are required for the node to be split further.
33+
#' @details
34+
#' The model can be created using the `fit()` function using the
35+
#' following _engines_:
36+
#' \itemize{
37+
#' \item \pkg{R}: `"rpart"` or `"C5.0"` (classification only)
38+
#' \item \pkg{Spark}: `"spark"`
39+
#' }
40+
#'
41+
#' Note that, for `rpart` models, but `cost_complexity` and
42+
#' `tree_depth` can be both be specified but the package will give
43+
#' precedence to `cost_complexity`. Also, `tree_depth` values
44+
#' greater than 30 `rpart` will give nonsense results on 32-bit
45+
#' machines.
46+
#'
47+
#' @section Engine Details:
48+
#'
49+
#' Engines may have pre-set default arguments when executing the
50+
#' model fit call. For this type of
51+
#' model, the template of the fit calls are::
52+
#'
53+
#' \pkg{rpart} classification
54+
#'
55+
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "rpart")}
56+
#'
57+
#' \pkg{rpart} regression
58+
#'
59+
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "regression"), "rpart")}
60+
#'
61+
#' \pkg{C5.0} classification
62+
#'
63+
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "C5.0")}
64+
#'
65+
#' \pkg{spark} classification
66+
#'
67+
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "spark")}
68+
#'
69+
#' \pkg{spark} regression
70+
#'
71+
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "regression"), "spark")}
72+
#'
73+
#' @note For models created using the spark engine, there are
74+
#' several differences to consider. First, only the formula
75+
#' interface to via `fit` is available; using `fit_xy` will
76+
#' generate an error. Second, the predictions will always be in a
77+
#' spark table format. The names will be the same as documented but
78+
#' without the dots. Third, there is no equivalent to factor
79+
#' columns in spark tables so class predictions are returned as
80+
#' character columns. Fourth, to retain the model object for a new
81+
#' R session (via `save`), the `model$fit` element of the `parsnip`
82+
#' object should be serialized via `ml_save(object$fit)` and
83+
#' separately saved to disk. In a new session, the object can be
84+
#' reloaded and reattached to the `parsnip` object.
85+
#'
86+
#' @importFrom purrr map_lgl
87+
#' @seealso [varying()], [fit()]
88+
#' @examples
89+
#' decision_tree(mode = "classification", tree_depth = 5)
90+
#' # Parameters can be represented by a placeholder:
91+
#' decision_tree(mode = "regression", cost_complexity = varying())
92+
#' @export
93+
94+
decision_tree <-
95+
function(mode = "unknown", cost_complexity = NULL, tree_depth = NULL, min_n = NULL) {
96+
97+
args <- list(
98+
cost_complexity = enquo(cost_complexity),
99+
tree_depth = enquo(tree_depth),
100+
min_n = enquo(min_n)
101+
)
102+
103+
new_model_spec(
104+
"decision_tree",
105+
args = args,
106+
eng_args = NULL,
107+
mode = mode,
108+
method = NULL,
109+
engine = NULL
110+
)
111+
}
112+
113+
#' @export
114+
print.decision_tree <- function(x, ...) {
115+
cat("Random Forest Model Specification (", x$mode, ")\n\n", sep = "")
116+
model_printer(x, ...)
117+
118+
if(!is.null(x$method$fit$args)) {
119+
cat("Model fit template:\n")
120+
print(show_call(x))
121+
}
122+
invisible(x)
123+
}
124+
125+
# ------------------------------------------------------------------------------
126+
127+
#' @export
128+
#' @inheritParams update.boost_tree
129+
#' @param object A random forest model specification.
130+
#' @examples
131+
#' model <- decision_tree(cost_complexity = 10, min_n = 3)
132+
#' model
133+
#' update(model, cost_complexity = 1)
134+
#' update(model, cost_complexity = 1, fresh = TRUE)
135+
#' @method update decision_tree
136+
#' @rdname decision_tree
137+
#' @export
138+
update.decision_tree <-
139+
function(object,
140+
cost_complexity = NULL, tree_depth = NULL, min_n = NULL,
141+
fresh = FALSE, ...) {
142+
update_dot_check(...)
143+
args <- list(
144+
cost_complexity = enquo(cost_complexity),
145+
tree_depth = enquo(tree_depth),
146+
min_n = enquo(min_n)
147+
)
148+
149+
if (fresh) {
150+
object$args <- args
151+
} else {
152+
null_args <- map_lgl(args, null_value)
153+
if (any(null_args))
154+
args <- args[!null_args]
155+
if (length(args) > 0)
156+
object$args[names(args)] <- args
157+
}
158+
159+
new_model_spec(
160+
"decision_tree",
161+
args = object$args,
162+
eng_args = object$eng_args,
163+
mode = object$mode,
164+
method = NULL,
165+
engine = object$engine
166+
)
167+
}
168+
169+
# ------------------------------------------------------------------------------
170+
171+
#' @export
172+
translate.decision_tree <- function(x, engine = x$engine, ...) {
173+
if (is.null(engine)) {
174+
message("Used `engine = 'ranger'` for translation.")
175+
engine <- "ranger"
176+
}
177+
178+
x <- translate.default(x, engine, ...)
179+
180+
# slightly cleaner code using
181+
arg_vals <- x$method$fit$args
182+
183+
if (x$engine == "spark") {
184+
if (x$mode == "unknown") {
185+
stop(
186+
"For spark random forests models, the mode cannot be 'unknown' ",
187+
"if the specification is to be translated.",
188+
call. = FALSE
189+
)
190+
} else {
191+
arg_vals$type <- x$mode
192+
}
193+
194+
# See "Details" in ?ml_random_forest_classifier. `feature_subset_strategy`
195+
# should be character even if it contains a number.
196+
if (any(names(arg_vals) == "feature_subset_strategy") &&
197+
isTRUE(is.numeric(quo_get_expr(arg_vals$feature_subset_strategy)))) {
198+
arg_vals$feature_subset_strategy <-
199+
paste(quo_get_expr(arg_vals$feature_subset_strategy))
200+
}
201+
}
202+
203+
# add checks to error trap or change things for this method
204+
if (engine == "ranger") {
205+
if (any(names(arg_vals) == "importance"))
206+
if (isTRUE(is.logical(quo_get_expr(arg_vals$importance))))
207+
stop("`importance` should be a character value. See ?ranger::ranger.",
208+
call. = FALSE)
209+
# unless otherwise specified, classification models are probability forests
210+
if (x$mode == "classification" && !any(names(arg_vals) == "probability"))
211+
arg_vals$probability <- TRUE
212+
213+
}
214+
x$method$fit$args <- arg_vals
215+
216+
x
217+
}
218+
219+
# ------------------------------------------------------------------------------
220+
221+
check_args.decision_tree <- function(object) {
222+
if (object$engine == "C5.0" && object$mode == "regression")
223+
stop("C5.0 is classification only.", call. = FALSE)
224+
invisible(object)
225+
}
226+
227+
# ------------------------------------------------------------------------------
228+
229+
#' Decision trees via rpart
230+
#'
231+
#' `rpart_train` is a wrapper for [rpart::rpart()] tree-based models
232+
#' where all of the model arguments are in the main function.
233+
#'
234+
#' @param formula A model formula.
235+
#' @param data A data frame.
236+
#' @param cp A non-negative number for complexity parameter. Any split
237+
#' that does not decrease the overall lack of fit by a factor of
238+
#' `cp` is not attempted. For instance, with anova splitting,
239+
#' this means that the overall R-squared must increase by `cp` at
240+
#' each step. The main role of this parameter is to save computing
241+
#' time by pruning off splits that are obviously not worthwhile.
242+
#' Essentially,the user informs the program that any split which
243+
#' does not improve the fit by `cp` will likely be pruned off by
244+
#' cross-validation, and that hence the program need not pursue it.
245+
#' @param weights Optional case weights.
246+
#' @param minsplit An integer for the minimum number of observations
247+
#' that must exist in a node in order for a split to be attempted.
248+
#' @param maxdepth An integer for the maximum depth of any node
249+
#' of the final tree, with the root node counted as depth 0.
250+
#' Values greater than 30 `rpart` will give nonsense results on
251+
#' 32-bit machines. This function will truncate `maxdepth` to 30 in
252+
#' those cases.
253+
#' @param ... Other arguments to pass to either `rpart` or `rpart.control`.
254+
#' @return A fitted rpart model.
255+
#' @export
256+
rpart_train <-
257+
function(formula, data, weights = NULL, cp = 0.01, minsplit = 20, maxdepth = 30, ...) {
258+
bitness <- 8 * .Machine$sizeof.pointer
259+
if (bitness == 32 & maxdepth > 30)
260+
maxdepth <- 30
261+
262+
other_args <- list(...)
263+
protect_ctrl <- c("minsplit", "maxdepth", "cp")
264+
protect_fit <- NULL
265+
f_names <- names(formals(getFromNamespace("rpart", "rpart")))
266+
c_names <- names(formals(getFromNamespace("rpart.control", "rpart")))
267+
other_args <- other_args[!(other_args %in% c(protect_ctrl, protect_fit))]
268+
ctrl_args <- other_args[names(other_args) %in% c_names]
269+
fit_args <- other_args[names(other_args) %in% f_names]
270+
271+
ctrl <- call2("rpart.control", .ns = "rpart")
272+
ctrl$minsplit <- minsplit
273+
ctrl$maxdepth <- maxdepth
274+
ctrl$cp <- cp
275+
for(i in names(ctrl_args))
276+
ctrl[[i]] <- ctrl_args[[i]]
277+
278+
fit_call <- call2("rpart", .ns = "rpart")
279+
fit_call$formula <- expr(formula)
280+
fit_call$data <- expr(data)
281+
fit_call$control <- ctrl
282+
if(!is.null(weights))
283+
fit_call$weights <- quote(weights)
284+
285+
for(i in names(fit_args))
286+
fit_call[[i]] <- fit_args[[i]]
287+
288+
eval_tidy(fit_call)
289+
}
290+

0 commit comments

Comments
 (0)