Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] Speed-up lgb.importance() #6364

Merged
merged 18 commits into from
Apr 10, 2024
Prev Previous commit
Next Next commit
switch to expected_n_trees
  • Loading branch information
mayer79 committed Mar 19, 2024
commit 520194668c0269725127e44c62a2592139916a01
18 changes: 10 additions & 8 deletions R-package/tests/testthat/test_lgb.model.dt.tree.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
NROUNDS <- 10L
MAX_DEPTH <- 3L
N <- nrow(iris)
X <- data.matrix(iris[FEAT])
X <- data.matrix(iris[2L:4L])
FEAT <- colnames(X)
NCLASS <- nlevels(iris[, 5L])

Expand Down Expand Up @@ -43,9 +43,9 @@ models <- list(reg = model_reg, bin = model_binary, multi = model_multiclass)

for (model_name in names(models)) {
model <- models[[model_name]]
nrounds <- NROUNDS
expected_n_trees <- NROUNDS
if (model_name == "multi") {
nrounds <- nrounds * NCLASS
expected_n_trees <- NROUNDS * NCLASS
}
df <- as.data.frame(lgb.model.dt.tree(model))
df_list <- split(df, f = df$tree_index, drop = TRUE)
Expand All @@ -54,7 +54,7 @@ for (model_name in names(models)) {
df_internal <- df[is.na(df$leaf_index), ]

test_that("lgb.model.dt.tree() returns the right number of trees", {
expect_equal(length(unique(df$tree_index)), nrounds)
expect_equal(length(unique(df$tree_index)), expected_n_trees)
})

test_that("num_iteration can return less trees", {
Expand All @@ -65,7 +65,7 @@ for (model_name in names(models)) {
})

test_that("Tree index from lgb.model.dt.tree() is in 0:(NROUNS-1)", {
expect_equal(unique(df$tree_index), (0L:(nrounds - 1L)))
expect_equal(unique(df$tree_index), (0L:(expected_n_trees - 1L)))
})

test_that("Depth calculated from lgb.model.dt.tree() respects max.depth", {
Expand All @@ -75,14 +75,14 @@ for (model_name in names(models)) {
test_that("Each tree from lgb.model.dt.tree() has single root node", {
expect_equal(
unname(sapply(df_list, function(df) sum(df$depth == 0L)))
, rep(1L, nrounds)
, rep(1L, expected_n_trees)
)
})

test_that("Each tree from lgb.model.dt.tree() has two depth 1 nodes", {
expect_equal(
unname(sapply(df_list, function(df) sum(df$depth == 1L)))
, rep(2L, nrounds)
, rep(2L, expected_n_trees)
)
})

Expand All @@ -107,7 +107,9 @@ for (model_name in names(models)) {
})

test_that("non-leaves from lgb.model.dt.tree() do not have leaf info", {
leaf_node_cols <- c("leaf_index", "leaf_parent", "leaf_value", "leaf_count")
leaf_node_cols <- c(
"leaf_index", "leaf_parent", "leaf_value", "leaf_count"
)
expect_true(all(is.na(df_internal[leaf_node_cols])))
})

Expand Down
Loading