Skip to content

Commit

Permalink
[R-package] Fixed R implementation of upper_bound() and lower_bound()…
Browse files Browse the repository at this point in the history
… for lgb.Booster (#2785)

* [R-package] Fixed R implementation of upper_bound() and lower_bound() for lgb.Booster

* [R-package] switched return type to double

* fixed  R tests on Booster upper_bound() and lower_bound()

* fixed linting

* moved numeric tolerance into a global constant
  • Loading branch information
jameslamb authored Feb 23, 2020
1 parent 4adb9ff commit 790c1e3
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 6 deletions.
10 changes: 5 additions & 5 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,9 @@ Booster <- R6::R6Class(
},

# Get upper bound
upper_bound_ = function() {
upper_bound = function() {

upper_bound <- 0L
upper_bound <- 0.0
lgb.call(
"LGBM_BoosterGetUpperBoundValue_R"
, ret = upper_bound
Expand All @@ -334,12 +334,12 @@ Booster <- R6::R6Class(
},

# Get lower bound
lower_bound_ = function() {
lower_bound = function() {

lower_bound <- 0L
lower_bound <- 0.0
lgb.call(
"LGBM_BoosterGetLowerBoundValue_R"
, ret = upper_bound
, ret = lower_bound
, private$handle
)

Expand Down
34 changes: 33 additions & 1 deletion R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ test <- agaricus.test

windows_flag <- grepl("Windows", Sys.info()[["sysname"]])

TOLERANCE <- 1e-6

test_that("train and predict binary classification", {
nrounds <- 10L
bst <- lightgbm(
Expand All @@ -28,7 +30,7 @@ test_that("train and predict binary classification", {
expect_equal(length(pred1), 6513L)
err_pred1 <- sum((pred1 > 0.5) != train$label) / length(train$label)
err_log <- record_results[1L]
expect_lt(abs(err_pred1 - err_log), 10e-6)
expect_lt(abs(err_pred1 - err_log), TOLERANCE)
})


Expand Down Expand Up @@ -70,6 +72,36 @@ test_that("use of multiple eval metrics works", {
expect_false(is.null(bst$record_evals))
})

test_that("lgb.Booster.upper_bound() and lgb.Booster.lower_bound() work as expected for binary classification", {
set.seed(708L)
nrounds <- 10L
bst <- lightgbm(
data = train$data
, label = train$label
, num_leaves = 5L
, nrounds = nrounds
, objective = "binary"
, metric = "binary_error"
)
expect_true(abs(bst$lower_bound() - -1.590853) < TOLERANCE)
expect_true(abs(bst$upper_bound() - 1.871015) < TOLERANCE)
})

test_that("lgb.Booster.upper_bound() and lgb.Booster.lower_bound() work as expected for regression", {
set.seed(708L)
nrounds <- 10L
bst <- lightgbm(
data = train$data
, label = train$label
, num_leaves = 5L
, nrounds = nrounds
, objective = "regression"
, metric = "l2"
)
expect_true(abs(bst$lower_bound() - 0.1513859) < TOLERANCE)
expect_true(abs(bst$upper_bound() - 0.9080349) < TOLERANCE)
})

test_that("lightgbm() rejects negative or 0 value passed to nrounds", {
dtrain <- lgb.Dataset(train$data, label = train$label)
params <- list(objective = "regression", metric = "l2,l1")
Expand Down
24 changes: 24 additions & 0 deletions include/LightGBM/lightgbm_R.h
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,30 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterGetCurrentIteration_R(LGBM_SE handle,
LGBM_SE out,
LGBM_SE call_state);

/*!
* \brief Get model upper bound value.
* \param handle Handle of booster
* \param[out] out_results Result pointing to max value
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterGetUpperBoundValue_R(
LGBM_SE handle,
LGBM_SE out_result,
LGBM_SE call_state
);

/*!
* \brief Get model lower bound value.
* \param handle Handle of booster
* \param[out] out_results Result pointing to min value
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterGetLowerBoundValue_R(
LGBM_SE handle,
LGBM_SE out_result,
LGBM_SE call_state
);

/*!
* \brief Get Name of eval
* \param eval_names eval names
Expand Down

0 comments on commit 790c1e3

Please sign in to comment.