Skip to content

Commit fd5124c

Browse files
Merge pull request #853 from tidymodels/extract_fit_time
add extract_fit_time
2 parents eb526fa + 6d76a59 commit fd5124c

File tree

11 files changed

+96
-38
lines changed

11 files changed

+96
-38
lines changed

DESCRIPTION

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ Imports:
2525
ggplot2,
2626
globals,
2727
glue,
28-
hardhat (>= 1.1.0),
28+
hardhat (>= 1.3.1.9000),
2929
lifecycle,
3030
magrittr,
3131
pillar,
@@ -77,4 +77,6 @@ Config/testthat/edition: 3
7777
Encoding: UTF-8
7878
LazyData: true
7979
Roxygen: list(markdown = TRUE)
80+
Remotes:
81+
tidymodels/hardhat
8082
RoxygenNote: 7.3.1

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ S3method(check_args,svm_linear)
2929
S3method(check_args,svm_poly)
3030
S3method(check_args,svm_rbf)
3131
S3method(extract_fit_engine,model_fit)
32+
S3method(extract_fit_time,model_fit)
3233
S3method(extract_parameter_dials,model_spec)
3334
S3method(extract_parameter_set_dials,model_spec)
3435
S3method(extract_spec_parsnip,model_fit)
@@ -222,6 +223,7 @@ export(discrim_quad)
222223
export(discrim_regularized)
223224
export(eval_args)
224225
export(extract_fit_engine)
226+
export(extract_fit_time)
225227
export(extract_parameter_dials)
226228
export(extract_parameter_set_dials)
227229
export(extract_spec_parsnip)
@@ -376,6 +378,7 @@ importFrom(generics,varying_args)
376378
importFrom(ggplot2,autoplot)
377379
importFrom(glue,glue_collapse)
378380
importFrom(hardhat,extract_fit_engine)
381+
importFrom(hardhat,extract_fit_time)
379382
importFrom(hardhat,extract_parameter_dials)
380383
importFrom(hardhat,extract_parameter_set_dials)
381384
importFrom(hardhat,extract_spec_parsnip)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# parsnip (development version)
22

3+
* New `extract_fit_time()` method has been added that returns the time it took to train the model (#853).
4+
35
# parsnip 1.2.1
46

57
* Added a missing `tidy()` method for survival analysis glmnet models (#1086).

R/extract.R

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,15 @@
1414
#'
1515
#' - `extract_parameter_set_dials()` returns a set of dials parameter objects.
1616
#'
17+
#' - `extract_fit_time()` returns a tibble with fit times. The fit times
18+
#' correspond to the time for the parsnip engine to fit and do not include
19+
#' other portions of the elapsed time in [parsnip::fit.model_spec()].
20+
#'
1721
#' @param x A parsnip `model_fit` object or a parsnip `model_spec` object.
1822
#' @param parameter A single string for the parameter ID.
23+
#' @param summarize A logical for whether the elapsed fit time should be
24+
#' returned as a single row or multiple rows. Doesn't support `FALSE` for
25+
#' parsnip models.
1926
#' @param ... Not currently used.
2027
#' @details
2128
#' Extracting the underlying engine fit can be helpful for describing the
@@ -127,3 +134,20 @@ eval_call_info <- function(x) {
127134
extract_parameter_dials.model_spec <- function(x, parameter, ...) {
128135
extract_parameter_dials(extract_parameter_set_dials(x), parameter)
129136
}
137+
138+
#' @export
139+
#' @rdname extract-parsnip
140+
extract_fit_time.model_fit <- function(x, summarize = TRUE, ...) {
141+
elapsed <- x[["elapsed"]][["elapsed"]][["elapsed"]]
142+
143+
if (is.na(elapsed) || is.null(elapsed)) {
144+
rlang::abort(
145+
"This model was fit before `extract_fit_time()` was added."
146+
)
147+
}
148+
149+
dplyr::tibble(
150+
stage_id = class(x$spec)[1],
151+
elapsed = elapsed
152+
)
153+
}

R/fit.R

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -453,8 +453,15 @@ allow_sparse <- function(x) {
453453
#' @export
454454
print.model_fit <- function(x, ...) {
455455
cat("parsnip model object\n\n")
456-
if (!is.na(x$elapsed[["elapsed"]])) {
457-
cat("Fit time: ", prettyunits::pretty_sec(x$elapsed[["elapsed"]]), "\n")
456+
457+
if (is.null(x$elapsed$print) && !is.na(x$elapsed[["elapsed"]])) {
458+
elapsed <- x$elapsed[["elapsed"]]
459+
cat("Fit time: ", prettyunits::pretty_sec(elapsed), "\n")
460+
}
461+
462+
if (isTRUE(x$elapsed$print)) {
463+
elapsed <- x$elapsed$elapsed[["elapsed"]]
464+
cat("Fit time: ", prettyunits::pretty_sec(elapsed), "\n")
458465
}
459466

460467
if (inherits(x$fit, "try-error")) {

R/fit_helpers.R

Lines changed: 13 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -44,29 +44,19 @@ form_form <-
4444
spec = object
4545
)
4646

47-
if (control$verbosity > 1L) {
48-
elapsed <- system.time(
49-
res$fit <- eval_mod(
50-
fit_call,
51-
capture = control$verbosity == 0,
52-
catch = control$catch,
53-
envir = env,
54-
...
55-
),
56-
gcFirst = FALSE
57-
)
58-
} else {
47+
elapsed <- system.time(
5948
res$fit <- eval_mod(
6049
fit_call,
6150
capture = control$verbosity == 0,
6251
catch = control$catch,
6352
envir = env,
6453
...
65-
)
66-
elapsed <- list(elapsed = NA_real_)
67-
}
54+
),
55+
gcFirst = FALSE
56+
)
6857
res$preproc <- list(y_var = all.vars(rlang::f_lhs(env$formula)))
69-
res$elapsed <- elapsed
58+
res$elapsed <- list(elapsed = elapsed, print = control$verbosity > 1L)
59+
7060
res
7161
}
7262

@@ -102,35 +92,24 @@ xy_xy <- function(object, env, control, target = "none", ...) {
10292

10393
res <- list(lvl = levels(env$y), spec = object)
10494

105-
if (control$verbosity > 1L) {
106-
elapsed <- system.time(
107-
res$fit <- eval_mod(
108-
fit_call,
109-
capture = control$verbosity == 0,
110-
catch = control$catch,
111-
envir = env,
112-
...
113-
),
114-
gcFirst = FALSE
115-
)
116-
} else {
95+
elapsed <- system.time(
11796
res$fit <- eval_mod(
11897
fit_call,
11998
capture = control$verbosity == 0,
12099
catch = control$catch,
121100
envir = env,
122101
...
123-
)
124-
elapsed <- list(elapsed = NA_real_)
125-
}
102+
),
103+
gcFirst = FALSE
104+
)
126105

127106
if (is.atomic(env$y)) {
128107
y_name <- character(0)
129108
} else {
130109
y_name <- colnames(env$y)
131110
}
132111
res$preproc <- list(y_var = y_name)
133-
res$elapsed <- elapsed
112+
res$elapsed <- list(elapsed = elapsed, print = control$verbosity > 1L)
134113
res
135114
}
136115

@@ -176,9 +155,9 @@ xy_form <- function(object, env, control, ...) {
176155
check_outcome(env$y, object)
177156

178157
encoding_info <- get_encoding(class(object)[1])
179-
encoding_info <-
158+
encoding_info <-
180159
vctrs::vec_slice(
181-
encoding_info,
160+
encoding_info,
182161
encoding_info$mode == object$mode & encoding_info$engine == object$engine
183162
)
184163

R/reexports.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,6 @@ hardhat::frequency_weights
5858
#' @export
5959
hardhat::importance_weights
6060

61+
#' @importFrom hardhat extract_fit_time
62+
#' @export
63+
hardhat::extract_fit_time

man/extract-parsnip.Rd

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/reexports.Rd

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/_snaps/extract.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,11 @@
1818
i The parsnip extension package baguette implements support for this specification.
1919
i Please install (if needed) and load to continue.
2020

21+
# extract_fit_time() works
22+
23+
Code
24+
extract_fit_time(lm_fit)
25+
Condition
26+
Error in `extract_fit_time()`:
27+
! This model was fit before `extract_fit_time()` was added.
28+

tests/testthat/test_extract.R

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,22 @@ test_that("extract_parameter_dials doesn't error if namespaced args are used", {
9595
NA
9696
)
9797
})
98+
99+
test_that("extract_fit_time() works", {
100+
lm_fit <- linear_reg() %>% fit(mpg ~ ., data = mtcars)
101+
102+
res <- extract_fit_time(lm_fit)
103+
104+
expect_true(is_tibble(res))
105+
expect_identical(names(res), c("stage_id", "elapsed"))
106+
expect_identical(res$stage_id, "linear_reg")
107+
expect_true(is.double(res$elapsed))
108+
expect_true(res$elapsed >= 0)
109+
110+
lm_fit$elapsed$elapsed <- NULL
111+
112+
expect_snapshot(
113+
error = TRUE,
114+
extract_fit_time(lm_fit)
115+
)
116+
})

0 commit comments

Comments
 (0)