Skip to content

Collect functions related to glmnet engines in one file #898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
425 changes: 425 additions & 0 deletions R/glmnet-engines.R

Large diffs are not rendered by default.

172 changes: 0 additions & 172 deletions R/glmnet.R

This file was deleted.

50 changes: 0 additions & 50 deletions R/linear_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -117,53 +117,3 @@ check_args.linear_reg <- function(object) {

invisible(object)
}

# ------------------------------------------------------------------------------

#' Organize glmnet predictions
#'
#' This function is for developer use and organizes predictions from glmnet
#' models.
#'
#' @param x Predictions as returned by the `predict()` method for glmnet models.
#' @param object An object of class `model_fit`.
#'
#' @rdname glmnet_helpers_prediction
#' @keywords internal
#' @export
.organize_glmnet_pred <- function(x, object) {
unname(x[, 1])
}

#' @export
predict._elnet <- predict_glmnet

#' @export
predict_numeric._elnet <- predict_numeric_glmnet

#' @export
predict_raw._elnet <- predict_raw_glmnet

#' @export
#'@rdname multi_predict
#' @param penalty A numeric vector of penalty values.
multi_predict._elnet <- multi_predict_glmnet

format_glmnet_multi_linear_reg <- function(pred, penalty) {
param_key <- tibble(group = colnames(pred), penalty = penalty)
pred <- as_tibble(pred)
pred$.row <- 1:nrow(pred)
pred <- gather(pred, group, .pred, -.row)
if (utils::packageVersion("dplyr") >= "1.0.99.9000") {
pred <- full_join(param_key, pred, by = "group", multiple = "all")
} else {
pred <- full_join(param_key, pred, by = "group")
}
pred$group <- NULL
pred <- arrange(pred, .row, penalty)
.row <- pred$.row
pred$.row <- NULL
pred <- split(pred, .row)
names(pred) <- NULL
tibble(.pred = pred)
}
65 changes: 0 additions & 65 deletions R/logistic_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -164,71 +164,6 @@ prob_to_class_2 <- function(x, object) {
unname(x)
}

organize_glmnet_class <- function(x, object) {
prob_to_class_2(x[, 1], object)
}

organize_glmnet_prob <- function(x, object) {
res <- tibble(v1 = 1 - x[, 1], v2 = x[, 1])
colnames(res) <- object$lvl
res
}

# ------------------------------------------------------------------------------

#' @export
predict._lognet <- predict_glmnet

#' @export
#' @rdname multi_predict
multi_predict._lognet <- multi_predict_glmnet

format_glmnet_multi_logistic_reg <- function(pred, penalty, type, lvl) {

type <- rlang::arg_match(type, c("class", "prob"))

penalty_key <- tibble(s = colnames(pred), penalty = penalty)

pred <- as_tibble(pred)
pred$.row <- seq_len(nrow(pred))
pred <- tidyr::pivot_longer(pred, -.row, names_to = "s", values_to = ".pred")

if (type == "class") {
pred <- pred %>%
dplyr::mutate(.pred_class = dplyr::if_else(.pred >= 0.5, lvl[2], lvl[1]),
.pred_class = factor(.pred_class, levels = lvl),
.keep = "unused")
} else {
pred <- pred %>%
dplyr::mutate(.pred_class_2 = 1 - .pred) %>%
rlang::set_names(c(".row", "s", paste0(".pred_", rev(lvl)))) %>%
dplyr::select(c(".row", "s", paste0(".pred_", lvl)))
}

if (utils::packageVersion("dplyr") >= "1.0.99.9000") {
pred <- dplyr::full_join(penalty_key, pred, by = "s", multiple = "all")
} else {
pred <- dplyr::full_join(penalty_key, pred, by = "s")
}

pred <- pred %>%
dplyr::select(-s) %>%
dplyr::arrange(penalty) %>%
tidyr::nest(.by = .row, .key = ".pred") %>%
dplyr::select(-.row)

pred
}

#' @export
predict_class._lognet <- predict_class_glmnet

#' @export
predict_classprob._lognet <- predict_classprob_glmnet

#' @export
predict_raw._lognet <- predict_raw_glmnet

# ------------------------------------------------------------------------------

liblinear_preds <- function(results, object) {
Expand Down
68 changes: 0 additions & 68 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -459,74 +459,6 @@ stan_conf_int <- function(object, newdata) {

# ------------------------------------------------------------------------------


#' Helper functions for checking the penalty of glmnet models
#'
#' @description
#' These functions are for developer use.
#'
#' `.check_glmnet_penalty_fit()` checks that the model specification for fitting a
#' glmnet model contains a single value.
#'
#' `.check_glmnet_penalty_predict()` checks that the penalty value used for prediction is valid.
#' If called by `predict()`, it needs to be a single value. Multiple values are
#' allowed for `multi_predict()`.
#'
#' @param x An object of class `model_spec`.
#' @rdname glmnet_helpers
#' @keywords internal
#' @export
.check_glmnet_penalty_fit <- function(x) {
pen <- rlang::eval_tidy(x$args$penalty)

if (length(pen) != 1) {
rlang::abort(c(
"For the glmnet engine, `penalty` must be a single number (or a value of `tune()`).",
glue::glue("There are {length(pen)} values for `penalty`."),
"To try multiple values for total regularization, use the tune package.",
"To predict multiple penalties, use `multi_predict()`"
))
}
}

#' @param penalty A penalty value to check.
#' @param object An object of class `model_fit`.
#' @param multi A logical indicating if multiple values are allowed.
#'
#' @rdname glmnet_helpers
#' @keywords internal
#' @export
.check_glmnet_penalty_predict <- function(penalty = NULL, object, multi = FALSE) {
if (is.null(penalty)) {
penalty <- object$fit$lambda
}

# when using `predict()`, allow for a single lambda
if (!multi) {
if (length(penalty) != 1) {
rlang::abort(
glue::glue(
"`penalty` should be a single numeric value. `multi_predict()` ",
"can be used to get multiple predictions per row of data.",
)
)
}
}

if (length(object$fit$lambda) == 1 && penalty != object$fit$lambda) {
rlang::abort(
glue::glue(
"The glmnet model was fit with a single penalty value of ",
"{object$fit$lambda}. Predicting with a value of {penalty} ",
"will give incorrect results from `glmnet()`."
)
)
}

penalty
}


check_case_weights <- function(x, spec) {
if (is.null(x) | spec$engine == "spark") {
return(invisible(NULL))
Expand Down
Loading