Description
The functions .organize_glmnet_pred()
, organize_glmnet_class()
, and organize_glmnet_prob()
are used in the post hook of the prediction module for the glmnet engines to linear_reg()
and logistic_reg()
.
For example, in line 246
Lines 239 to 256 in a482442
Their job is to reformat the predictions into a format that format_num()
, format_class
, or format_classprobs()
can work with, called inside of predict.model_fit()
.
Since they are only ever called in this way, here in parsnip and the exported .organize_glmnet_pred()
in censored and poissonreg, they will always deal with predictions for a single penalty value. They do contain code to deal with predictions for multiple penalty values but we don't need that here. So I suggest we remove that here.
Why only ever a single penalty value? Because we check for that with .check_glmnet_penalty_predict()
. It is possible to get around this check by setting multi = TRUE
which is used in combination with type = "raw"
inside of multi_predict()
. Type "raw"
ensures that we don't call any of the post hook functions so the predictions don't go through the organize_glmnet*()
functions.
It is possible to get around this check but that is offlabel usage and I think we would be okay not supporting that in exchange for simpler post hook functions.
library(parsnip)
data("hpc_data", package = "modeldata")
hpc <- hpc_data[1:150, c(2:5, 8)]
f_fit <- linear_reg(penalty = 0.1, mixture = 0.3) %>%
set_engine("glmnet", nlambda = 15) %>%
fit(input_fields ~ log(compounds) + class, data = hpc)
# regular usage errors informatively
predict(f_fit, hpc[1:3,], penalty = 1:2)
#> Error in `.check_glmnet_penalty_predict()`:
#> ! `penalty` should be a single numeric value. `multi_predict()` can be used to get multiple predictions per row of data.
# off-label usage
predict(f_fit, hpc[1:3,], penalty = 1:2, multi = TRUE)
#> # A tibble: 6 × 2
#> .pred_values .pred_lambda
#> <dbl> <int>
#> 1 570. 1
#> 2 163. 1
#> 3 167. 1
#> 4 570. 2
#> 5 163. 2
#> 6 168. 2
Created on 2023-02-22 with reprex v2.0.2