Skip to content

Commit d8f273c

Browse files
authored
glmnet multi_predict(): Check type for all modes (#900)
* check type for all modes * bump version in anticipation of #897 being merged first * update NEWS * update to merge PR
1 parent 22c87a8 commit d8f273c

File tree

3 files changed

+11
-9
lines changed

3 files changed

+11
-9
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: parsnip
22
Title: A Common API to Modeling and Analysis Functions
3-
Version: 1.0.4.9002
3+
Version: 1.0.4.9003
44
Authors@R: c(
55
person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")),
66
person("Davis", "Vaughan", , "davis@posit.co", role = "aut"),

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
* Several internal functions (to help work with `Surv` objects) were added as a standalone file that can be used in other packages via `usethis::use_standalone("tidymodels/parsnip")`.
1818

19+
* `multi_predict()` methods for `linear_reg()`, `logistic_reg()`, and `multinomial_reg()` models fitted with the `"glmnet"` engine now check the `type` better and error accordingly (#900).
20+
1921
* Rather than being implemented in each method, the check for the `new_data` argument being mistakenly passed as `newdata` to `multi_predict()` now happens in the generic. Packages re-exporting the `multi_predict()` generic and implementing now-duplicate checks may see new failures and can remove their own analogous checks. This check already existed in all `predict()` methods (via `predict.model_fit()`) and all parsnip `multi_predict()` methods (#525).
2022

2123
* `logistic_reg()` will now warn at `fit()` when the outcome has more than two levels (#545).

R/glmnet-engines.R

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -173,14 +173,20 @@ multi_predict_glmnet <- function(object,
173173
type = NULL,
174174
penalty = NULL,
175175
...) {
176+
type <- check_pred_type(object, type)
177+
check_spec_pred_type(object, type)
178+
if (type == "prob") {
179+
check_spec_levels(object)
180+
}
181+
182+
dots <- list(...)
183+
176184
if (object$spec$mode == "classification") {
177185
if (is_quosure(penalty)) {
178186
penalty <- eval_tidy(penalty)
179187
}
180188
}
181189

182-
dots <- list(...)
183-
184190
object$spec <- eval_args(object$spec)
185191

186192
if (is.null(penalty)) {
@@ -195,12 +201,6 @@ multi_predict_glmnet <- function(object,
195201
model_type <- class(object$spec)[1]
196202

197203
if (object$spec$mode == "classification") {
198-
if (is.null(type)) {
199-
type <- "class"
200-
}
201-
if (!(type %in% c("class", "prob", "link", "raw"))) {
202-
rlang::abort("`type` should be either 'class', 'link', 'raw', or 'prob'.")
203-
}
204204
if (type == "prob" |
205205
model_type == "logistic_reg") {
206206
dots$type <- "response"

0 commit comments

Comments
 (0)