From 3eeb6279b20cab8d1758610d0aaab3a2953e8732 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 23 Apr 2020 23:47:34 -0500 Subject: [PATCH] [R-package] added tests on LGBM_BoosterResetTrainingData_R --- R-package/tests/testthat/test_lgb.Booster.R | 63 +++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/R-package/tests/testthat/test_lgb.Booster.R b/R-package/tests/testthat/test_lgb.Booster.R index 6b1cc20f957a..77b51906f3d9 100644 --- a/R-package/tests/testthat/test_lgb.Booster.R +++ b/R-package/tests/testthat/test_lgb.Booster.R @@ -311,3 +311,66 @@ test_that("Booster$rollback_one_iter() should work as expected", { logloss <- bst$eval_train()[[1L]][["value"]] expect_equal(logloss, 0.027915146) }) + +test_that("Booster$update() passing a train_set works as expected", { + set.seed(708L) + data(agaricus.train, package = "lightgbm") + nrounds <- 2L + + # train with 2 rounds and then update + bst <- lightgbm( + data = as.matrix(agaricus.train$data) + , label = agaricus.train$label + , num_leaves = 4L + , learning_rate = 1.0 + , nrounds = nrounds + , objective = "binary" + ) + expect_true(lgb.is.Booster(bst)) + expect_equal(bst$current_iter(), nrounds) + bst$update( + train_set = Dataset$new( + data = agaricus.train$data + , label = agaricus.train$label + ) + ) + expect_true(lgb.is.Booster(bst)) + expect_equal(bst$current_iter(), nrounds + 1L) + + # train with 3 rounds directlry + bst2 <- lightgbm( + data = as.matrix(agaricus.train$data) + , label = agaricus.train$label + , num_leaves = 4L + , learning_rate = 1.0 + , nrounds = nrounds + 1L + , objective = "binary" + ) + expect_true(lgb.is.Booster(bst2)) + expect_equal(bst2$current_iter(), nrounds + 1L) + + # model with 2 rounds + 1 update should be identical to 3 rounds + expect_equal(bst2$eval_train()[[1L]][["value"]], 0.04806585) + expect_equal(bst$eval_train()[[1L]][["value"]], bst2$eval_train()[[1L]][["value"]]) +}) + +test_that("Booster$update() throws an informative error if you provide a non-Dataset to update()", { + set.seed(708L) + data(agaricus.train, package = "lightgbm") + nrounds <- 2L + + # train with 2 rounds and then update + bst <- lightgbm( + data = as.matrix(agaricus.train$data) + , label = agaricus.train$label + , num_leaves = 4L + , learning_rate = 1.0 + , nrounds = nrounds + , objective = "binary" + ) + expect_error({ + bst$update( + train_set = data.frame(x = rnorm(10L)) + ) + }, regexp = "lgb.Booster.update: Only can use lgb.Dataset", fixed = TRUE) +})