Skip to content

Commit

Permalink
[R-package] Promote objective and init_score to top-level argumen…
Browse files Browse the repository at this point in the history
…ts in `lightgbm()` (#4976)

* promote objective and init_score to top-level arguments

* follow comments

* Update R-package/R/lightgbm.R

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* update docs

* linter

* comments

* comments

* comments

* extend test for default objective

* Update R-package/tests/testthat/test_basic.R

Co-authored-by: James Lamb <jaylamb20@gmail.com>
  • Loading branch information
david-cortes and jameslamb authored Feb 23, 2022
1 parent 6ced58a commit 31facb4
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 1 deletion.
10 changes: 9 additions & 1 deletion R-package/R/lightgbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ NULL
#' @inheritParams lgb_shared_params
#' @param label Vector of labels, used if \code{data} is not an \code{\link{lgb.Dataset}}
#' @param weight vector of response values. If not NULL, will set to dataset
#' @param objective Optimization objective (e.g. `"regression"`, `"binary"`, etc.).
#' For a list of accepted objectives, see
#' \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html}{
#' the "Parameters" section of the documentation}.
#' @param init_score initial score is the base prediction lightgbm will boost from
#' @param ... Additional arguments passed to \code{\link{lgb.train}}. For example
#' \itemize{
#' \item{\code{valids}: a list of \code{lgb.Dataset} objects, used for validation}
Expand Down Expand Up @@ -121,6 +126,8 @@ lightgbm <- function(data,
init_model = NULL,
callbacks = list(),
serializable = TRUE,
objective = "regression",
init_score = NULL,
...) {

# validate inputs early to avoid unnecessary computation
Expand All @@ -133,13 +140,14 @@ lightgbm <- function(data,

# Check whether data is lgb.Dataset, if not then create lgb.Dataset manually
if (!lgb.is.Dataset(x = dtrain)) {
dtrain <- lgb.Dataset(data = data, label = label, weight = weight)
dtrain <- lgb.Dataset(data = data, label = label, weight = weight, init_score = init_score)
}

train_args <- list(
"params" = params
, "data" = dtrain
, "nrounds" = nrounds
, "obj" = objective
, "verbose" = verbose
, "eval_freq" = eval_freq
, "early_stopping_rounds" = early_stopping_rounds
Expand Down
9 changes: 9 additions & 0 deletions R-package/man/lightgbm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

105 changes: 105 additions & 0 deletions R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -2812,3 +2812,108 @@ for (x3_to_categorical in c(TRUE, FALSE)) {
})
}
}

test_that("lightgbm() accepts objective as function argument and under params", {
bst1 <- lightgbm(
data = train$data
, label = train$label
, params = list(objective = "regression_l1")
, nrounds = 5L
, verbose = -1L
)
expect_equal(bst1$params$objective, "regression_l1")
model_txt_lines <- strsplit(
x = bst1$save_model_to_string()
, split = "\n"
)[[1L]]
expect_true(any(model_txt_lines == "objective=regression_l1"))
expect_false(any(model_txt_lines == "objective=regression_l2"))

bst2 <- lightgbm(
data = train$data
, label = train$label
, objective = "regression_l1"
, nrounds = 5L
, verbose = -1L
)
expect_equal(bst2$params$objective, "regression_l1")
model_txt_lines <- strsplit(
x = bst2$save_model_to_string()
, split = "\n"
)[[1L]]
expect_true(any(model_txt_lines == "objective=regression_l1"))
expect_false(any(model_txt_lines == "objective=regression_l2"))
})

test_that("lightgbm() prioritizes objective under params over objective as function argument", {
bst1 <- lightgbm(
data = train$data
, label = train$label
, objective = "regression"
, params = list(objective = "regression_l1")
, nrounds = 5L
, verbose = -1L
)
expect_equal(bst1$params$objective, "regression_l1")
model_txt_lines <- strsplit(
x = bst1$save_model_to_string()
, split = "\n"
)[[1L]]
expect_true(any(model_txt_lines == "objective=regression_l1"))
expect_false(any(model_txt_lines == "objective=regression_l2"))

bst2 <- lightgbm(
data = train$data
, label = train$label
, objective = "regression"
, params = list(loss = "regression_l1")
, nrounds = 5L
, verbose = -1L
)
expect_equal(bst2$params$objective, "regression_l1")
model_txt_lines <- strsplit(
x = bst2$save_model_to_string()
, split = "\n"
)[[1L]]
expect_true(any(model_txt_lines == "objective=regression_l1"))
expect_false(any(model_txt_lines == "objective=regression_l2"))
})

test_that("lightgbm() accepts init_score as function argument", {
bst1 <- lightgbm(
data = train$data
, label = train$label
, objective = "binary"
, nrounds = 5L
, verbose = -1L
)
pred1 <- predict(bst1, train$data, rawscore = TRUE)

bst2 <- lightgbm(
data = train$data
, label = train$label
, init_score = pred1
, objective = "binary"
, nrounds = 5L
, verbose = -1L
)
pred2 <- predict(bst2, train$data, rawscore = TRUE)

expect_true(any(pred1 != pred2))
})

test_that("lightgbm() defaults to 'regression' objective if objective not otherwise provided", {
bst <- lightgbm(
data = train$data
, label = train$label
, nrounds = 5L
, verbose = -1L
)
expect_equal(bst$params$objective, "regression")
model_txt_lines <- strsplit(
x = bst$save_model_to_string()
, split = "\n"
)[[1L]]
expect_true(any(model_txt_lines == "objective=regression"))
expect_false(any(model_txt_lines == "objective=regression_l1"))
})

0 comments on commit 31facb4

Please sign in to comment.