Skip to content

Commit 74439c6

Browse files
authored
speed up get_model_spec() helper (#901)
* test `get_model_spec()` helper * speed up `get_model_spec()` helper * accommodate 0-length `libs`/`fit` ``` r boop <- list(a = 1) bench::mark( base = if (length(boop) > 0) {boop[[1]]} else {NULL}, purrr = purrr::pluck(boop, 1), dplyr = dplyr::first(boop), iterations = 100 ) #> # A tibble: 3 × 6 #> expression min median `itr/sec` mem_alloc `gc/sec` #> <bch:expr> <bch:tm> <bch:tm> <dbl> <bch:byt> <dbl> #> 1 base 122.94ns 163.91ns 3179808. 0B 0 #> 2 purrr 2.01µs 2.52µs 305755. 22.66KB 0 #> 3 dplyr 5.29µs 6.07µs 141894. 8.12MB 0 ```
1 parent 0127f2d commit 74439c6

File tree

2 files changed

+49
-19
lines changed

2 files changed

+49
-19
lines changed

R/translate.R

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -106,25 +106,21 @@ get_model_spec <- function(model, mode, engine) {
106106
env_obj <- grep(model, env_obj, value = TRUE)
107107

108108
res <- list()
109-
res$libs <-
110-
rlang::env_get(m_env, paste0(model, "_pkgs")) %>%
111-
dplyr::filter(engine == !!engine) %>%
112-
purrr::pluck("pkg") %>%
113-
purrr::pluck(1)
114-
115-
res$fit <-
116-
rlang::env_get(m_env, paste0(model, "_fit")) %>%
117-
dplyr::filter(mode == !!mode & engine == !!engine) %>%
118-
dplyr::pull(value) %>%
119-
purrr::pluck(1)
120-
121-
pred_code <-
122-
rlang::env_get(m_env, paste0(model, "_predict")) %>%
123-
dplyr::filter(mode == !!mode & engine == !!engine) %>%
124-
dplyr::select(-engine, -mode)
125-
126-
res$pred <- pred_code[["value"]]
127-
names(res$pred) <- pred_code$type
109+
110+
libs <- rlang::env_get(m_env, paste0(model, "_pkgs"))
111+
libs <- vctrs::vec_slice(libs$pkg, libs$engine == engine)
112+
res$libs <- if (length(libs) > 0) {libs[[1]]} else {NULL}
113+
114+
fits <- rlang::env_get(m_env, paste0(model, "_fit"))
115+
fits <- vctrs::vec_slice(fits$value, fits$mode == mode & fits$engine == engine)
116+
res$fit <- if (length(fits) > 0) {fits[[1]]} else {NULL}
117+
118+
preds <- rlang::env_get(m_env, paste0(model, "_predict"))
119+
where <- preds$mode == mode & preds$engine == engine
120+
types <- vctrs::vec_slice(preds$type, where)
121+
values <- vctrs::vec_slice(preds$value, where)
122+
names(values) <- types
123+
res$pred <- values
128124

129125
res
130126
}

tests/testthat/test_translate.R

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,4 +309,38 @@ test_that("translate tuning paramter names", {
309309
expect_snapshot_error(.model_param_name_key(1))
310310
})
311311

312+
# ------------------------------------------------------------------------------
313+
314+
test_that("get_model_spec helper", {
315+
mod1 <- get_model_spec("linear_reg", "regression", "lm")
316+
317+
expect_type(mod1, "list")
318+
319+
expect_type(mod1$libs, "character")
320+
expect_length(mod1$libs, 1)
321+
expect_equal(mod1$libs, "stats")
322+
323+
expect_type(mod1$fit, "list")
324+
expect_length(mod1$fit, 4)
325+
expect_equal(names(mod1$fit), c("interface", "protect", "func", "defaults"))
312326

327+
expect_type(mod1$pred, "list")
328+
expect_length(mod1$pred, 4)
329+
expect_equal(names(mod1$pred), c("numeric", "conf_int", "pred_int", "raw"))
330+
331+
expect_type(mod1$pred$numeric, "list")
332+
expect_length(mod1$pred$numeric, 4)
333+
expect_equal(names(mod1$pred$numeric), c("pre", "post", "func", "args"))
334+
335+
expect_type(mod1$pred$conf_int, "list")
336+
expect_length(mod1$pred$conf_int, 4)
337+
expect_equal(names(mod1$pred$conf_int), c("pre", "post", "func", "args"))
338+
339+
expect_type(mod1$pred$pred_int, "list")
340+
expect_length(mod1$pred$pred_int, 4)
341+
expect_equal(names(mod1$pred$pred_int), c("pre", "post", "func", "args"))
342+
343+
expect_type(mod1$pred$raw, "list")
344+
expect_length(mod1$pred$raw, 4)
345+
expect_equal(names(mod1$pred$raw), c("pre", "post", "func", "args"))
346+
})

0 commit comments

Comments
 (0)