Skip to content

Commit 6c5482a

Browse files
simonpcouchtopepo
andauthored
patch params argument with xgboost engine in boost_tree() (#787)
* patch `params` argument with `xgboost` engine in `boost_tree()` * remove + add snapshots from previous PRs * update snaps with new help-page reference Co-authored-by: Max Kuhn <mxkuhn@gmail.com>
1 parent 9e36249 commit 6c5482a

10 files changed

+342
-52
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# parsnip (development version)
22

3+
4+
* Enabled passing additional engine arguments with the xgboost `boost_tree()` engine. To supply engine-specific arguments that are documented in `xgboost::xgb.train()` as arguments to be passed via `params`, supply the list elements directly as named arguments to `set_engine()`. Read more in `?details_boost_tree_xgboost` (#787).
5+
36
# parsnip 1.0.0
47

58
## Model Specification Changes

R/boost_tree.R

Lines changed: 59 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,6 @@ check_args.boost_tree <- function(object) {
213213
#' @param counts A logical. If `FALSE`, `colsample_bynode` and
214214
#' `colsample_bytree` are both assumed to be _proportions_ of the proportion of
215215
#' columns affects (instead of counts).
216-
#' @param objective A single string (or NULL) that defines the loss function that
217-
#' `xgboost` uses to create trees. See [xgboost::xgb.train()] for options. If left
218-
#' NULL, an appropriate loss function is chosen.
219216
#' @param event_level For binary classification, this is a single string of either
220217
#' `"first"` or `"second"` to pass along describing which level of the outcome
221218
#' should be considered the "event".
@@ -227,7 +224,7 @@ xgb_train <- function(
227224
x, y, weights = NULL,
228225
max_depth = 6, nrounds = 15, eta = 0.3, colsample_bynode = NULL,
229226
colsample_bytree = NULL, min_child_weight = 1, gamma = 0, subsample = 1,
230-
validation = 0, early_stop = NULL, objective = NULL, counts = TRUE,
227+
validation = 0, early_stop = NULL, counts = TRUE,
231228
event_level = c("first", "second"), ...) {
232229

233230
event_level <- rlang::arg_match(event_level, c("first", "second"))
@@ -248,18 +245,6 @@ xgb_train <- function(
248245
}
249246
}
250247

251-
if (is.null(objective)) {
252-
if (is.numeric(y)) {
253-
objective <- "reg:squarederror"
254-
} else {
255-
if (num_class == 2) {
256-
objective <- "binary:logistic"
257-
} else {
258-
objective <- "multi:softprob"
259-
}
260-
}
261-
}
262-
263248
n <- nrow(x)
264249
p <- ncol(x)
265250

@@ -300,35 +285,79 @@ xgb_train <- function(
300285
colsample_bytree = colsample_bytree,
301286
colsample_bynode = colsample_bynode,
302287
min_child_weight = min(min_child_weight, n),
303-
subsample = subsample,
304-
objective = objective
288+
subsample = subsample
305289
)
306290

307-
main_args <- list(
308-
data = quote(x$data),
309-
watchlist = quote(x$watchlist),
310-
params = arg_list,
311-
nrounds = nrounds,
312-
early_stopping_rounds = early_stop
291+
others <- process_others(others, arg_list)
292+
293+
main_args <- c(
294+
list(
295+
data = quote(x$data),
296+
watchlist = quote(x$watchlist),
297+
params = arg_list,
298+
nrounds = nrounds,
299+
early_stopping_rounds = early_stop
300+
),
301+
others
313302
)
303+
304+
if (is.null(main_args$objective)) {
305+
if (is.numeric(y)) {
306+
main_args$objective <- "reg:squarederror"
307+
} else {
308+
if (num_class == 2) {
309+
main_args$objective <- "binary:logistic"
310+
} else {
311+
main_args$objective <- "multi:softprob"
312+
}
313+
}
314+
}
315+
314316
if (!is.null(num_class) && num_class > 2) {
315317
main_args$num_class <- num_class
316318
}
317319

318320
call <- make_call(fun = "xgb.train", ns = "xgboost", main_args)
319321

320-
# override or add some other args
322+
eval_tidy(call, env = current_env())
323+
}
324+
325+
process_others <- function(others, arg_list) {
326+
guarded <- c("data", "weights", "num_class", "watchlist")
327+
guarded_supplied <- names(others)[names(others) %in% guarded]
328+
329+
if (length(guarded_supplied) > 0) {
330+
cli::cli_warn(
331+
c(
332+
"!" = "{cli::qty(guarded_supplied)} The argument{?s} {.arg {guarded_supplied}} \
333+
{?is/are} guarded by parsnip and will not be passed to {.fun xgb.train}."
334+
),
335+
class = "xgboost_guarded_warning"
336+
)
337+
}
321338

322339
others <-
323-
others[!(names(others) %in% c("data", "weights", "nrounds", "num_class", names(arg_list)))]
340+
others[!(names(others) %in% guarded)]
341+
342+
if (!is.null(others$params)) {
343+
cli::cli_warn(
344+
c(
345+
"!" = "Please supply elements of the `params` list argument as main arguments \
346+
to `set_engine()` rather than as part of `params`.",
347+
"i" = "See `?details_boost_tree_xgboost` for more information."
348+
),
349+
class = "xgboost_params_warning"
350+
)
351+
352+
params <- others$params[!names(others$params) %in% names(arg_list)]
353+
others <- c(others[names(others) != "params"], params)
354+
}
355+
324356
if (!(any(names(others) == "verbose"))) {
325357
others$verbose <- 0
326358
}
327-
if (length(others) > 0) {
328-
call <- rlang::call_modify(call, !!!others)
329-
}
330359

331-
eval_tidy(call, env = current_env())
360+
others
332361
}
333362

334363
recalc_param <- function(x, counts, denom) {

man/details_boost_tree_xgboost.Rd

Lines changed: 50 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/rmd/boost_tree_xgboost.Rmd

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,26 @@ For classification, non-numeric outcomes (i.e., factors) are internally converte
6060

6161
## Other details
6262

63+
### Interfacing with the `params` argument
64+
65+
The xgboost function that parsnip indirectly wraps, [xgboost::xgb.train()], takes most arguments via the `params` list argument. To supply engine-specific arguments that are documented in [xgboost::xgb.train()] as arguments to be passed via `params`, supply the list elements directly as named arguments to [set_engine()] rather than as elements in `params`. For example, pass a non-default evaluation metric like this:
66+
67+
```{r}
68+
# good
69+
boost_tree() %>%
70+
set_engine("xgboost", eval_metric = "mae")
71+
```
72+
73+
...rather than this:
74+
75+
```{r}
76+
# bad
77+
boost_tree() %>%
78+
set_engine("xgboost", params = list(eval_metric = "mae"))
79+
```
80+
81+
parsnip will then route arguments as needed. In the case that arguments are passed to `params` via [set_engine()], parsnip will warn and re-route the arguments as needed. Note, though, that arguments passed to `params` cannot be tuned.
82+
6383
### Sparse matrices
6484

6585
xgboost requires the data to be in a sparse format. If your predictor data are already in this format, then use [fit_xy.model_spec()] to pass it to the model function. Otherwise, parsnip converts the data to this format.
@@ -78,9 +98,11 @@ By default, the model is trained without parallel processing. This can be change
7898
```{r child = "template-early-stopping.Rmd"}
7999
```
80100

101+
Note that, since the `validation` argument provides an alternative interface to `watchlist`, the `watchlist` argument is guarded by parsnip and will be ignored (with a warning) if passed.
102+
81103
### Objective function
82104

83-
parsnip chooses the objective function based on the characteristics of the outcome. To use a different loss, pass the `objective` argument to [set_engine()].
105+
parsnip chooses the objective function based on the characteristics of the outcome. To use a different loss, pass the `objective` argument to [set_engine()] directly.
84106

85107
## Examples
86108

man/rmd/boost_tree_xgboost.md

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,46 @@ For classification, non-numeric outcomes (i.e., factors) are internally converte
109109

110110
## Other details
111111

112+
### Interfacing with the `params` argument
113+
114+
The xgboost function that parsnip indirectly wraps, [xgboost::xgb.train()], takes most arguments via the `params` list argument. To supply engine-specific arguments that are documented in [xgboost::xgb.train()] as arguments to be passed via `params`, supply the list elements directly as named arguments to [set_engine()] rather than as elements in `params`. For example, pass a non-default evaluation metric like this:
115+
116+
117+
```r
118+
# good
119+
boost_tree() %>%
120+
set_engine("xgboost", eval_metric = "mae")
121+
```
122+
123+
```
124+
## Boosted Tree Model Specification (unknown)
125+
##
126+
## Engine-Specific Arguments:
127+
## eval_metric = mae
128+
##
129+
## Computational engine: xgboost
130+
```
131+
132+
...rather than this:
133+
134+
135+
```r
136+
# bad
137+
boost_tree() %>%
138+
set_engine("xgboost", params = list(eval_metric = "mae"))
139+
```
140+
141+
```
142+
## Boosted Tree Model Specification (unknown)
143+
##
144+
## Engine-Specific Arguments:
145+
## params = list(eval_metric = "mae")
146+
##
147+
## Computational engine: xgboost
148+
```
149+
150+
parsnip will then route arguments as needed. In the case that arguments are passed to `params` via [set_engine()], parsnip will warn and re-route the arguments as needed. Note, though, that arguments passed to `params` cannot be tuned.
151+
112152
### Sparse matrices
113153

114154
xgboost requires the data to be in a sparse format. If your predictor data are already in this format, then use [fit_xy.model_spec()] to pass it to the model function. Otherwise, parsnip converts the data to this format.
@@ -137,9 +177,11 @@ The best way to use this feature is in conjunction with an _internal validation
137177

138178
If the model specification has `early_stop >= trees`, `early_stop` is converted to `trees - 1` and a warning is issued.
139179

180+
Note that, since the `validation` argument provides an alternative interface to `watchlist`, the `watchlist` argument is guarded by parsnip and will be ignored (with a warning) if passed.
181+
140182
### Objective function
141183

142-
parsnip chooses the objective function based on the characteristics of the outcome. To use a different loss, pass the `objective` argument to [set_engine()].
184+
parsnip chooses the objective function based on the characteristics of the outcome. To use a different loss, pass the `objective` argument to [set_engine()] directly.
143185

144186
## Examples
145187

man/xgb_train.Rd

Lines changed: 0 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# interface to param arguments
2+
3+
! Please supply elements of the `params` list argument as main arguments to `set_engine()` rather than as part of `params`.
4+
i See `?details_boost_tree_xgboost` for more information.
5+
6+
---
7+
8+
! Please supply elements of the `params` list argument as main arguments to `set_engine()` rather than as part of `params`.
9+
i See `?details_boost_tree_xgboost` for more information.
10+
11+
---
12+
13+
! The argument `watchlist` is guarded by parsnip and will not be passed to `xgb.train()`.
14+
15+
---
16+
17+
! The arguments `watchlist` and `data` are guarded by parsnip and will not be passed to `xgb.train()`.
18+
19+
---
20+
21+
! Please supply elements of the `params` list argument as main arguments to `set_engine()` rather than as part of `params`.
22+
i See `?details_boost_tree_xgboost` for more information.
23+

tests/testthat/_snaps/proportional_hazards.md

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,3 @@
1-
# printing
2-
3-
Code
4-
proportional_hazards()
5-
Message
6-
parsnip could not locate an implementation for `proportional_hazards` censored regression model specifications using the `survival` engine.
7-
i The parsnip extension package censored implements support for this specification.
8-
i Please install (if needed) and load to continue.
9-
Output
10-
Proportional Hazards Model Specification (censored regression)
11-
12-
Computational engine: survival
13-
14-
151
# updating
162

173
Code

0 commit comments

Comments
 (0)