Skip to content

Commit e9abdf5

Browse files
topepoEmilHvitfeldtDavisVaughan
authored
glmnet autoplot method for #642 (#643)
* glmnet autoplot method * workflow methods * add pkgdown entry * makes labels appear on best_penalty line if present * Apply suggestions from code review Co-authored-by: Davis Vaughan <davis@rstudio.com> * move ggrepl to suggests * doc updates * anmespace functions and check for glmnet package * remove workflow method * move ... up in order * A note about ggrepl * more consistent test files * fix nocov tags Co-authored-by: Emil Hvitfeldt <emilhhvitfeldt@gmail.com> Co-authored-by: Davis Vaughan <davis@rstudio.com>
1 parent 45ec8e6 commit e9abdf5

17 files changed

+239
-65
lines changed

DESCRIPTION

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: parsnip
22
Title: A Common API to Modeling and Analysis Functions
3-
Version: 0.1.7.9005
3+
Version: 0.1.7.9006
44
Authors@R: c(
55
person("Max", "Kuhn", , "max@rstudio.com", role = c("aut", "cre")),
66
person("Davis", "Vaughan", , "davis@rstudio.com", role = "aut"),
@@ -21,6 +21,7 @@ Imports:
2121
cli,
2222
dplyr (>= 0.8.0.1),
2323
generics (>= 0.1.0.9000),
24+
ggplot2,
2425
globals,
2526
glue,
2627
hardhat (>= 0.1.6.9001),
@@ -41,7 +42,7 @@ Suggests:
4142
dials (>= 0.0.10.9001),
4243
earth,
4344
tensorflow,
44-
ggplot2,
45+
ggrepel,
4546
keras,
4647
kernlab,
4748
kknn,

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ export(C5.0_train)
165165
export(C5_rules)
166166
export(add_rowindex)
167167
export(augment)
168+
export(autoplot)
168169
export(bag_mars)
169170
export(bag_tree)
170171
export(bart)
@@ -314,13 +315,15 @@ importFrom(generics,glance)
314315
importFrom(generics,required_pkgs)
315316
importFrom(generics,tidy)
316317
importFrom(generics,varying_args)
318+
importFrom(ggplot2,autoplot)
317319
importFrom(glue,glue_collapse)
318320
importFrom(hardhat,extract_fit_engine)
319321
importFrom(hardhat,extract_parameter_dials)
320322
importFrom(hardhat,extract_parameter_set_dials)
321323
importFrom(hardhat,extract_spec_parsnip)
322324
importFrom(hardhat,tune)
323325
importFrom(magrittr,"%>%")
326+
importFrom(purrr,"%||%")
324327
importFrom(purrr,as_vector)
325328
importFrom(purrr,imap)
326329
importFrom(purrr,imap_lgl)

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,13 @@
4040

4141
* `varying_args()` is soft-deprecated in favor of `tune_args()`.
4242

43+
* An `autoplot()` method was added for glmnet objects, showing the coefficient paths versus the penalty values (#642).
44+
4345
* parsnip is now more robust working with keras and tensorflow for a larger range of versions (#596).
4446

4547
* xgboost engines now use the new `iterationrange` parameter instead of the deprecated `ntreelimit` (#656).
4648

49+
4750
# parsnip 0.1.7
4851

4952
## Model Specification Changes

R/0_imports.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#' @importFrom generics varying_args
44
#' @importFrom glue glue_collapse
55
#' @importFrom purrr as_vector imap imap_lgl map map_chr map_dbl map_df map_dfr
6-
#' @importFrom purrr map_lgl
6+
#' @importFrom purrr map_lgl %||%
77
#' @importFrom rlang abort call2 caller_env current_env enquo enquos eval_tidy
88
#' @importFrom rlang expr get_expr is_empty is_missing is_null is_quosure
99
#' @importFrom rlang is_symbolic lgl missing_arg quo_get_expr quos sym syms
@@ -16,4 +16,5 @@
1616
#' @importFrom utils capture.output getFromNamespace globalVariables head
1717
#' @importFrom utils methods stack
1818
#' @importFrom vctrs vec_size vec_unique
19+
#' @importFrom ggplot2 autoplot
1920
NULL

R/aaa.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ function(results, object) {
8282
}
8383

8484
# ------------------------------------------------------------------------------
85-
# nocov
85+
# nocov start
8686

8787
utils::globalVariables(
8888
c('.', '.label', '.pred', '.row', 'data', 'engine', 'engine2', 'group',
@@ -91,7 +91,9 @@ utils::globalVariables(
9191
"max_terms", "max_tree", "model", "name", "num_terms", "penalty", "trees",
9292
"sub_neighbors", ".pred_class", "x", "y", "predictor_indicators",
9393
"compute_intercept", "remove_intercept", "estimate", "term",
94-
"call_info", "component", "component_id", "func", "pkg", ".order", "item", "tunable")
94+
"call_info", "component", "component_id", "func", "tunable", "label",
95+
"pkg", ".order", "item", "tunable"
96+
)
9597
)
9698

9799
# nocov end

R/autoplot.R

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
#' Create a ggplot for a model object
2+
#'
3+
#' This method provides a good visualization method for model results.
4+
#' Currently, only methods for glmnet models are implemented.
5+
#'
6+
#' @param object A model fit object.
7+
#' @param min_penalty A single, non-negative number for the smallest penalty
8+
#' value that should be shown in the plot. If left `NULL`, the whole data
9+
#' range is used.
10+
#' @param best_penalty A single, non-negative number that will show a vertical
11+
#' line marker. If left `NULL`, no line is shown. When this argument is used,
12+
#' the \pkg{ggrepl} package is required.
13+
#' @param top_n A non-negative integer for how many model predictors to label.
14+
#' The top predictors are ranked by their absolute coefficient value. For
15+
#' multinomial or multivariate models, the `top_n` terms are selected within
16+
#' class or response, respectively.
17+
#' @param ... For [autoplot.glmnet()], options to pass to
18+
#' [ggrepel::geom_label_repel()]. Otherwise, this argument is ignored.
19+
#' @return A ggplot object with penalty on the x-axis and coefficients on the
20+
#' y-axis. For multinomial or multivariate models, the plot is faceted.
21+
#' @details The \pkg{glmnet} package will need to be attached or loaded for
22+
#' its `autoplot()` method to work correctly.
23+
#'
24+
# registered in zzz.R
25+
autoplot.model_fit <- function(object, ...) {
26+
autoplot(object$fit, ...)
27+
}
28+
29+
# glmnet is not a formal dependency here.
30+
# unit tests are located at https://github.com/tidymodels/extratests
31+
# nocov start
32+
33+
# registered in zzz.R
34+
#' @rdname autoplot.model_fit
35+
autoplot.glmnet <- function(object, ..., min_penalty = 0, best_penalty = NULL,
36+
top_n = 3L) {
37+
autoplot_glmnet(object, min_penalty, best_penalty, top_n, ...)
38+
}
39+
40+
41+
map_glmnet_coefs <- function(x) {
42+
coefs <- coef(x)
43+
# If parsnip is used to fit the model, glmnet should be attached and this will
44+
# work. If an object is loaded from a new session, they will need to load the
45+
# package.
46+
if (is.null(coefs)) {
47+
rlang::abort("Please load the glmnet package before running `autoplot()`.")
48+
}
49+
p <- x$dim[1]
50+
if (is.list(coefs)) {
51+
classes <- names(coefs)
52+
coefs <- purrr::map(coefs, reformat_coefs, p = p, penalty = x$lambda)
53+
coefs <- purrr::map2_dfr(coefs, classes, ~ dplyr::mutate(.x, class = .y))
54+
} else {
55+
coefs <- reformat_coefs(coefs, p = p, penalty = x$lambda)
56+
}
57+
coefs
58+
}
59+
60+
reformat_coefs <- function(x, p, penalty) {
61+
x <- as.matrix(x)
62+
num_estimates <- nrow(x)
63+
if (num_estimates > p) {
64+
# The intercept is first
65+
x <- x[-(num_estimates - p),, drop = FALSE]
66+
}
67+
term_lab <- rownames(x)
68+
colnames(x) <- paste(seq_along(penalty))
69+
x <- tibble::as_tibble(x)
70+
x$term <- term_lab
71+
x <- tidyr::pivot_longer(x, cols = -term, names_to = "index", values_to = "estimate")
72+
x$penalty <- rep(penalty, p)
73+
x$index <- NULL
74+
x
75+
}
76+
77+
top_coefs <- function(x, top_n = 5) {
78+
x %>%
79+
dplyr::group_by(term) %>%
80+
dplyr::arrange(term, dplyr::desc(abs(estimate))) %>%
81+
dplyr::slice(1) %>%
82+
dplyr::ungroup() %>%
83+
dplyr::arrange(dplyr::desc(abs(estimate))) %>%
84+
dplyr::slice(1:top_n)
85+
}
86+
87+
autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L, ...) {
88+
check_penalty_value(min_penalty)
89+
90+
tidy_coefs <-
91+
map_glmnet_coefs(x) %>%
92+
dplyr::filter(penalty >= min_penalty)
93+
94+
actual_min_penalty <- min(tidy_coefs$penalty)
95+
num_terms <- length(unique(tidy_coefs$term))
96+
top_n <- min(top_n[1], num_terms)
97+
if (top_n < 0) {
98+
top_n <- 0
99+
}
100+
101+
has_groups <- any(names(tidy_coefs) == "class")
102+
103+
# Keep the large values
104+
if (has_groups) {
105+
label_coefs <-
106+
tidy_coefs %>%
107+
dplyr::group_nest(class) %>%
108+
dplyr::mutate(data = purrr::map(data, top_coefs, top_n = top_n)) %>%
109+
dplyr::select(class, data) %>%
110+
tidyr::unnest(cols = data)
111+
} else {
112+
if (is.null(best_penalty)) {
113+
label_coefs <- tidy_coefs %>%
114+
top_coefs(top_n)
115+
} else {
116+
label_coefs <- tidy_coefs %>%
117+
dplyr::filter(penalty > best_penalty) %>%
118+
dplyr::filter(penalty == min(penalty)) %>%
119+
dplyr::arrange(dplyr::desc(abs(estimate))) %>%
120+
dplyr::slice(seq_len(top_n))
121+
}
122+
}
123+
124+
label_coefs <-
125+
label_coefs %>%
126+
dplyr::mutate(penalty = best_penalty %||% actual_min_penalty) %>%
127+
dplyr::mutate(label = gsub(".pred_no_", "", term))
128+
129+
# plot the paths and highlight the large values
130+
p <-
131+
tidy_coefs %>%
132+
ggplot2::ggplot(ggplot2::aes(x = penalty, y = estimate, group = term, col = term))
133+
134+
if (has_groups) {
135+
p <- p + ggplot2::facet_wrap(~ class)
136+
}
137+
138+
if (!is.null(best_penalty)) {
139+
check_penalty_value(best_penalty)
140+
p <- p + ggplot2::geom_vline(xintercept = best_penalty, lty = 3)
141+
}
142+
143+
p <- p +
144+
ggplot2::geom_line(alpha = .4, show.legend = FALSE) +
145+
ggplot2::scale_x_log10()
146+
147+
if(top_n > 0) {
148+
rlang::check_installed("ggrepel")
149+
p <- p +
150+
ggrepel::geom_label_repel(
151+
data = label_coefs,
152+
ggplot2::aes(y = estimate, label = label),
153+
show.legend = FALSE,
154+
...
155+
)
156+
}
157+
p
158+
}
159+
160+
check_penalty_value <- function(x) {
161+
cl <- match.call()
162+
arg_val <- as.character(cl$x)
163+
if (!is.vector(x) || length(x) != 1 || !is.numeric(x) || x < 0) {
164+
msg <- paste0("Argument '", arg_val, "' should be a single, non-negative value.")
165+
rlang::abort(msg)
166+
}
167+
invisible(x)
168+
}
169+
170+
# nocov end

R/reexports.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
#' @importFrom ggplot2 autoplot
2+
#' @export
3+
ggplot2::autoplot
14

25
#' @importFrom magrittr %>%
36
#' @export

R/zzz.R

Lines changed: 3 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
s3_register("generics::required_pkgs", "model_fit")
1313
s3_register("generics::required_pkgs", "model_spec")
1414

15+
s3_register("ggplot2::autoplot", "model_fit")
16+
s3_register("ggplot2::autoplot", "glmnet")
17+
1518
# - If tune isn't installed, register the method (`packageVersion()` will error here)
1619
# - If tune >= 0.1.6.9001 is installed, register the method
1720
should_register_tune_args_method <- tryCatch(
@@ -108,61 +111,3 @@ s3_register <- function(generic, class, method = NULL) {
108111

109112
# nocov end
110113

111-
112-
113-
#' ## nocov start
114-
#'
115-
#' data_obj <- ls(pattern = "_data$")
116-
#' data_obj <- data_obj[data_obj != "prepare_data"]
117-
#'
118-
#' data_names <-
119-
#' map_dfr(
120-
#' data_obj,
121-
#' function(x) {
122-
#' module <- names(get(x))
123-
#' if (length(module) > 1) {
124-
#' module <- table(module)
125-
#' module <- as_tibble(module)
126-
#' module$object <- x
127-
#' module
128-
#' } else
129-
#' module <- NULL
130-
#' module
131-
#' }
132-
#' )
133-
#'
134-
#' if(any(data_names$n > 1)) {
135-
#' print(data_names[data_names$n > 1,])
136-
#' rlang::abort("Some models have duplicate module names.")
137-
#' }
138-
#' rm(data_names)
139-
#'
140-
#' # ------------------------------------------------------------------------------
141-
#'
142-
#' engine_objects <- ls(pattern = "_engines$")
143-
#' engine_objects <- engine_objects[engine_objects != "possible_engines"]
144-
#'
145-
#' get_engine_info <- function(x) {
146-
#' y <- x
147-
#' y <- get(y)
148-
#' z <- stack(y)
149-
#' z$mode <- rownames(y)
150-
#' z$model <- gsub("_engines$", "", x)
151-
#' z$object <- x
152-
#' z <- z[z$values,]
153-
#' z <- z[z$mode != "unknown",]
154-
#' z$values <- NULL
155-
#' names(z)[1] <- "engine"
156-
#' z$engine <- as.character(z$engine)
157-
#' z
158-
#' }
159-
#'
160-
#' engine_info <-
161-
#' purrr::map_df(
162-
#' parsnip:::engine_objects,
163-
#' get_engine_info
164-
#' )
165-
#'
166-
#' rm(engine_objects)
167-
#'
168-
#' ## nocov end

_pkgdown.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ reference:
4848
- svm_rbf
4949
- title: Infrastructure
5050
contents:
51+
- autoplot.model_fit
5152
- add_rowindex
5253
- augment.model_fit
5354
- descriptors

man/autoplot.model_fit.Rd

Lines changed: 42 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)