Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] Promote objective and init_score to top-level arguments in lightgbm() #4976

Merged
merged 12 commits into from
Feb 23, 2022
10 changes: 9 additions & 1 deletion R-package/R/lightgbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ NULL
#' @inheritParams lgb_shared_params
#' @param label Vector of labels, used if \code{data} is not an \code{\link{lgb.Dataset}}
#' @param weight vector of response values. If not NULL, will set to dataset
#' @param objective Optimization objective (e.g. `"regression"`, `"binary"`, etc.).
#' For a list of accepted objectives, see
#' \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html}{
#' the "Parameters" section of the documentation}.
#' @param init_score initial score is the base prediction lightgbm will boost from
#' @param ... Additional arguments passed to \code{\link{lgb.train}}. For example
#' \itemize{
#' \item{\code{valids}: a list of \code{lgb.Dataset} objects, used for validation}
Expand Down Expand Up @@ -121,6 +126,8 @@ lightgbm <- function(data,
init_model = NULL,
callbacks = list(),
serializable = TRUE,
objective = "regression",
init_score = NULL,
...) {

# validate inputs early to avoid unnecessary computation
Expand All @@ -133,13 +140,14 @@ lightgbm <- function(data,

# Check whether data is lgb.Dataset, if not then create lgb.Dataset manually
if (!lgb.is.Dataset(x = dtrain)) {
dtrain <- lgb.Dataset(data = data, label = label, weight = weight)
dtrain <- lgb.Dataset(data = data, label = label, weight = weight, init_score = init_score)
}

train_args <- list(
"params" = params
, "data" = dtrain
, "nrounds" = nrounds
, "obj" = objective
, "verbose" = verbose
, "eval_freq" = eval_freq
, "early_stopping_rounds" = early_stopping_rounds
Expand Down
9 changes: 9 additions & 0 deletions R-package/man/lightgbm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

105 changes: 105 additions & 0 deletions R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -2812,3 +2812,108 @@ for (x3_to_categorical in c(TRUE, FALSE)) {
})
}
}

test_that("lightgbm() accepts objective as function argument and under params", {
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
bst1 <- lightgbm(
data = train$data
, label = train$label
, params = list(objective = "regression_l1")
, nrounds = 5L
, verbose = -1L
)
expect_equal(bst1$params$objective, "regression_l1")
model_txt_lines <- strsplit(
x = bst1$save_model_to_string()
, split = "\n"
)[[1L]]
expect_true(any(model_txt_lines == "objective=regression_l1"))
expect_false(any(model_txt_lines == "objective=regression_l2"))

bst2 <- lightgbm(
data = train$data
, label = train$label
, objective = "regression_l1"
, nrounds = 5L
, verbose = -1L
)
expect_equal(bst2$params$objective, "regression_l1")
model_txt_lines <- strsplit(
x = bst2$save_model_to_string()
, split = "\n"
)[[1L]]
expect_true(any(model_txt_lines == "objective=regression_l1"))
expect_false(any(model_txt_lines == "objective=regression_l2"))
})

test_that("lightgbm() prioritizes objective under params over objective as function argument", {
bst1 <- lightgbm(
data = train$data
, label = train$label
, objective = "regression"
, params = list(objective = "regression_l1")
, nrounds = 5L
, verbose = -1L
)
expect_equal(bst1$params$objective, "regression_l1")
model_txt_lines <- strsplit(
x = bst1$save_model_to_string()
, split = "\n"
)[[1L]]
expect_true(any(model_txt_lines == "objective=regression_l1"))
expect_false(any(model_txt_lines == "objective=regression_l2"))

bst2 <- lightgbm(
data = train$data
, label = train$label
, objective = "regression"
, params = list(loss = "regression_l1")
, nrounds = 5L
, verbose = -1L
)
expect_equal(bst2$params$objective, "regression_l1")
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
model_txt_lines <- strsplit(
x = bst2$save_model_to_string()
, split = "\n"
)[[1L]]
expect_true(any(model_txt_lines == "objective=regression_l1"))
expect_false(any(model_txt_lines == "objective=regression_l2"))
})

test_that("lightgbm() accepts init_score as function argument", {
bst1 <- lightgbm(
data = train$data
, label = train$label
, objective = "binary"
, nrounds = 5L
, verbose = -1L
)
pred1 <- predict(bst1, train$data, rawscore = TRUE)

bst2 <- lightgbm(
data = train$data
, label = train$label
, init_score = pred1
, objective = "binary"
, nrounds = 5L
, verbose = -1L
)
pred2 <- predict(bst2, train$data, rawscore = TRUE)

expect_true(any(pred1 != pred2))
})

test_that("lightgbm() defaults to 'regression' objective if objective not otherwise provided", {
bst <- lightgbm(
data = train$data
, label = train$label
, nrounds = 5L
, verbose = -1L
)
expect_equal(bst$params$objective, "regression")
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
model_txt_lines <- strsplit(
x = bst$save_model_to_string()
, split = "\n"
)[[1L]]
expect_true(any(model_txt_lines == "objective=regression"))
expect_false(any(model_txt_lines == "objective=regression_l1"))
})