Skip to content

Commit f7786ef

Browse files
khotilovlarskotthoff
authored andcommitted
xgboost: expose watchlist and callbacks (#1859)
* xgboost: expose watchlist and callbacks; remove silent from params; set default lambda=1; add tweedie_variance_power param * disable TODO linter
1 parent e3f74dd commit f7786ef

File tree

4 files changed

+45
-22
lines changed

4 files changed

+45
-22
lines changed

R/RLearner_classif_xgboost.R

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ makeRLearner.classif.xgboost = function() {
77
# we pass all of what goes in 'params' directly to ... of xgboost
88
# makeUntypedLearnerParam(id = "params", default = list()),
99
makeDiscreteLearnerParam(id = "booster", default = "gbtree", values = c("gbtree", "gblinear", "dart")),
10-
makeIntegerLearnerParam(id = "silent", default = 0L, tunable = FALSE),
10+
makeUntypedLearnerParam(id = "watchlist", default = NULL, tunable = FALSE),
1111
makeNumericLearnerParam(id = "eta", default = 0.3, lower = 0, upper = 1),
1212
makeNumericLearnerParam(id = "gamma", default = 0, lower = 0),
1313
makeIntegerLearnerParam(id = "max_depth", default = 6L, lower = 1L),
@@ -16,7 +16,7 @@ makeRLearner.classif.xgboost = function() {
1616
makeNumericLearnerParam(id = "colsample_bytree", default = 1, lower = 0, upper = 1),
1717
makeNumericLearnerParam(id = "colsample_bylevel", default = 1, lower = 0, upper = 1),
1818
makeIntegerLearnerParam(id = "num_parallel_tree", default = 1L, lower = 1L),
19-
makeNumericLearnerParam(id = "lambda", default = 0, lower = 0),
19+
makeNumericLearnerParam(id = "lambda", default = 1, lower = 0),
2020
makeNumericLearnerParam(id = "lambda_bias", default = 0, lower = 0),
2121
makeNumericLearnerParam(id = "alpha", default = 0, lower = 0),
2222
makeUntypedLearnerParam(id = "objective", default = "binary:logistic", tunable = FALSE),
@@ -26,6 +26,7 @@ makeRLearner.classif.xgboost = function() {
2626
makeNumericLearnerParam(id = "missing", default = NULL, tunable = FALSE, when = "both",
2727
special.vals = list(NA, NA_real_, NULL)),
2828
makeIntegerVectorLearnerParam(id = "monotone_constraints", default = 0, lower = -1, upper = 1),
29+
makeNumericLearnerParam(id = "tweedie_variance_power", lower = 1, upper = 2, default = 1.5, requires = quote(objective == "reg:tweedie")),
2930
makeIntegerLearnerParam(id = "nthread", lower = 1L, tunable = FALSE),
3031
makeIntegerLearnerParam(id = "nrounds", default = 1L, lower = 1L),
3132
# FIXME nrounds seems to have no default in xgboost(), if it has 1, par.vals is redundant
@@ -38,7 +39,14 @@ makeRLearner.classif.xgboost = function() {
3839
makeDiscreteLearnerParam(id = "sample_type", default = "uniform", values = c("uniform", "weighted"), requires = quote(booster == "dart")),
3940
makeDiscreteLearnerParam(id = "normalize_type", default = "tree", values = c("tree", "forest"), requires = quote(booster == "dart")),
4041
makeNumericLearnerParam(id = "rate_drop", default = 0, lower = 0, upper = 1, requires = quote(booster == "dart")),
41-
makeNumericLearnerParam(id = "skip_drop", default = 0, lower = 0, upper = 1, requires = quote(booster == "dart"))
42+
makeNumericLearnerParam(id = "skip_drop", default = 0, lower = 0, upper = 1, requires = quote(booster == "dart")),
43+
# TODO: uncomment the following after the next CRAN update, and set max_depth's lower = 0L
44+
#makeLogicalLearnerParam(id = "one_drop", default = FALSE, requires = quote(booster == "dart")),
45+
#makeDiscreteLearnerParam(id = "tree_method", default = "exact", values = c("exact", "hist"), requires = quote(booster != "gblinear")),
46+
#makeDiscreteLearnerParam(id = "grow_policy", default = "depthwise", values = c("depthwise", "lossguide"), requires = quote(tree_method == "hist")),
47+
#makeIntegerLearnerParam(id = "max_leaves", default = 0L, lower = 0L, requires = quote(grow_policy == "lossguide")),
48+
#makeIntegerLearnerParam(id = "max_bin", default = 256L, lower = 2L, requires = quote(tree_method == "hist")),
49+
makeUntypedLearnerParam(id = "callbacks", default = list(), tunable = FALSE)
4250
),
4351
par.vals = list(nrounds = 1L, verbose = 0L),
4452
properties = c("twoclass", "multiclass", "numerics", "prob", "weights", "missings", "featimp"),
@@ -54,8 +62,6 @@ trainLearner.classif.xgboost = function(.learner, .task, .subset, .weights = NUL
5462

5563
td = getTaskDesc(.task)
5664
parlist = list(...)
57-
parlist$data = data.matrix(getTaskData(.task, .subset, target.extra = TRUE)$data)
58-
parlist$label = match(as.character(getTaskData(.task, .subset, target.extra = TRUE)$target), td$class.levels) - 1
5965
nc = length(td$class.levels)
6066

6167
if (is.null(parlist$objective))
@@ -68,10 +74,17 @@ trainLearner.classif.xgboost = function(.learner, .task, .subset, .weights = NUL
6874
if (parlist$objective %in% c("multi:softprob", "multi:softmax"))
6975
parlist$num_class = nc
7076

77+
task.data = getTaskData(.task, .subset, target.extra = TRUE)
78+
label = match(as.character(task.data$target), td$class.levels) - 1
79+
parlist$data = xgboost::xgb.DMatrix(data = data.matrix(task.data$data), label = label)
80+
7181
if (!is.null(.weights))
72-
parlist$data = xgboost::xgb.DMatrix(data = parlist$data, label = parlist$label, weight = .weights)
82+
xgboost::setinfo(parlist$data, "weight", .weights)
83+
84+
if (is.null(parlist$watchlist))
85+
parlist$watchlist = list(train = parlist$data)
7386

74-
do.call(xgboost::xgboost, parlist)
87+
do.call(xgboost::xgb.train, parlist)
7588
}
7689

7790
#' @export
@@ -131,5 +144,3 @@ getFeatureImportanceLearner.classif.xgboost = function(.learner, .model, ...) {
131144
fiv = imp$Gain
132145
setNames(fiv, imp$Feature)
133146
}
134-
135-

R/RLearner_regr_xgboost.R

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ makeRLearner.regr.xgboost = function() {
77
# we pass all of what goes in 'params' directly to ... of xgboost
88
#makeUntypedLearnerParam(id = "params", default = list()),
99
makeDiscreteLearnerParam(id = "booster", default = "gbtree", values = c("gbtree", "gblinear", "dart")),
10-
makeIntegerLearnerParam(id = "silent", default = 0L, tunable = FALSE),
10+
makeUntypedLearnerParam(id = "watchlist", default = NULL, tunable = FALSE),
1111
makeNumericLearnerParam(id = "eta", default = 0.3, lower = 0, upper = 1),
1212
makeNumericLearnerParam(id = "gamma", default = 0, lower = 0),
1313
makeIntegerLearnerParam(id = "max_depth", default = 6L, lower = 1L),
@@ -16,16 +16,17 @@ makeRLearner.regr.xgboost = function() {
1616
makeNumericLearnerParam(id = "colsample_bytree", default = 1, lower = 0, upper = 1),
1717
makeNumericLearnerParam(id = "colsample_bylevel", default = 1, lower = 0, upper = 1),
1818
makeIntegerLearnerParam(id = "num_parallel_tree", default = 1L, lower = 1L),
19-
makeNumericLearnerParam(id = "lambda", default = 0, lower = 0),
19+
makeNumericLearnerParam(id = "lambda", default = 1, lower = 0),
2020
makeNumericLearnerParam(id = "lambda_bias", default = 0, lower = 0),
2121
makeNumericLearnerParam(id = "alpha", default = 0, lower = 0),
2222
makeUntypedLearnerParam(id = "objective", default = "reg:linear", tunable = FALSE),
2323
makeUntypedLearnerParam(id = "eval_metric", default = "rmse", tunable = FALSE),
2424
makeNumericLearnerParam(id = "base_score", default = 0.5, tunable = FALSE),
25-
25+
makeNumericLearnerParam(id = "max_delta_step", lower = 0, default = 0),
2626
makeNumericLearnerParam(id = "missing", default = NULL, tunable = FALSE, when = "both",
2727
special.vals = list(NA, NA_real_, NULL)),
2828
makeIntegerVectorLearnerParam(id = "monotone_constraints", default = 0, lower = -1, upper = 1),
29+
makeNumericLearnerParam(id = "tweedie_variance_power", lower = 1, upper = 2, default = 1.5, requires = quote(objective == "reg:tweedie")),
2930
makeIntegerLearnerParam(id = "nthread", lower = 1L, tunable = FALSE),
3031
makeIntegerLearnerParam(id = "nrounds", default = 1L, lower = 1L),
3132
# FIXME nrounds seems to have no default in xgboost(), if it has 1, par.vals is redundant
@@ -35,9 +36,17 @@ makeRLearner.regr.xgboost = function() {
3536
requires = quote(verbose == 1L)),
3637
makeIntegerLearnerParam(id = "early_stopping_rounds", default = NULL, lower = 1L, special.vals = list(NULL), tunable = FALSE),
3738
makeLogicalLearnerParam(id = "maximize", default = NULL, special.vals = list(NULL), tunable = FALSE),
39+
makeDiscreteLearnerParam(id = "sample_type", default = "uniform", values = c("uniform", "weighted"), requires = quote(booster == "dart")),
3840
makeDiscreteLearnerParam(id = "normalize_type", default = "tree", values = c("tree", "forest"), requires = quote(booster == "dart")),
3941
makeNumericLearnerParam(id = "rate_drop", default = 0, lower = 0, upper = 1, requires = quote(booster == "dart")),
40-
makeNumericLearnerParam(id = "skip_drop", default = 0, lower = 0, upper = 1, requires = quote(booster == "dart"))
42+
makeNumericLearnerParam(id = "skip_drop", default = 0, lower = 0, upper = 1, requires = quote(booster == "dart")),
43+
# TODO: uncomment the following after the next CRAN update, and set max_depth's lower = 0L
44+
#makeLogicalLearnerParam(id = "one_drop", default = FALSE, requires = quote(booster == "dart")),
45+
#makeDiscreteLearnerParam(id = "tree_method", default = "exact", values = c("exact", "hist"), requires = quote(booster != "gblinear")),
46+
#makeDiscreteLearnerParam(id = "grow_policy", default = "depthwise", values = c("depthwise", "lossguide"), requires = quote(tree_method == "hist")),
47+
#makeIntegerLearnerParam(id = "max_leaves", default = 0L, lower = 0L, requires = quote(grow_policy == "lossguide")),
48+
#makeIntegerLearnerParam(id = "max_bin", default = 256L, lower = 2L, requires = quote(tree_method == "hist")),
49+
makeUntypedLearnerParam(id = "callbacks", default = list(), tunable = FALSE)
4150
),
4251
par.vals = list(nrounds = 1L, verbose = 0L),
4352
properties = c("numerics", "weights", "featimp", "missings"),
@@ -52,16 +61,19 @@ makeRLearner.regr.xgboost = function() {
5261
trainLearner.regr.xgboost = function(.learner, .task, .subset, .weights = NULL, ...) {
5362
parlist = list(...)
5463

55-
parlist$label = getTaskData(.task, .subset, target.extra = TRUE)$target
56-
parlist$data = data.matrix(getTaskData(.task, .subset, target.extra = TRUE)$data)
57-
5864
if (is.null(parlist$objective))
5965
parlist$objective = "reg:linear"
6066

67+
task.data = getTaskData(.task, .subset, target.extra = TRUE)
68+
parlist$data = xgboost::xgb.DMatrix(data = data.matrix(task.data$data), label = task.data$target)
69+
6170
if (!is.null(.weights))
62-
parlist$data = xgboost::xgb.DMatrix(data = parlist$data, label = parlist$label, weight = .weights)
71+
xgboost::setinfo(parlist$data, "weight", .weights)
72+
73+
if (is.null(parlist$watchlist))
74+
parlist$watchlist = list(train = parlist$data)
6375

64-
do.call(xgboost::xgboost, parlist)
76+
do.call(xgboost::xgb.train, parlist)
6577
}
6678

6779
#' @export

tests/testthat/helper_lint.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ if (isLintrVersionOk() && require("lintr", quietly = TRUE) && require("rex", qui
266266
seq = lintr::seq_linter,
267267
unneeded.concatenation = lintr::unneeded_concatenation_linter,
268268
trailing.whitespace = lintr::trailing_whitespace_linter,
269-
todo.comment = lintr::todo_comment_linter(todo = "todo"), # is case-insensitive
269+
#todo.comment = lintr::todo_comment_linter(todo = "todo"), # is case-insensitive
270270
spaces.inside = lintr::spaces_inside_linter,
271271
infix.spaces = infix.spaces.linter,
272272
object.naming = object.naming.linter)

tests/testthat/test_regr_xgboost.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ test_that("regr_xgboost", {
3131
})
3232

3333
test_that("xgboost works with different 'missing' arg vals", {
34-
lrn = makeLearner("classif.xgboost", missing = NA_real_)
35-
lrn = makeLearner("classif.xgboost", missing = NA)
36-
lrn = makeLearner("classif.xgboost", missing = NULL)
34+
lrn = makeLearner("regr.xgboost", missing = NA_real_)
35+
lrn = makeLearner("regr.xgboost", missing = NA)
36+
lrn = makeLearner("regr.xgboost", missing = NULL)
3737
})
3838

3939

0 commit comments

Comments
 (0)