Skip to content

Commit 40ec24f

Browse files
Fixed bug where boost_tree() models couldn't be fit with 1 predictor if validation argument was used. (#994)
* add `drop = FALSE` to matrix subsetting in as_xgb_data() * add test * add news
1 parent acad61e commit 40ec24f

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
* Fixed bug where prediction on rank dificient `lm()` models produced `.pred_res` instead of `.pred`. (#985)
1212

13+
* Fixed bug where `boost_tree()` models couldn't be fit with 1 predictor if `validation` argument was used. (#994)
14+
1315
# parsnip 1.1.0
1416

1517
This release of parsnip contains a number of new features and bug fixes, accompanied by several optimizations that substantially decrease the time to `fit()` and `predict()` with the package.

R/boost_tree.R

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -435,15 +435,22 @@ as_xgb_data <- function(x, y, validation = 0, weights = NULL, event_level = "fir
435435
# Split data
436436
m <- floor(n * (1 - validation)) + 1
437437
trn_index <- sample(seq_len(n), size = max(m, 2))
438-
val_data <- xgboost::xgb.DMatrix(x[-trn_index,], label = y[-trn_index], missing = NA)
438+
val_data <- xgboost::xgb.DMatrix(
439+
data = x[-trn_index, , drop = FALSE],
440+
label = y[-trn_index],
441+
missing = NA
442+
)
439443
watch_list <- list(validation = val_data)
440444

441445
info_list <- list(label = y[trn_index])
442446
if (!is.null(weights)) {
443447
info_list$weight <- weights[trn_index]
444448
}
445-
dat <- xgboost::xgb.DMatrix(x[trn_index,], missing = NA, info = info_list)
446-
449+
dat <- xgboost::xgb.DMatrix(
450+
data = x[trn_index, , drop = FALSE],
451+
missing = NA,
452+
info = info_list
453+
)
447454

448455
} else {
449456
info_list <- list(label = y)

tests/testthat/test_boost_tree.R

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,12 @@ test_that('argument checks for data dimensions', {
3838
expect_equal(args$min_instances_per_node, expr(min_rows(1000, x)))
3939
})
4040

41+
test_that('boost_tree can be fit with 1 predictor if validation is used', {
42+
spec <- boost_tree(trees = 1) %>%
43+
set_engine("xgboost", validation = 0.5) %>%
44+
set_mode("regression")
45+
46+
expect_no_error(
47+
fit(spec, mpg ~ disp, data = mtcars)
48+
)
49+
})

0 commit comments

Comments
 (0)