Skip to content

Commit bc125e9

Browse files
authored
Merge pull request #485 from tidymodels/glmnet-fail-penalty
Add error for glmnet models if penalty is not exactly 1
2 parents 01168ca + 2d9c1d3 commit bc125e9

17 files changed

+70
-118
lines changed

R/engines.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ load_libs <- function(x, quiet, attach = FALSE) {
8282
#' @examples
8383
#' # First, set general arguments using the standardized names
8484
#' mod <-
85-
#' logistic_reg(mixture = 1/3) %>%
85+
#' logistic_reg(penalty = 0.01, mixture = 1/3) %>%
8686
#' # now say how you want to fit the model and another other options
8787
#' set_engine("glmnet", nlambda = 10)
8888
#' translate(mod, engine = "glmnet")

R/linear_reg.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ translate.linear_reg <- function(x, engine = x$engine, ...) {
112112
# Since the `fit` information is gone for the penalty, we need to have an
113113
# evaluated value for the parameter.
114114
x$args$penalty <- rlang::eval_tidy(x$args$penalty)
115+
check_glmnet_penalty(x)
115116
}
116117

117118
x

R/logistic_reg.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ translate.logistic_reg <- function(x, engine = x$engine, ...) {
115115
# Since the `fit` information is gone for the penalty, we need to have an
116116
# evaluated value for the parameter.
117117
x$args$penalty <- rlang::eval_tidy(x$args$penalty)
118+
check_glmnet_penalty(x)
118119
}
119120

120121
if (engine == "LiblineaR") {

R/misc.R

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,4 +323,13 @@ stan_conf_int <- function(object, newdata) {
323323
rlang::eval_tidy(fn)
324324
}
325325

326-
326+
check_glmnet_penalty <- function(x) {
327+
if (length(x$args$penalty) != 1) {
328+
rlang::abort(c(
329+
"For the glmnet engine, `penalty` must be a single number (or a value of `tune()`).",
330+
glue::glue("There are {length(x$args$penalty)} values for `penalty`."),
331+
"To try multiple values for total regularization, use the tune package.",
332+
"To predict multiple penalties, use `multi_predict()`"
333+
))
334+
}
335+
}

R/translate.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
#' translate(lm_spec, engine = "spark")
3939
#'
4040
#' # with a placeholder for an unknown argument value:
41-
#' translate(linear_reg(mixture = varying()), engine = "glmnet")
41+
#' translate(linear_reg(penalty = varying(), mixture = varying()), engine = "glmnet")
4242
#'
4343
#' @export
4444

man/contr_one_hot.Rd

Lines changed: 6 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/linear_reg.Rd

Lines changed: 4 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/logistic_reg.Rd

Lines changed: 4 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/multinom_reg.Rd

Lines changed: 4 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/rmd/linear-reg.Rmd

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,14 @@ Engines may have pre-set default arguments when executing the model fit call. Fo
1010
```{r lm-reg}
1111
linear_reg() %>%
1212
set_engine("lm") %>%
13-
set_mode("regression") %>%
1413
translate()
1514
```
1615

1716
## glmnet
1817

1918
```{r glmnet-csl}
20-
linear_reg() %>%
19+
linear_reg(penalty = 0.1) %>%
2120
set_engine("glmnet") %>%
22-
set_mode("regression") %>%
2321
translate()
2422
```
2523

@@ -37,7 +35,6 @@ penalty results.
3735
```{r stan-reg}
3836
linear_reg() %>%
3937
set_engine("stan") %>%
40-
set_mode("regression") %>%
4138
translate()
4239
```
4340

@@ -55,7 +52,6 @@ returned.
5552
```{r spark-reg}
5653
linear_reg() %>%
5754
set_engine("spark") %>%
58-
set_mode("regression") %>%
5955
translate()
6056
```
6157

@@ -64,7 +60,6 @@ linear_reg() %>%
6460
```{r keras-reg}
6561
linear_reg() %>%
6662
set_engine("keras") %>%
67-
set_mode("regression") %>%
6863
translate()
6964
```
7065

man/rmd/logistic-reg.Rmd

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,14 @@ For this type of model, the template of the fit calls are below.
1111
```{r glm-reg}
1212
logistic_reg() %>%
1313
set_engine("glm") %>%
14-
set_mode("classification") %>%
1514
translate()
1615
```
1716

1817
## glmnet
1918

2019
```{r glmnet-csl}
21-
logistic_reg() %>%
20+
logistic_reg(penalty = 0.1) %>%
2221
set_engine("glmnet") %>%
23-
set_mode("classification") %>%
2422
translate()
2523
```
2624

@@ -38,7 +36,6 @@ penalty results.
3836
```{r liblinear-reg}
3937
logistic_reg() %>%
4038
set_engine("LiblineaR") %>%
41-
set_mode("classification") %>%
4239
translate()
4340
```
4441

@@ -54,7 +51,6 @@ regularized regression models do not, which will result in different parameter e
5451
```{r stan-reg}
5552
logistic_reg() %>%
5653
set_engine("stan") %>%
57-
set_mode("classification") %>%
5854
translate()
5955
```
6056

@@ -72,7 +68,6 @@ returned.
7268
```{r spark-reg}
7369
logistic_reg() %>%
7470
set_engine("spark") %>%
75-
set_mode("classification") %>%
7671
translate()
7772
```
7873

@@ -81,7 +76,6 @@ logistic_reg() %>%
8176
```{r keras-reg}
8277
logistic_reg() %>%
8378
set_engine("keras") %>%
84-
set_mode("classification") %>%
8579
translate()
8680
```
8781

man/rmd/multinom-reg.Rmd

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@ For this type of model, the template of the fit calls are below.
99
## glmnet
1010

1111
```{r glmnet-cls}
12-
multinom_reg() %>%
12+
multinom_reg(penalty = 0.1) %>%
1313
set_engine("glmnet") %>%
14-
set_mode("classification") %>%
1514
translate()
1615
```
1716

@@ -29,7 +28,6 @@ penalty results.
2928
```{r nnet-cls}
3029
multinom_reg() %>%
3130
set_engine("nnet") %>%
32-
set_mode("classification") %>%
3331
translate()
3432
```
3533

@@ -38,7 +36,6 @@ multinom_reg() %>%
3836
```{r spark-cls}
3937
multinom_reg() %>%
4038
set_engine("spark") %>%
41-
set_mode("classification") %>%
4239
translate()
4340
```
4441

@@ -47,7 +44,6 @@ multinom_reg() %>%
4744
```{r keras-cls}
4845
multinom_reg() %>%
4946
set_engine("keras") %>%
50-
set_mode("classification") %>%
5147
translate()
5248
```
5349

man/set_engine.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/translate.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)