Skip to content

Commit 3cb0914

Browse files
authored
Collect functions related to glmnet engines in one file (#898)
* move helpers to check the penalty * group predict vs multi_predict * move elnet helpers * move lognet helpers * move multnet helpers * rename file * update docs to reflect new location
1 parent 92b2bd5 commit 3cb0914

File tree

9 files changed

+431
-433
lines changed

9 files changed

+431
-433
lines changed

R/glmnet-engines.R

Lines changed: 425 additions & 0 deletions
Large diffs are not rendered by default.

R/glmnet.R

Lines changed: 0 additions & 172 deletions
This file was deleted.

R/linear_reg.R

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -117,53 +117,3 @@ check_args.linear_reg <- function(object) {
117117

118118
invisible(object)
119119
}
120-
121-
# ------------------------------------------------------------------------------
122-
123-
#' Organize glmnet predictions
124-
#'
125-
#' This function is for developer use and organizes predictions from glmnet
126-
#' models.
127-
#'
128-
#' @param x Predictions as returned by the `predict()` method for glmnet models.
129-
#' @param object An object of class `model_fit`.
130-
#'
131-
#' @rdname glmnet_helpers_prediction
132-
#' @keywords internal
133-
#' @export
134-
.organize_glmnet_pred <- function(x, object) {
135-
unname(x[, 1])
136-
}
137-
138-
#' @export
139-
predict._elnet <- predict_glmnet
140-
141-
#' @export
142-
predict_numeric._elnet <- predict_numeric_glmnet
143-
144-
#' @export
145-
predict_raw._elnet <- predict_raw_glmnet
146-
147-
#' @export
148-
#'@rdname multi_predict
149-
#' @param penalty A numeric vector of penalty values.
150-
multi_predict._elnet <- multi_predict_glmnet
151-
152-
format_glmnet_multi_linear_reg <- function(pred, penalty) {
153-
param_key <- tibble(group = colnames(pred), penalty = penalty)
154-
pred <- as_tibble(pred)
155-
pred$.row <- 1:nrow(pred)
156-
pred <- gather(pred, group, .pred, -.row)
157-
if (utils::packageVersion("dplyr") >= "1.0.99.9000") {
158-
pred <- full_join(param_key, pred, by = "group", multiple = "all")
159-
} else {
160-
pred <- full_join(param_key, pred, by = "group")
161-
}
162-
pred$group <- NULL
163-
pred <- arrange(pred, .row, penalty)
164-
.row <- pred$.row
165-
pred$.row <- NULL
166-
pred <- split(pred, .row)
167-
names(pred) <- NULL
168-
tibble(.pred = pred)
169-
}

R/logistic_reg.R

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -164,71 +164,6 @@ prob_to_class_2 <- function(x, object) {
164164
unname(x)
165165
}
166166

167-
organize_glmnet_class <- function(x, object) {
168-
prob_to_class_2(x[, 1], object)
169-
}
170-
171-
organize_glmnet_prob <- function(x, object) {
172-
res <- tibble(v1 = 1 - x[, 1], v2 = x[, 1])
173-
colnames(res) <- object$lvl
174-
res
175-
}
176-
177-
# ------------------------------------------------------------------------------
178-
179-
#' @export
180-
predict._lognet <- predict_glmnet
181-
182-
#' @export
183-
#' @rdname multi_predict
184-
multi_predict._lognet <- multi_predict_glmnet
185-
186-
format_glmnet_multi_logistic_reg <- function(pred, penalty, type, lvl) {
187-
188-
type <- rlang::arg_match(type, c("class", "prob"))
189-
190-
penalty_key <- tibble(s = colnames(pred), penalty = penalty)
191-
192-
pred <- as_tibble(pred)
193-
pred$.row <- seq_len(nrow(pred))
194-
pred <- tidyr::pivot_longer(pred, -.row, names_to = "s", values_to = ".pred")
195-
196-
if (type == "class") {
197-
pred <- pred %>%
198-
dplyr::mutate(.pred_class = dplyr::if_else(.pred >= 0.5, lvl[2], lvl[1]),
199-
.pred_class = factor(.pred_class, levels = lvl),
200-
.keep = "unused")
201-
} else {
202-
pred <- pred %>%
203-
dplyr::mutate(.pred_class_2 = 1 - .pred) %>%
204-
rlang::set_names(c(".row", "s", paste0(".pred_", rev(lvl)))) %>%
205-
dplyr::select(c(".row", "s", paste0(".pred_", lvl)))
206-
}
207-
208-
if (utils::packageVersion("dplyr") >= "1.0.99.9000") {
209-
pred <- dplyr::full_join(penalty_key, pred, by = "s", multiple = "all")
210-
} else {
211-
pred <- dplyr::full_join(penalty_key, pred, by = "s")
212-
}
213-
214-
pred <- pred %>%
215-
dplyr::select(-s) %>%
216-
dplyr::arrange(penalty) %>%
217-
tidyr::nest(.by = .row, .key = ".pred") %>%
218-
dplyr::select(-.row)
219-
220-
pred
221-
}
222-
223-
#' @export
224-
predict_class._lognet <- predict_class_glmnet
225-
226-
#' @export
227-
predict_classprob._lognet <- predict_classprob_glmnet
228-
229-
#' @export
230-
predict_raw._lognet <- predict_raw_glmnet
231-
232167
# ------------------------------------------------------------------------------
233168

234169
liblinear_preds <- function(results, object) {

R/misc.R

Lines changed: 0 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -459,74 +459,6 @@ stan_conf_int <- function(object, newdata) {
459459

460460
# ------------------------------------------------------------------------------
461461

462-
463-
#' Helper functions for checking the penalty of glmnet models
464-
#'
465-
#' @description
466-
#' These functions are for developer use.
467-
#'
468-
#' `.check_glmnet_penalty_fit()` checks that the model specification for fitting a
469-
#' glmnet model contains a single value.
470-
#'
471-
#' `.check_glmnet_penalty_predict()` checks that the penalty value used for prediction is valid.
472-
#' If called by `predict()`, it needs to be a single value. Multiple values are
473-
#' allowed for `multi_predict()`.
474-
#'
475-
#' @param x An object of class `model_spec`.
476-
#' @rdname glmnet_helpers
477-
#' @keywords internal
478-
#' @export
479-
.check_glmnet_penalty_fit <- function(x) {
480-
pen <- rlang::eval_tidy(x$args$penalty)
481-
482-
if (length(pen) != 1) {
483-
rlang::abort(c(
484-
"For the glmnet engine, `penalty` must be a single number (or a value of `tune()`).",
485-
glue::glue("There are {length(pen)} values for `penalty`."),
486-
"To try multiple values for total regularization, use the tune package.",
487-
"To predict multiple penalties, use `multi_predict()`"
488-
))
489-
}
490-
}
491-
492-
#' @param penalty A penalty value to check.
493-
#' @param object An object of class `model_fit`.
494-
#' @param multi A logical indicating if multiple values are allowed.
495-
#'
496-
#' @rdname glmnet_helpers
497-
#' @keywords internal
498-
#' @export
499-
.check_glmnet_penalty_predict <- function(penalty = NULL, object, multi = FALSE) {
500-
if (is.null(penalty)) {
501-
penalty <- object$fit$lambda
502-
}
503-
504-
# when using `predict()`, allow for a single lambda
505-
if (!multi) {
506-
if (length(penalty) != 1) {
507-
rlang::abort(
508-
glue::glue(
509-
"`penalty` should be a single numeric value. `multi_predict()` ",
510-
"can be used to get multiple predictions per row of data.",
511-
)
512-
)
513-
}
514-
}
515-
516-
if (length(object$fit$lambda) == 1 && penalty != object$fit$lambda) {
517-
rlang::abort(
518-
glue::glue(
519-
"The glmnet model was fit with a single penalty value of ",
520-
"{object$fit$lambda}. Predicting with a value of {penalty} ",
521-
"will give incorrect results from `glmnet()`."
522-
)
523-
)
524-
}
525-
526-
penalty
527-
}
528-
529-
530462
check_case_weights <- function(x, spec) {
531463
if (is.null(x) | spec$engine == "spark") {
532464
return(invisible(NULL))

0 commit comments

Comments
 (0)