Skip to content

Commit b5b03ea

Browse files
authored
register tune_args() and tunable() methods unconditionally (#869)
* register `tune_args()` and `tunable()` methods unconditionally * correct ref * register methods in the usual way * resituate `tune_args()` import, export neither `tune_args()` nor `tuneable()` generic * revert `required_pkgs()` registration
1 parent 9060711 commit b5b03ea

File tree

6 files changed

+41
-58
lines changed

6 files changed

+41
-58
lines changed

NAMESPACE

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,18 @@ S3method(translate,survival_reg)
9090
S3method(translate,svm_linear)
9191
S3method(translate,svm_poly)
9292
S3method(translate,svm_rbf)
93+
S3method(tunable,boost_tree)
94+
S3method(tunable,decision_tree)
95+
S3method(tunable,linear_reg)
96+
S3method(tunable,logistic_reg)
97+
S3method(tunable,mars)
98+
S3method(tunable,mlp)
99+
S3method(tunable,model_spec)
100+
S3method(tunable,multinomial_reg)
101+
S3method(tunable,rand_forest)
93102
S3method(tunable,survival_reg)
103+
S3method(tunable,svm_poly)
104+
S3method(tune_args,model_spec)
94105
S3method(type_sum,model_fit)
95106
S3method(type_sum,model_spec)
96107
S3method(update,C5_rules)
@@ -317,6 +328,7 @@ importFrom(generics,glance)
317328
importFrom(generics,required_pkgs)
318329
importFrom(generics,tidy)
319330
importFrom(generics,tunable)
331+
importFrom(generics,tune_args)
320332
importFrom(generics,varying_args)
321333
importFrom(ggplot2,autoplot)
322334
importFrom(glue,glue_collapse)

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
* For censored regression models, a "reverse Kaplan-Meier" curve is computed for the censoring distribution. This can be used when evaluating this type of model (#855).
44

5+
* The model specification methods for `generics::tune_args()` and
6+
`generics::tunable()` are now registered unconditionally (tidymodels/workflows#192).
7+
58
# parsnip 1.0.3
69

710
* Adds documentation and tuning infrastructure for the new `flexsurvspline` engine for the `survival_reg()` model specification from the `censored` package (@mattwarkentin, #831).

R/parsnip-package.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
## usethis namespace: start
1111
#' @importFrom dplyr arrange bind_cols bind_rows collect full_join group_by
1212
#' @importFrom dplyr mutate pull rename select starts_with summarise tally
13-
#' @importFrom generics tunable varying_args
13+
#' @importFrom generics tunable varying_args tune_args
1414
#' @importFrom glue glue_collapse
1515
#' @importFrom pillar type_sum
1616
#' @importFrom purrr as_vector imap imap_lgl map map_chr map_dbl map_df map_dfr

R/tunable.R

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
# Lazily registered in .onLoad()
1+
22
# Unit tests are in extratests
33
# nocov start
4-
tunable_model_spec <- function(x, ...) {
4+
5+
#' @export
6+
tunable.model_spec <- function(x, ...) {
57
mod_env <- rlang::ns_env("parsnip")$parsnip
68

79
if (is.null(x$engine)) {
@@ -228,8 +230,8 @@ flexsurvspline_engine_args <-
228230

229231
# ------------------------------------------------------------------------------
230232

231-
# Lazily registered in .onLoad()
232-
tunable_linear_reg <- function(x, ...) {
233+
#' @export
234+
tunable.linear_reg <- function(x, ...) {
233235
res <- NextMethod()
234236
if (x$engine == "glmnet") {
235237
res$call_info[res$name == "mixture"] <-
@@ -240,8 +242,8 @@ tunable_linear_reg <- function(x, ...) {
240242
res
241243
}
242244

243-
# Lazily registered in .onLoad()
244-
tunable_logistic_reg <- function(x, ...) {
245+
#' @export
246+
tunable.logistic_reg <- function(x, ...) {
245247
res <- NextMethod()
246248
if (x$engine == "glmnet") {
247249
res$call_info[res$name == "mixture"] <-
@@ -252,8 +254,8 @@ tunable_logistic_reg <- function(x, ...) {
252254
res
253255
}
254256

255-
# Lazily registered in .onLoad()
256-
tunable_multinomial_reg <- function(x, ...) {
257+
#' @export
258+
tunable.multinomial_reg <- function(x, ...) {
257259
res <- NextMethod()
258260
if (x$engine == "glmnet") {
259261
res$call_info[res$name == "mixture"] <-
@@ -264,8 +266,8 @@ tunable_multinomial_reg <- function(x, ...) {
264266
res
265267
}
266268

267-
# Lazily registered in .onLoad()
268-
tunable_boost_tree <- function(x, ...) {
269+
#' @export
270+
tunable.boost_tree <- function(x, ...) {
269271
res <- NextMethod()
270272
if (x$engine == "xgboost") {
271273
res <- add_engine_parameters(res, xgboost_engine_args)
@@ -287,8 +289,8 @@ tunable_boost_tree <- function(x, ...) {
287289
res
288290
}
289291

290-
# Lazily registered in .onLoad()
291-
tunable_rand_forest <- function(x, ...) {
292+
#' @export
293+
tunable.rand_forest <- function(x, ...) {
292294
res <- NextMethod()
293295
if (x$engine == "ranger") {
294296
res <- add_engine_parameters(res, ranger_engine_args)
@@ -302,17 +304,17 @@ tunable_rand_forest <- function(x, ...) {
302304
res
303305
}
304306

305-
# Lazily registered in .onLoad()
306-
tunable_mars <- function(x, ...) {
307+
#' @export
308+
tunable.mars <- function(x, ...) {
307309
res <- NextMethod()
308310
if (x$engine == "earth") {
309311
res <- add_engine_parameters(res, earth_engine_args)
310312
}
311313
res
312314
}
313315

314-
# Lazily registered in .onLoad()
315-
tunable_decision_tree <- function(x, ...) {
316+
#' @export
317+
tunable.decision_tree <- function(x, ...) {
316318
res <- NextMethod()
317319
if (x$engine == "C5.0") {
318320
res <- add_engine_parameters(res, c5_tree_engine_args)
@@ -325,8 +327,8 @@ tunable_decision_tree <- function(x, ...) {
325327
res
326328
}
327329

328-
# Lazily registered in .onLoad()
329-
tunable_svm_poly <- function(x, ...) {
330+
#' @export
331+
tunable.svm_poly <- function(x, ...) {
330332
res <- NextMethod()
331333
if (x$engine == "kernlab") {
332334
res$call_info[res$name == "degree"] <-
@@ -336,8 +338,8 @@ tunable_svm_poly <- function(x, ...) {
336338
}
337339

338340

339-
# Lazily registered in .onLoad()
340-
tunable_mlp <- function(x, ...) {
341+
#' @export
342+
tunable.mlp <- function(x, ...) {
341343
res <- NextMethod()
342344
if (x$engine == "brulee") {
343345
res <- add_engine_parameters(res, brulee_mlp_engine_args)

R/tune_args.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
2-
# Lazily registered in .onLoad()
3-
tune_args_model_spec <- function(object, full = FALSE, ...) {
1+
#' @method tune_args model_spec
2+
#' @export
3+
tune_args.model_spec <- function(object, full = FALSE, ...) {
44

55
# use the model_spec top level class as the id
66
model_type <- class(object)[1]

R/zzz.R

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,40 +14,6 @@
1414

1515
vctrs::s3_register("ggplot2::autoplot", "model_fit")
1616
vctrs::s3_register("ggplot2::autoplot", "glmnet")
17-
18-
# - If tune isn't installed, register the method (`packageVersion()` will error here)
19-
# - If tune >= 0.1.6.9001 is installed, register the method
20-
should_register_tune_args_method <- tryCatch(
21-
expr = utils::packageVersion("tune") >= "0.1.6.9001",
22-
error = function(cnd) TRUE
23-
)
24-
25-
if (should_register_tune_args_method) {
26-
# `tune_args.model_spec()` moved from tune to parsnip
27-
vctrs::s3_register("generics::tune_args", "model_spec", tune_args_model_spec)
28-
}
29-
30-
# - If tune isn't installed, register the method (`packageVersion()` will error here)
31-
# - If tune >= 0.1.6.9002 is installed, register the method
32-
should_register_tunable_method <- tryCatch(
33-
expr = utils::packageVersion("tune") >= "0.1.6.9002",
34-
error = function(cnd) TRUE
35-
)
36-
37-
if (should_register_tunable_method) {
38-
# `tunable.model_spec()` and friends moved from tune to parsnip
39-
vctrs::s3_register("generics::tunable", "model_spec", tunable_model_spec)
40-
vctrs::s3_register("generics::tunable", "linear_reg", tunable_linear_reg)
41-
vctrs::s3_register("generics::tunable", "logistic_reg", tunable_logistic_reg)
42-
vctrs::s3_register("generics::tunable", "multinomial_reg", tunable_multinomial_reg)
43-
vctrs::s3_register("generics::tunable", "boost_tree", tunable_boost_tree)
44-
vctrs::s3_register("generics::tunable", "rand_forest", tunable_rand_forest)
45-
vctrs::s3_register("generics::tunable", "mars", tunable_mars)
46-
vctrs::s3_register("generics::tunable", "decision_tree", tunable_decision_tree)
47-
vctrs::s3_register("generics::tunable", "svm_poly", tunable_svm_poly)
48-
vctrs::s3_register("generics::tunable", "mlp", tunable_mlp)
49-
}
50-
5117
}
5218

5319
# nocov end

0 commit comments

Comments
 (0)