Skip to content

Commit 03a966a

Browse files
committed
TST move test to test_multinom_reg_glmnet
1 parent 1d668ff commit 03a966a

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

tests/testthat/test_multinom_reg.R

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,4 @@ test_that('bad input', {
105105
expect_warning(translate(multinom_reg() %>% set_engine("glmnet", x = iris[,1:3], y = iris$Species)))
106106
})
107107

108-
test_that("predictions are factors with all levels", {
109-
basic <- multinom_reg() %>% set_engine("glmnet") %>% fit(Species ~ ., data = iris)
110-
nd <- iris[iris$Species == "setosa", ]
111-
yhat <- predict(basic, new_data = nd, penalty = .1)
112-
expect_is(yhat$.pred_class, "factor")
113-
expect_equal(levels(yhat$.pred_class), levels(iris$Species))
114-
yhat_multi <- multi_predict(basic, new_data = nd, penalty = .1)$.pred
115-
expect_is(yhat_multi[[1]]$.pred_class, "factor")
116-
expect_equal(levels(yhat_multi[[1]]$.pred_class), levels(iris$Species))
117-
})
108+

tests/testthat/test_multinom_reg_glmnet.R

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ test_that('glmnet probabilities, mulitiple lambda', {
138138
multi_predict(xy_fit, iris[rows, 1:4], penalty = lams)$.pred
139139
)
140140

141+
expect_is(yhat$.pred_class, "factor")
142+
expect_equal(levels(yhat$.pred_class), levels(iris$Species))
143+
144+
141145
expect_error(
142146
multi_predict(xy_fit, newdata = iris[rows, 1:4], penalty = lams),
143147
"Did you mean"
@@ -150,3 +154,12 @@ test_that('glmnet probabilities, mulitiple lambda', {
150154
)
151155

152156
})
157+
158+
test_that("predictions are factors with all levels", {
159+
basic <- multinom_reg() %>% set_engine("glmnet") %>% fit(Species ~ ., data = iris[rows, ])
160+
nd <- iris[iris$Species == "setosa", ]
161+
yhat <- predict(basic, new_data = nd, penalty = .1)
162+
yhat_multi <- multi_predict(basic, new_data = nd, penalty = .1)$.pred
163+
expect_is(yhat_multi[[1]]$.pred_class, "factor")
164+
expect_equal(levels(yhat_multi[[1]]$.pred_class), levels(iris$Species))
165+
})

0 commit comments

Comments
 (0)