From 790c1e33e690c2e11c12277e30e83644af268ba1 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Sat, 22 Feb 2020 21:57:55 -0600 Subject: [PATCH] [R-package] Fixed R implementation of upper_bound() and lower_bound() for lgb.Booster (#2785) * [R-package] Fixed R implementation of upper_bound() and lower_bound() for lgb.Booster * [R-package] switched return type to double * fixed R tests on Booster upper_bound() and lower_bound() * fixed linting * moved numeric tolerance into a global constant --- R-package/R/lgb.Booster.R | 10 ++++---- R-package/tests/testthat/test_basic.R | 34 ++++++++++++++++++++++++++- include/LightGBM/lightgbm_R.h | 24 +++++++++++++++++++ 3 files changed, 62 insertions(+), 6 deletions(-) diff --git a/R-package/R/lgb.Booster.R b/R-package/R/lgb.Booster.R index d9c4e002b5ab..d8e3e9ce485a 100644 --- a/R-package/R/lgb.Booster.R +++ b/R-package/R/lgb.Booster.R @@ -322,9 +322,9 @@ Booster <- R6::R6Class( }, # Get upper bound - upper_bound_ = function() { + upper_bound = function() { - upper_bound <- 0L + upper_bound <- 0.0 lgb.call( "LGBM_BoosterGetUpperBoundValue_R" , ret = upper_bound @@ -334,12 +334,12 @@ Booster <- R6::R6Class( }, # Get lower bound - lower_bound_ = function() { + lower_bound = function() { - lower_bound <- 0L + lower_bound <- 0.0 lgb.call( "LGBM_BoosterGetLowerBoundValue_R" - , ret = upper_bound + , ret = lower_bound , private$handle ) diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index 962bc20315af..04c4056bdfa6 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -7,6 +7,8 @@ test <- agaricus.test windows_flag <- grepl("Windows", Sys.info()[["sysname"]]) +TOLERANCE <- 1e-6 + test_that("train and predict binary classification", { nrounds <- 10L bst <- lightgbm( @@ -28,7 +30,7 @@ test_that("train and predict binary classification", { expect_equal(length(pred1), 6513L) err_pred1 <- sum((pred1 > 0.5) != train$label) / length(train$label) err_log <- record_results[1L] - expect_lt(abs(err_pred1 - err_log), 10e-6) + expect_lt(abs(err_pred1 - err_log), TOLERANCE) }) @@ -70,6 +72,36 @@ test_that("use of multiple eval metrics works", { expect_false(is.null(bst$record_evals)) }) +test_that("lgb.Booster.upper_bound() and lgb.Booster.lower_bound() work as expected for binary classification", { + set.seed(708L) + nrounds <- 10L + bst <- lightgbm( + data = train$data + , label = train$label + , num_leaves = 5L + , nrounds = nrounds + , objective = "binary" + , metric = "binary_error" + ) + expect_true(abs(bst$lower_bound() - -1.590853) < TOLERANCE) + expect_true(abs(bst$upper_bound() - 1.871015) < TOLERANCE) +}) + +test_that("lgb.Booster.upper_bound() and lgb.Booster.lower_bound() work as expected for regression", { + set.seed(708L) + nrounds <- 10L + bst <- lightgbm( + data = train$data + , label = train$label + , num_leaves = 5L + , nrounds = nrounds + , objective = "regression" + , metric = "l2" + ) + expect_true(abs(bst$lower_bound() - 0.1513859) < TOLERANCE) + expect_true(abs(bst$upper_bound() - 0.9080349) < TOLERANCE) +}) + test_that("lightgbm() rejects negative or 0 value passed to nrounds", { dtrain <- lgb.Dataset(train$data, label = train$label) params <- list(objective = "regression", metric = "l2,l1") diff --git a/include/LightGBM/lightgbm_R.h b/include/LightGBM/lightgbm_R.h index bff36238b978..c034be360661 100644 --- a/include/LightGBM/lightgbm_R.h +++ b/include/LightGBM/lightgbm_R.h @@ -336,6 +336,30 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterGetCurrentIteration_R(LGBM_SE handle, LGBM_SE out, LGBM_SE call_state); +/*! +* \brief Get model upper bound value. +* \param handle Handle of booster +* \param[out] out_results Result pointing to max value +* \return 0 when succeed, -1 when failure happens +*/ +LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterGetUpperBoundValue_R( + LGBM_SE handle, + LGBM_SE out_result, + LGBM_SE call_state +); + +/*! +* \brief Get model lower bound value. +* \param handle Handle of booster +* \param[out] out_results Result pointing to min value +* \return 0 when succeed, -1 when failure happens +*/ +LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterGetLowerBoundValue_R( + LGBM_SE handle, + LGBM_SE out_result, + LGBM_SE call_state +); + /*! * \brief Get Name of eval * \param eval_names eval names