Skip to content

Commit

Permalink
fix: check for early_stopping_rounds
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Jun 25, 2024
1 parent 892a86d commit 43ebb7a
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 7 deletions.
5 changes: 3 additions & 2 deletions R/LearnerClassifXgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
#'
#' @section Early stopping:
#' Early stopping can be used to find the optimal number of boosting rounds.
#' The `early_stopping` parameter controls which set is used to monitor the performance.
#' Set `early_stopping_rounds` to an integer vaulue to monitor the performance of the model on the validation set while training.
#' For information on how to configure the valdiation set, see the *Validation* section of [`mlr3::Learner`].
#'
Expand Down Expand Up @@ -83,7 +82,9 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
tags = c("train", "hotstart", "internal_tuning"),
aggr = crate(function(x) as.integer(ceiling(mean(unlist(x)))), .parent = topenv()),
in_tune_fn = crate(function(domain, param_vals) {
assert_true(!is.null(param_vals$early_stopping), .var.name = "early stopping rounds is set")
if (is.null(param_vals$early_stopping_rounds)) {
stop("Parameter 'early_stopping_rounds' must be set to use internal tuning.")
}
assert_integerish(domain$upper, len = 1L, any.missing = FALSE) }, .parent = topenv()),
disable_in_tune = list(early_stopping_rounds = NULL)
)
Expand Down
4 changes: 3 additions & 1 deletion R/LearnerRegrXgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost",
tags = c("train", "hotstart", "internal_tuning"),
aggr = crate(function(x) as.integer(ceiling(mean(unlist(x)))), .parent = topenv()),
in_tune_fn = crate(function(domain, param_vals) {
assert_true(!is.null(param_vals$early_stopping), .var.name = "early stopping rounds is set")
if (is.null(param_vals$early_stopping_rounds)) {
stop("Parameter 'early_stopping_rounds' must be set to use internal tuning.")
}
assert_integerish(domain$upper, len = 1L, any.missing = FALSE) }, .parent = topenv()),
disable_in_tune = list(early_stopping_rounds = NULL)
)
Expand Down
1 change: 0 additions & 1 deletion man/mlr_learners_classif.xgboost.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion man/mlr_learners_regr.xgboost.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/test_classif_xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ test_that("validation and inner tuning", {
validate = 0.2
)
s = learner$param_set$search_space()
expect_error(learner$param_set$convert_internal_search_space(s), "early stopping")
expect_error(learner$param_set$convert_internal_search_space(s), "Parameter")
learner$param_set$set_values(early_stopping_rounds = 10)
learner$param_set$disable_internal_tuning("nrounds")
expect_equal(learner$param_set$values$early_stopping_rounds, NULL)
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_regr_xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ test_that("validation and inner tuning", {
validate = 0.2
)
s = learner$param_set$search_space()
expect_error(learner$param_set$convert_internal_search_space(s), "early stopping")
expect_error(learner$param_set$convert_internal_search_space(s), "Parameter")
learner$param_set$set_values(early_stopping_rounds = 10)
learner$param_set$disable_internal_tuning("nrounds")
expect_equal(learner$param_set$values$early_stopping_rounds, NULL)
Expand Down

0 comments on commit 43ebb7a

Please sign in to comment.