Skip to content

Commit 79f8c9f

Browse files
committed
Logistic regression test cases
1 parent 62075a8 commit 79f8c9f

File tree

2 files changed

+163
-6
lines changed

2 files changed

+163
-6
lines changed

R/logistic_reg.R

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
#' # Parameters can be represented by a placeholder:
5050
#' logistic_reg(regularization = varying())
5151
#' @export
52-
5352
#' @importFrom rlang expr enquo missing_arg
5453
#' @importFrom purrr map_lgl
5554
logistic_reg <-
@@ -235,19 +234,19 @@ finalize.logistic_reg <- function(x, engine = NULL, ...) {
235234
x <- check_engine(x)
236235

237236
# exceptions and error trapping here
238-
if(engine %in% c("glm", "stan_glm") & !null_value(x$args$regularization)) {
237+
if(x$engine %in% c("glm", "stan_glm") & !null_value(x$args$regularization)) {
239238
warning("The argument `regularization` cannot be used with this engine. ",
240239
"The value will be set to NULL")
241240
x$args$regularization <- quos(NULL)
242241
}
243-
if(engine %in% c("glm", "stan_glm") & !null_value(x$args$mixture)) {
242+
if(x$engine %in% c("glm", "stan_glm") & !null_value(x$args$mixture)) {
244243
warning("The argument `mixture` cannot be used with this engine. ",
245244
"The value will be set to NULL")
246245
x$args$mixture <- quos(NULL)
247246
}
248247

249248
x$method <- get_model_objects(x, x$engine)()
250-
if(!(engine %in% c("glm", "stan_glm"))) {
249+
if(!(x$engine %in% c("glm", "stan_glm"))) {
251250
real_args <- deharmonize(x$args, logistic_reg_arg_key, x$engine)
252251
# replace default args with user-specified
253252
x$method$fit <-
@@ -270,14 +269,14 @@ finalize.logistic_reg <- function(x, engine = NULL, ...) {
270269
x$method$fit <- sub_arg_values(x$method$fit, x$others, ignore = x$method$protect)
271270

272271
# remove NULL and unmodified argument values
273-
modifed_args <- if (!(engine %in% c("glm", "stan_glm")))
272+
modifed_args <- if (!(x$engine %in% c("glm", "stan_glm")))
274273
names(real_args)[!vapply(real_args, null_value, lgl(1))]
275274
else
276275
NULL
277276
modifed_args <- unique(c("family", modifed_args))
278277

279278
# glmnet can't handle NULL weights
280-
if (engine == "glmnet" & identical(x$method$fit$weights, quote(missing_arg())))
279+
if (x$engine == "glmnet" & identical(x$method$fit$weights, quote(missing_arg())))
281280
x$method$protect <- x$method$protect[x$method$protect != "weights"]
282281

283282
x$method$fit <- prune_expr(x$method$fit, x$method$protect, c(modifed_args, names(x$others)))

tests/testthat/test_logistic_reg.R

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
library(testthat)
2+
library(parsnip)
3+
4+
5+
test_that('primary arguments', {
6+
basic <- logistic_reg()
7+
basic_glm <- finalize(basic, engine = "glm")
8+
basic_glmnet <- finalize(basic, engine = "glmnet")
9+
basic_stan <- finalize(basic, engine = "stan_glm")
10+
expect_equal(basic_glm$method$fit,
11+
quote(
12+
glm(
13+
formula = formula,
14+
family = binomial(),
15+
data = data,
16+
weights = NULL
17+
)
18+
)
19+
)
20+
expect_equal(basic_glmnet$method$fit,
21+
quote(
22+
glmnet(
23+
x = as.matrix(x),
24+
y = y,
25+
family = "binomial"
26+
)
27+
)
28+
)
29+
expect_equal(basic_stan$method$fit,
30+
quote(
31+
stan_glm(
32+
formula = formula,
33+
family = binomial(),
34+
data = data,
35+
weights = NULL
36+
)
37+
)
38+
)
39+
40+
mixture <- logistic_reg(mixture = 0.128)
41+
mixture_glmnet <- finalize(mixture, engine = "glmnet")
42+
expect_equal(mixture_glmnet$method$fit,
43+
quote(
44+
glmnet(
45+
x = as.matrix(x),
46+
y = y,
47+
family = "binomial",
48+
alpha = 0.128
49+
)
50+
)
51+
)
52+
53+
regularization <- logistic_reg(regularization = 1)
54+
regularization_glmnet <- finalize(regularization, engine = "glmnet")
55+
expect_equal(regularization_glmnet$method$fit,
56+
quote(
57+
glmnet(
58+
x = as.matrix(x),
59+
y = y,
60+
family = "binomial",
61+
lambda = 1
62+
)
63+
)
64+
)
65+
66+
mixture_v <- logistic_reg(mixture = varying())
67+
mixture_v_glmnet <- finalize(mixture_v, engine = "glmnet")
68+
expect_equal(mixture_v_glmnet$method$fit,
69+
quote(
70+
glmnet(
71+
x = as.matrix(x),
72+
y = y,
73+
family = "binomial",
74+
alpha = varying()
75+
)
76+
)
77+
)
78+
})
79+
80+
test_that('engine arguments', {
81+
glm_fam <- logistic_reg(engine_args = list(family = binomial(link = "probit")))
82+
expect_equal(finalize(glm_fam, engine = "glm")$method$fit,
83+
quote(
84+
glm(
85+
formula = formula,
86+
family = binomial(link = "probit"),
87+
data = data,
88+
weights = NULL
89+
)
90+
)
91+
)
92+
93+
glmnet_nlam <- logistic_reg(engine_args = list(nlambda = 10))
94+
expect_equal(finalize(glmnet_nlam, engine = "glmnet")$method$fit,
95+
quote(
96+
glmnet(
97+
x = as.matrix(x),
98+
y = y,
99+
family = "binomial",
100+
nlambda = 10
101+
)
102+
)
103+
)
104+
105+
# these should get pass into the ... slot
106+
stan_samp <- logistic_reg(engine_args = list(chains = 1, iter = 5))
107+
expect_equal(finalize(stan_samp, engine = "stan_glm")$method$fit,
108+
quote(
109+
stan_glm(
110+
formula = formula,
111+
family = binomial(),
112+
data = data,
113+
weights = NULL,
114+
chains = 1,
115+
iter = 5
116+
)
117+
)
118+
)
119+
120+
})
121+
122+
123+
test_that('updating', {
124+
expr1 <- logistic_reg( engine_args = list(family = binomial(link = "probit")))
125+
expr1_exp <- logistic_reg(mixture = 0, engine_args = list(family = binomial(link = "probit")))
126+
127+
expr2 <- logistic_reg(mixture = varying())
128+
expr2_exp <- logistic_reg(mixture = varying(), engine_args = list(nlambda = 10))
129+
130+
expr3 <- logistic_reg(mixture = 0, regularization = varying())
131+
expr3_exp <- logistic_reg(mixture = 1)
132+
133+
expr4 <- logistic_reg(mixture = 0, engine_args = list(nlambda = 10))
134+
expr4_exp <- logistic_reg(mixture = 0, engine_args = list(nlambda = 10, pmax = 2))
135+
136+
expr5 <- logistic_reg(mixture = 1, engine_args = list(nlambda = 10))
137+
expr5_exp <- logistic_reg(mixture = 1, engine_args = list(nlambda = 10, pmax = 2))
138+
139+
expect_equal(update(expr1, mixture = 0), expr1_exp)
140+
expect_equal(update(expr2, engine_args = list(nlambda = 10)), expr2_exp)
141+
expect_equal(update(expr3, mixture = 1, fresh = TRUE), expr3_exp)
142+
expect_equal(update(expr4, engine_args = list(pmax = 2)), expr4_exp)
143+
expect_equal(update(expr5, engine_args = list(nlambda = 10, pmax = 2)), expr5_exp)
144+
145+
})
146+
147+
test_that('bad input', {
148+
expect_error(logistic_reg(ase.weights = var))
149+
expect_error(logistic_reg(mode = "regression"))
150+
expect_error(finalize(logistic_reg(), engine = "wat?"))
151+
expect_warning(finalize(logistic_reg(), engine = NULL))
152+
expect_error(finalize(logistic_reg(engine_args = list(ytest = 2)), engine = "glmnet"))
153+
expect_error(finalize(logistic_reg(formula = y ~ x)))
154+
expect_warning(finalize(logistic_reg(engine_args = list(x = x, y = y)), engine = "glmnet"))
155+
expect_warning(finalize(logistic_reg(engine_args = list(formula = y ~ x)), engine = "glm"))
156+
})
157+
158+

0 commit comments

Comments
 (0)