Skip to content

Reduce complexity of "post hook" functions for glmnet #876

Closed
@hfrick

Description

@hfrick

The functions .organize_glmnet_pred(), organize_glmnet_class(), and organize_glmnet_prob() are used in the post hook of the prediction module for the glmnet engines to linear_reg() and logistic_reg().

For example, in line 246

parsnip/R/linear_reg_data.R

Lines 239 to 256 in a482442

set_pred(
model = "linear_reg",
eng = "glmnet",
mode = "regression",
type = "numeric",
value = list(
pre = NULL,
post = .organize_glmnet_pred,
func = c(fun = "predict"),
args =
list(
object = expr(object$fit),
newx = expr(as.matrix(new_data[, rownames(object$fit$beta), drop = FALSE])),
type = "response",
s = expr(object$spec$args$penalty)
)
)
)

Their job is to reformat the predictions into a format that format_num(), format_class, or format_classprobs() can work with, called inside of predict.model_fit().

Since they are only ever called in this way, here in parsnip and the exported .organize_glmnet_pred() in censored and poissonreg, they will always deal with predictions for a single penalty value. They do contain code to deal with predictions for multiple penalty values but we don't need that here. So I suggest we remove that here.

Why only ever a single penalty value? Because we check for that with .check_glmnet_penalty_predict(). It is possible to get around this check by setting multi = TRUE which is used in combination with type = "raw" inside of multi_predict(). Type "raw" ensures that we don't call any of the post hook functions so the predictions don't go through the organize_glmnet*() functions.

It is possible to get around this check but that is offlabel usage and I think we would be okay not supporting that in exchange for simpler post hook functions.

library(parsnip)
data("hpc_data", package = "modeldata")
hpc <- hpc_data[1:150, c(2:5, 8)]

f_fit <- linear_reg(penalty = 0.1, mixture = 0.3) %>%
  set_engine("glmnet", nlambda = 15) %>%
  fit(input_fields ~ log(compounds) + class, data = hpc)

# regular usage errors informatively
predict(f_fit, hpc[1:3,], penalty = 1:2)
#> Error in `.check_glmnet_penalty_predict()`:
#> ! `penalty` should be a single numeric value. `multi_predict()` can be used to get multiple predictions per row of data.

# off-label usage
predict(f_fit, hpc[1:3,], penalty = 1:2, multi = TRUE)
#> # A tibble: 6 × 2
#>   .pred_values .pred_lambda
#>          <dbl>        <int>
#> 1         570.            1
#> 2         163.            1
#> 3         167.            1
#> 4         570.            2
#> 5         163.            2
#> 6         168.            2

Created on 2023-02-22 with reprex v2.0.2

Metadata

Metadata

Assignees

No one assigned

    Labels

    featurea feature request or enhancement

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions