Skip to content

Commit 6820609

Browse files
authored
merge pr #732: prompt on unavailable extension packages
2 parents 64e7125 + 51ffc03 commit 6820609

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+213
-62
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
* `predict(type = "prob")` will now provide an error if the outcome variable has a level called `"class"` (#720).
1414

15+
* Model type functions will now message informatively if a needed parsnip extension package is not loaded (#731).
16+
1517
# parsnip 0.2.1
1618

1719
* Fixed a major bug in spark models induced in the previous version (#671).

R/arguments.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ set_args <- function(object, ...) {
6767
eng_args = object$eng_args,
6868
mode = object$mode,
6969
method = NULL,
70-
engine = object$engine
70+
engine = object$engine,
71+
check_missing_spec = FALSE
7172
)
7273
}
7374

R/bag_mars.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ print.bag_mars <- function(x, ...) {
4747
cat("Bagged MARS Model Specification (", x$mode, ")\n\n", sep = "")
4848
model_printer(x, ...)
4949

50-
if (!is.null(x$method$fit$args)) {
50+
if (is_printable_spec(x)) {
5151
cat("Model fit template:\n")
5252
print(show_call(x))
5353
}

R/bag_tree.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ print.bag_tree <- function(x, ...) {
5151
cat("Bagged Decision Tree Model Specification (", x$mode, ")\n\n", sep = "")
5252
model_printer(x, ...)
5353

54-
if (!is.null(x$method$fit$args)) {
54+
if (is_printable_spec(x)) {
5555
cat("Model fit template:\n")
5656
print(show_call(x))
5757
}

R/bart.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ print.bart <- function(x, ...) {
9595
cat("BART Model Specification (", x$mode, ")\n\n", sep = "")
9696
model_printer(x, ...)
9797

98-
if(!is.null(x$method$fit$args)) {
98+
if (is_printable_spec(x)) {
9999
cat("Model fit template:\n")
100100
print(show_call(x))
101101
}

R/boost_tree.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ print.boost_tree <- function(x, ...) {
8484
cat("Boosted Tree Model Specification (", x$mode, ")\n\n", sep = "")
8585
model_printer(x, ...)
8686

87-
if (!is.null(x$method$fit$args)) {
87+
if (is_printable_spec(x)) {
8888
cat("Model fit template:\n")
8989
print(show_call(x))
9090
}

R/c5_rules.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ print.C5_rules <- function(x, ...) {
6868
cat("C5.0 Model Specification (", x$mode, ")\n\n", sep = "")
6969
model_printer(x, ...)
7070

71-
if (!is.null(x$method$fit$args)) {
71+
if (is_printable_spec(x)) {
7272
cat("Model fit template:\n")
7373
print(show_call(x))
7474
}

R/cubist_rules.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ print.cubist_rules <- function(x, ...) {
9494
cat("Cubist Model Specification (", x$mode, ")\n\n", sep = "")
9595
model_printer(x, ...)
9696

97-
if (!is.null(x$method$fit$args)) {
97+
if (is_printable_spec(x)) {
9898
cat("Model fit template:\n")
9999
print(show_call(x))
100100
}

R/decision_tree.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ print.decision_tree <- function(x, ...) {
5656
cat("Decision Tree Model Specification (", x$mode, ")\n\n", sep = "")
5757
model_printer(x, ...)
5858

59-
if (!is.null(x$method$fit$args)) {
59+
if (is_printable_spec(x)) {
6060
cat("Model fit template:\n")
6161
print(show_call(x))
6262
}

R/discrim_flexible.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ print.discrim_flexible <- function(x, ...) {
5050
cat("Flexible Discriminant Model Specification (", x$mode, ")\n\n", sep = "")
5151
model_printer(x, ...)
5252

53-
if (!is.null(x$method$fit$args)) {
53+
if (is_printable_spec(x)) {
5454
cat("Model fit template:\n")
5555
print(show_call(x))
5656
}

R/discrim_linear.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ print.discrim_linear <- function(x, ...) {
5252
cat("Linear Discriminant Model Specification (", x$mode, ")\n\n", sep = "")
5353
model_printer(x, ...)
5454

55-
if (!is.null(x$method$fit$args)) {
55+
if (is_printable_spec(x)) {
5656
cat("Model fit template:\n")
5757
print(show_call(x))
5858
}

R/discrim_quad.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ print.discrim_quad <- function(x, ...) {
4646
cat("Quadratic Discriminant Model Specification (", x$mode, ")\n\n", sep = "")
4747
model_printer(x, ...)
4848

49-
if (!is.null(x$method$fit$args)) {
49+
if (is_printable_spec(x)) {
5050
cat("Model fit template:\n")
5151
print(show_call(x))
5252
}

R/discrim_regularized.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ print.discrim_regularized <- function(x, ...) {
6767
cat("Regularized Discriminant Model Specification (", x$mode, ")\n\n", sep = "")
6868
model_printer(x, ...)
6969

70-
if (!is.null(x$method$fit$args)) {
70+
if (is_printable_spec(x)) {
7171
cat("Model fit template:\n")
7272
print(show_call(x))
7373
}

R/engine_docs.R

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,10 @@ make_engine_list <- function(mod) {
243243

244244
get_default_engine <- function(mod, pkg = "parsnip") {
245245
cl <- rlang::call2(mod, .ns = pkg)
246-
rlang::eval_tidy(cl)$engine
246+
suppressMessages(
247+
res <- rlang::eval_tidy(cl)$engine
248+
)
249+
res
247250
}
248251

249252
#' @export

R/engines.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ set_engine <- function(object, engine, ...) {
129129
eng_args = enquos(...),
130130
mode = object$mode,
131131
method = NULL,
132-
engine = object$engine
132+
engine = object$engine,
133+
check_missing_spec = FALSE
133134
)
134135
}
135136

R/gen_additive_mod.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ print.gen_additive_mod <- function(x, ...) {
5555
cat("GAM Specification (", x$mode, ")\n\n", sep = "")
5656
model_printer(x, ...)
5757

58-
if(!is.null(x$method$fit$args)) {
58+
if (is_printable_spec(x)) {
5959
cat("Model fit template:\n")
6060
print(show_call(x))
6161
}

R/linear_reg.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ print.linear_reg <- function(x, ...) {
6363
cat("Linear Regression Model Specification (", x$mode, ")\n\n", sep = "")
6464
model_printer(x, ...)
6565

66-
if (!is.null(x$method$fit$args)) {
66+
if (is_printable_spec(x)) {
6767
cat("Model fit template:\n")
6868
print(show_call(x))
6969
}

R/logistic_reg.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ print.logistic_reg <- function(x, ...) {
7070
cat("Logistic Regression Model Specification (", x$mode, ")\n\n", sep = "")
7171
model_printer(x, ...)
7272

73-
if(!is.null(x$method$fit$args)) {
73+
if (is_printable_spec(x)) {
7474
cat("Model fit template:\n")
7575
print(show_call(x))
7676
}

R/mars.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ print.mars <- function(x, ...) {
5454
cat("MARS Model Specification (", x$mode, ")\n\n", sep = "")
5555
model_printer(x, ...)
5656

57-
if(!is.null(x$method$fit$args)) {
57+
if (is_printable_spec(x)) {
5858
cat("Model fit template:\n")
5959
print(show_call(x))
6060
}

R/misc.R

Lines changed: 82 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,78 @@ model_printer <- function(x, ...) {
7979
is_missing_arg <- function(x)
8080
identical(x, quote(missing_arg()))
8181

82+
model_info_table <-
83+
utils::read.delim(system.file("models.tsv", package = "parsnip"))
84+
85+
# given a model object, return TRUE if:
86+
# * the model is supported without extensions
87+
# * the model needs an extension and it is loaded
88+
#
89+
# return FALSE if:
90+
# * the model needs an extension and it is _not_ loaded
91+
has_loaded_implementation <- function(spec_, engine_, mode_) {
92+
if (isFALSE(mode_ %in% c("regression", "censored regression", "classification"))) {
93+
mode_ <- c("regression", "censored regression", "classification")
94+
}
95+
eng_cond <- if (is.null(engine_)) {TRUE} else {quote(engine == engine_)}
96+
97+
avail <-
98+
get_from_env(spec_) %>%
99+
dplyr::filter(mode %in% mode_, !!eng_cond)
100+
pars <-
101+
model_info_table %>%
102+
dplyr::filter(model == spec_, !!eng_cond, mode %in% mode_, is.na(pkg))
103+
104+
if (nrow(pars) > 0 || nrow(avail) > 0) {
105+
return(TRUE)
106+
}
107+
108+
FALSE
109+
}
110+
111+
is_printable_spec <- function(x) {
112+
!is.null(x$method$fit$args) &&
113+
has_loaded_implementation(class(x)[1], x$engine, x$mode)
114+
}
115+
116+
# construct a message informing the user that there are no
117+
# implementations for the current model spec / mode / engine.
118+
#
119+
# if there's a "pre-registered" extension supporting that setup,
120+
# nudge the user to install/load it.
121+
inform_missing_implementation <- function(spec_, engine_, mode_) {
122+
avail <-
123+
show_engines(spec_) %>%
124+
dplyr::filter(mode == mode_, engine == engine_)
125+
all <-
126+
model_info_table %>%
127+
dplyr::filter(model == spec_, mode == mode_, engine == engine_, !is.na(pkg)) %>%
128+
dplyr::select(-model)
129+
130+
if (identical(mode_, "unknown")) {
131+
mode_ <- ""
132+
}
133+
134+
msg <-
135+
glue::glue(
136+
"parsnip could not locate an implementation for `{spec_}` {mode_} model \\
137+
specifications using the `{engine_}` engine."
138+
)
139+
140+
if (nrow(avail) == 0 && nrow(all) > 0) {
141+
msg <-
142+
c(
143+
msg,
144+
i = paste0("The parsnip extension package ", all$pkg[[1]],
145+
" implements support for this specification."),
146+
i = "Please install (if needed) and load to continue.",
147+
""
148+
)
149+
}
150+
151+
msg
152+
}
153+
82154

83155
#' Print the model call
84156
#'
@@ -89,18 +161,10 @@ is_missing_arg <- function(x)
89161
show_call <- function(object) {
90162
object$method$fit$args <-
91163
map(object$method$fit$args, convert_arg)
92-
if (
93-
is.null(object$method$fit$func["pkg"]) ||
94-
is.na(object$method$fit$func["pkg"])
95-
) {
96-
res <- call2(object$method$fit$func["fun"], !!!object$method$fit$args)
97-
} else {
98-
res <-
99-
call2(object$method$fit$func["fun"],
100-
!!!object$method$fit$args,
101-
.ns = object$method$fit$func["pkg"])
102-
}
103-
res
164+
165+
call2(object$method$fit$func["fun"],
166+
!!!object$method$fit$args,
167+
.ns = object$method$fit$func["pkg"])
104168
}
105169

106170
convert_arg <- function(x) {
@@ -110,7 +174,6 @@ convert_arg <- function(x) {
110174
x
111175
}
112176

113-
114177
levels_from_formula <- function(f, dat) {
115178
if (inherits(dat, "tbl_spark"))
116179
res <- NULL
@@ -181,10 +244,15 @@ update_dot_check <- function(...) {
181244
#' @export
182245
#' @keywords internal
183246
#' @rdname add_on_exports
184-
new_model_spec <- function(cls, args, eng_args, mode, method, engine) {
247+
new_model_spec <- function(cls, args, eng_args, mode, method, engine,
248+
check_missing_spec = TRUE) {
185249

186250
check_spec_mode_engine_val(cls, engine, mode)
187251

252+
if ((!has_loaded_implementation(cls, engine, mode)) && check_missing_spec) {
253+
rlang::inform(inform_missing_implementation(cls, engine, mode))
254+
}
255+
188256
out <- list(args = args, eng_args = eng_args,
189257
mode = mode, method = method, engine = engine)
190258
class(out) <- make_classes(cls)

R/mlp.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ print.mlp <- function(x, ...) {
6464
cat("Single Layer Neural Network Specification (", x$mode, ")\n\n", sep = "")
6565
model_printer(x, ...)
6666

67-
if(!is.null(x$method$fit$args)) {
67+
if (is_printable_spec(x)) {
6868
cat("Model fit template:\n")
6969
print(show_call(x))
7070
}

R/multinom_reg.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ print.multinom_reg <- function(x, ...) {
7070
cat("Multinomial Regression Model Specification (", x$mode, ")\n\n", sep = "")
7171
model_printer(x, ...)
7272

73-
if (!is.null(x$method$fit$args)) {
73+
if (is_printable_spec(x)) {
7474
cat("Model fit template:\n")
7575
print(show_call(x))
7676
}

R/naive_Bayes.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ print.naive_Bayes <- function(x, ...) {
4949
cat("Naive Bayes Model Specification (", x$mode, ")\n\n", sep = "")
5050
model_printer(x, ...)
5151

52-
if (!is.null(x$method$fit$args)) {
52+
if (is_printable_spec(x)) {
5353
cat("Model fit template:\n")
5454
print(show_call(x))
5555
}

R/nearest_neighbor.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ print.nearest_neighbor <- function(x, ...) {
6060
cat("K-Nearest Neighbor Model Specification (", x$mode, ")\n\n", sep = "")
6161
model_printer(x, ...)
6262

63-
if(!is.null(x$method$fit$args)) {
63+
if (is_printable_spec(x)) {
6464
cat("Model fit template:\n")
6565
print(show_call(x))
6666
}

R/pls.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ print.pls <- function(x, ...) {
4545
cat("PLS Model Specification (", x$mode, ")\n\n", sep = "")
4646
model_printer(x, ...)
4747

48-
if (!is.null(x$method$fit$args)) {
48+
if (is_printable_spec(x)) {
4949
cat("Model fit template:\n")
5050
print(show_call(x))
5151
}

R/poisson_reg.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ print.poisson_reg <- function(x, ...) {
5656
cat("Poisson Regression Model Specification (", x$mode, ")\n\n", sep = "")
5757
model_printer(x, ...)
5858

59-
if (!is.null(x$method$fit$args)) {
59+
if (is_printable_spec(x)) {
6060
cat("Model fit template:\n")
6161
print(show_call(x))
6262
}

R/proportional_hazards.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ print.proportional_hazards <- function(x, ...) {
5858
cat("Proportional Hazards Model Specification (", x$mode, ")\n\n", sep = "")
5959
model_printer(x, ...)
6060

61-
if (!is.null(x$method$fit$args)) {
61+
if (is_printable_spec(x)) {
6262
cat("Model fit template:\n")
6363
print(show_call(x))
6464
}

R/rand_forest.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ print.rand_forest <- function(x, ...) {
5656
cat("Random Forest Model Specification (", x$mode, ")\n\n", sep = "")
5757
model_printer(x, ...)
5858

59-
if(!is.null(x$method$fit$args)) {
59+
if (is_printable_spec(x)) {
6060
cat("Model fit template:\n")
6161
print(show_call(x))
6262
}

R/rule_fit.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ print.rule_fit <- function(x, ...) {
7070
cat("RuleFit Model Specification (", x$mode, ")\n\n", sep = "")
7171
model_printer(x, ...)
7272

73-
if (!is.null(x$method$fit$args)) {
73+
if (is_printable_spec(x)) {
7474
cat("Model fit template:\n")
7575
print(show_call(x))
7676
}

R/surv_reg.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ print.surv_reg <- function(x, ...) {
5757
cat("Parametric Survival Regression Model Specification (", x$mode, ")\n\n", sep = "")
5858
model_printer(x, ...)
5959

60-
if(!is.null(x$method$fit$args)) {
60+
if (is_printable_spec(x)) {
6161
cat("Model fit template:\n")
6262
print(show_call(x))
6363
}

0 commit comments

Comments
 (0)