Skip to content

prompt on unavailable extension packages #732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
May 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
4b04aa6
add machinery to prompt on unavailable extension package
simonpcouch May 22, 2022
3bd57a8
adjust conditions to `show_call`
simonpcouch May 22, 2022
e703371
add tests + update snapshots
simonpcouch May 22, 2022
d33c2eb
update NEWS
simonpcouch May 22, 2022
17b8203
pkgdown fix
simonpcouch May 22, 2022
31db1aa
Revert "adjust conditions to `show_call`"
simonpcouch May 22, 2022
6174fa1
adjust conditions to `show_call`
simonpcouch May 22, 2022
f3ff5f1
`print` -> `rlang::inform`
simonpcouch May 22, 2022
8c4ea8e
adjust `has_loaded_implementation` interface
simonpcouch May 22, 2022
7070827
update tests
simonpcouch May 22, 2022
8843bcf
address check warnings/notes
simonpcouch May 22, 2022
7eb956c
revise prompt, update snapshots
simonpcouch May 23, 2022
1ca9e60
one more `new = FALSE` in `new_model_spec` calls
simonpcouch May 23, 2022
ff84494
address unstated examples dependencies check
simonpcouch May 23, 2022
50180e7
rephrase NEWS bullet
simonpcouch May 23, 2022
8978637
better utilize `glue::glue` and `rlang::inform`
simonpcouch May 25, 2022
f9c7e77
refactor printing conditional into helper
simonpcouch May 25, 2022
51ef493
append vertical spacing to message
simonpcouch May 25, 2022
4dd0fb9
rename `new_model_spec` arg: `new` -> `check_missing_spec`
simonpcouch May 27, 2022
6c06315
`prompt_missing_implementation` -> `inform_missing_implementation`
simonpcouch May 27, 2022
9918043
assign `models.tsv` to internal object
simonpcouch May 27, 2022
51ffc03
add test for message on unknown implementation
simonpcouch May 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

* `predict(type = "prob")` will now provide an error if the outcome variable has a level called `"class"` (#720).

* Model type functions will now message informatively if a needed parsnip extension package is not loaded (#731).

# parsnip 0.2.1

* Fixed a major bug in spark models induced in the previous version (#671).
Expand Down
3 changes: 2 additions & 1 deletion R/arguments.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ set_args <- function(object, ...) {
eng_args = object$eng_args,
mode = object$mode,
method = NULL,
engine = object$engine
engine = object$engine,
check_missing_spec = FALSE
)
}

Expand Down
2 changes: 1 addition & 1 deletion R/bag_mars.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ print.bag_mars <- function(x, ...) {
cat("Bagged MARS Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if (!is.null(x$method$fit$args)) {
if (is_printable_spec(x)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down
2 changes: 1 addition & 1 deletion R/bag_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ print.bag_tree <- function(x, ...) {
cat("Bagged Decision Tree Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if (!is.null(x$method$fit$args)) {
if (is_printable_spec(x)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down
2 changes: 1 addition & 1 deletion R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ print.bart <- function(x, ...) {
cat("BART Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if(!is.null(x$method$fit$args)) {
if (is_printable_spec(x)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down
2 changes: 1 addition & 1 deletion R/boost_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ print.boost_tree <- function(x, ...) {
cat("Boosted Tree Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if (!is.null(x$method$fit$args)) {
if (is_printable_spec(x)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down
2 changes: 1 addition & 1 deletion R/c5_rules.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ print.C5_rules <- function(x, ...) {
cat("C5.0 Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if (!is.null(x$method$fit$args)) {
if (is_printable_spec(x)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down
2 changes: 1 addition & 1 deletion R/cubist_rules.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ print.cubist_rules <- function(x, ...) {
cat("Cubist Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if (!is.null(x$method$fit$args)) {
if (is_printable_spec(x)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down
2 changes: 1 addition & 1 deletion R/decision_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ print.decision_tree <- function(x, ...) {
cat("Decision Tree Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if (!is.null(x$method$fit$args)) {
if (is_printable_spec(x)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down
2 changes: 1 addition & 1 deletion R/discrim_flexible.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ print.discrim_flexible <- function(x, ...) {
cat("Flexible Discriminant Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if (!is.null(x$method$fit$args)) {
if (is_printable_spec(x)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down
2 changes: 1 addition & 1 deletion R/discrim_linear.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ print.discrim_linear <- function(x, ...) {
cat("Linear Discriminant Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if (!is.null(x$method$fit$args)) {
if (is_printable_spec(x)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down
2 changes: 1 addition & 1 deletion R/discrim_quad.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ print.discrim_quad <- function(x, ...) {
cat("Quadratic Discriminant Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if (!is.null(x$method$fit$args)) {
if (is_printable_spec(x)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down
2 changes: 1 addition & 1 deletion R/discrim_regularized.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ print.discrim_regularized <- function(x, ...) {
cat("Regularized Discriminant Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if (!is.null(x$method$fit$args)) {
if (is_printable_spec(x)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down
5 changes: 4 additions & 1 deletion R/engine_docs.R
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,10 @@ make_engine_list <- function(mod) {

get_default_engine <- function(mod, pkg = "parsnip") {
cl <- rlang::call2(mod, .ns = pkg)
rlang::eval_tidy(cl)$engine
suppressMessages(
res <- rlang::eval_tidy(cl)$engine
)
res
}

#' @export
Expand Down
3 changes: 2 additions & 1 deletion R/engines.R
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ set_engine <- function(object, engine, ...) {
eng_args = enquos(...),
mode = object$mode,
method = NULL,
engine = object$engine
engine = object$engine,
check_missing_spec = FALSE
)
}

Expand Down
2 changes: 1 addition & 1 deletion R/gen_additive_mod.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ print.gen_additive_mod <- function(x, ...) {
cat("GAM Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if(!is.null(x$method$fit$args)) {
if (is_printable_spec(x)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down
2 changes: 1 addition & 1 deletion R/linear_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ print.linear_reg <- function(x, ...) {
cat("Linear Regression Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if (!is.null(x$method$fit$args)) {
if (is_printable_spec(x)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down
2 changes: 1 addition & 1 deletion R/logistic_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ print.logistic_reg <- function(x, ...) {
cat("Logistic Regression Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if(!is.null(x$method$fit$args)) {
if (is_printable_spec(x)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down
2 changes: 1 addition & 1 deletion R/mars.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ print.mars <- function(x, ...) {
cat("MARS Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if(!is.null(x$method$fit$args)) {
if (is_printable_spec(x)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down
96 changes: 82 additions & 14 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,78 @@ model_printer <- function(x, ...) {
is_missing_arg <- function(x)
identical(x, quote(missing_arg()))

model_info_table <-
utils::read.delim(system.file("models.tsv", package = "parsnip"))

# given a model object, return TRUE if:
# * the model is supported without extensions
# * the model needs an extension and it is loaded
#
# return FALSE if:
# * the model needs an extension and it is _not_ loaded
has_loaded_implementation <- function(spec_, engine_, mode_) {
if (isFALSE(mode_ %in% c("regression", "censored regression", "classification"))) {
mode_ <- c("regression", "censored regression", "classification")
}
eng_cond <- if (is.null(engine_)) {TRUE} else {quote(engine == engine_)}

avail <-
get_from_env(spec_) %>%
dplyr::filter(mode %in% mode_, !!eng_cond)
pars <-
model_info_table %>%
dplyr::filter(model == spec_, !!eng_cond, mode %in% mode_, is.na(pkg))

if (nrow(pars) > 0 || nrow(avail) > 0) {
return(TRUE)
}

FALSE
}

is_printable_spec <- function(x) {
!is.null(x$method$fit$args) &&
has_loaded_implementation(class(x)[1], x$engine, x$mode)
}

# construct a message informing the user that there are no
# implementations for the current model spec / mode / engine.
#
# if there's a "pre-registered" extension supporting that setup,
# nudge the user to install/load it.
inform_missing_implementation <- function(spec_, engine_, mode_) {
avail <-
show_engines(spec_) %>%
dplyr::filter(mode == mode_, engine == engine_)
all <-
model_info_table %>%
dplyr::filter(model == spec_, mode == mode_, engine == engine_, !is.na(pkg)) %>%
dplyr::select(-model)

if (identical(mode_, "unknown")) {
mode_ <- ""
}

msg <-
glue::glue(
"parsnip could not locate an implementation for `{spec_}` {mode_} model \\
specifications using the `{engine_}` engine."
)

if (nrow(avail) == 0 && nrow(all) > 0) {
msg <-
c(
msg,
i = paste0("The parsnip extension package ", all$pkg[[1]],
" implements support for this specification."),
i = "Please install (if needed) and load to continue.",
""
)
}

msg
}


#' Print the model call
#'
Expand All @@ -89,18 +161,10 @@ is_missing_arg <- function(x)
show_call <- function(object) {
object$method$fit$args <-
map(object$method$fit$args, convert_arg)
if (
is.null(object$method$fit$func["pkg"]) ||
is.na(object$method$fit$func["pkg"])
) {
res <- call2(object$method$fit$func["fun"], !!!object$method$fit$args)
} else {
res <-
call2(object$method$fit$func["fun"],
!!!object$method$fit$args,
.ns = object$method$fit$func["pkg"])
}
res

call2(object$method$fit$func["fun"],
!!!object$method$fit$args,
.ns = object$method$fit$func["pkg"])
}

convert_arg <- function(x) {
Expand All @@ -110,7 +174,6 @@ convert_arg <- function(x) {
x
}


levels_from_formula <- function(f, dat) {
if (inherits(dat, "tbl_spark"))
res <- NULL
Expand Down Expand Up @@ -181,10 +244,15 @@ update_dot_check <- function(...) {
#' @export
#' @keywords internal
#' @rdname add_on_exports
new_model_spec <- function(cls, args, eng_args, mode, method, engine) {
new_model_spec <- function(cls, args, eng_args, mode, method, engine,
check_missing_spec = TRUE) {

check_spec_mode_engine_val(cls, engine, mode)

if ((!has_loaded_implementation(cls, engine, mode)) && check_missing_spec) {
rlang::inform(inform_missing_implementation(cls, engine, mode))
}

out <- list(args = args, eng_args = eng_args,
mode = mode, method = method, engine = engine)
class(out) <- make_classes(cls)
Expand Down
2 changes: 1 addition & 1 deletion R/mlp.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ print.mlp <- function(x, ...) {
cat("Single Layer Neural Network Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if(!is.null(x$method$fit$args)) {
if (is_printable_spec(x)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down
2 changes: 1 addition & 1 deletion R/multinom_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ print.multinom_reg <- function(x, ...) {
cat("Multinomial Regression Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if (!is.null(x$method$fit$args)) {
if (is_printable_spec(x)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down
2 changes: 1 addition & 1 deletion R/naive_Bayes.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ print.naive_Bayes <- function(x, ...) {
cat("Naive Bayes Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if (!is.null(x$method$fit$args)) {
if (is_printable_spec(x)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down
2 changes: 1 addition & 1 deletion R/nearest_neighbor.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ print.nearest_neighbor <- function(x, ...) {
cat("K-Nearest Neighbor Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if(!is.null(x$method$fit$args)) {
if (is_printable_spec(x)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down
2 changes: 1 addition & 1 deletion R/pls.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ print.pls <- function(x, ...) {
cat("PLS Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if (!is.null(x$method$fit$args)) {
if (is_printable_spec(x)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down
2 changes: 1 addition & 1 deletion R/poisson_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ print.poisson_reg <- function(x, ...) {
cat("Poisson Regression Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if (!is.null(x$method$fit$args)) {
if (is_printable_spec(x)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down
2 changes: 1 addition & 1 deletion R/proportional_hazards.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ print.proportional_hazards <- function(x, ...) {
cat("Proportional Hazards Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if (!is.null(x$method$fit$args)) {
if (is_printable_spec(x)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down
2 changes: 1 addition & 1 deletion R/rand_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ print.rand_forest <- function(x, ...) {
cat("Random Forest Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if(!is.null(x$method$fit$args)) {
if (is_printable_spec(x)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down
2 changes: 1 addition & 1 deletion R/rule_fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ print.rule_fit <- function(x, ...) {
cat("RuleFit Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if (!is.null(x$method$fit$args)) {
if (is_printable_spec(x)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down
2 changes: 1 addition & 1 deletion R/surv_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ print.surv_reg <- function(x, ...) {
cat("Parametric Survival Regression Model Specification (", x$mode, ")\n\n", sep = "")
model_printer(x, ...)

if(!is.null(x$method$fit$args)) {
if (is_printable_spec(x)) {
cat("Model fit template:\n")
print(show_call(x))
}
Expand Down
Loading