Skip to content

Commit 927defa

Browse files
authored
handle NULL objective in xgb_predict() (#875)
1 parent abbab7c commit 927defa

File tree

3 files changed

+23
-2
lines changed

3 files changed

+23
-2
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
* `.organize_glmnet_pred()` now expects predictions for a single penalty value (#876).
44

5+
* Fixed bug with prediction from a boosted tree model fitted with `"xgboost"` using a custom objective function (#875).
6+
57
# parsnip 1.0.4
68

79
* For censored regression models, a "reverse Kaplan-Meier" curve is computed for the censoring distribution. This can be used when evaluating this type of model (#855).

R/boost_tree.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ xgb_predict <- function(object, new_data, ...) {
399399
res <- predict(object, new_data, ...)
400400

401401
x <- switch(
402-
object$params$objective,
402+
object$params$objective %||% 3L,
403403
"binary:logitraw" = stats::binomial()$linkinv(res),
404404
"multi:softprob" = matrix(res, ncol = object$params$num_class, byrow = TRUE),
405405
res)

tests/testthat/test_boost_tree_xgboost.R

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,27 @@ test_that('xgboost alternate objective', {
217217

218218
xgb_fit <- spec %>% fit(mpg ~ ., data = mtcars)
219219
expect_equal(extract_fit_engine(xgb_fit)$params$objective, "reg:pseudohubererror")
220-
})
220+
expect_no_error(xgb_preds <- predict(xgb_fit, new_data = mtcars[1,]))
221+
expect_s3_class(xgb_preds, "data.frame")
222+
223+
logregobj <- function(preds, dtrain) {
224+
labels <- xgboost::getinfo(dtrain, "label")
225+
preds <- 1 / (1 + exp(-preds))
226+
grad <- preds - labels
227+
hess <- preds * (1 - preds)
228+
return(list(grad = grad, hess = hess))
229+
}
230+
231+
spec2 <-
232+
boost_tree() %>%
233+
set_engine("xgboost", objective = logregobj) %>%
234+
set_mode("classification")
221235

236+
xgb_fit2 <- spec2 %>% fit(vs ~ ., data = mtcars %>% mutate(vs = as.factor(vs)))
237+
expect_equal(rlang::eval_tidy(xgb_fit2$spec$eng_args$objective), logregobj)
238+
expect_no_error(xgb_preds2 <- predict(xgb_fit2, new_data = mtcars[1,-8]))
239+
expect_s3_class(xgb_preds2, "data.frame")
240+
})
222241

223242
test_that('submodel prediction', {
224243

0 commit comments

Comments
 (0)