Skip to content

Commit fbae633

Browse files
committed
do not filter when mode is unknown
1 parent f8505bd commit fbae633

File tree

3 files changed

+93
-54
lines changed

3 files changed

+93
-54
lines changed

R/bag_tree.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ bag_tree <-
4242
eng_args = NULL,
4343
mode = mode,
4444
method = NULL,
45-
engine = engine
45+
engine = engine,
46+
check_missing_spec = FALSE
4647
)
4748
}
4849

R/engines.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ set_engine.model_spec <- function(object, engine, ...) {
132132
mode = object$mode,
133133
method = NULL,
134134
engine = object$engine,
135-
check_missing_spec = FALSE
135+
check_missing_spec = TRUE
136136
)
137137
}
138138

R/misc.R

Lines changed: 90 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,17 @@ make_classes <- function(prefix) {
1515
#' @return If an error is not thrown (from non-empty ellipses), a NULL list.
1616
#' @keywords internal
1717
#' @export
18-
check_empty_ellipse <- function (...) {
18+
check_empty_ellipse <- function(...) {
1919
terms <- quos(...)
20-
if (!is_empty(terms))
20+
if (!is_empty(terms)) {
2121
rlang::abort("Please pass other arguments to the model function via `set_engine()`.")
22+
}
2223
terms
2324
}
2425

25-
is_missing_arg <- function(x)
26+
is_missing_arg <- function(x) {
2627
identical(x, quote(missing_arg()))
28+
}
2729

2830
model_info_table <-
2931
utils::read.delim(system.file("models.tsv", package = "parsnip"))
@@ -38,7 +40,11 @@ has_loaded_implementation <- function(spec_, engine_, mode_) {
3840
if (isFALSE(mode_ %in% c("regression", "censored regression", "classification"))) {
3941
mode_ <- c("regression", "censored regression", "classification")
4042
}
41-
eng_cond <- if (is.null(engine_)) {TRUE} else {quote(engine == engine_)}
43+
eng_cond <- if (is.null(engine_)) {
44+
TRUE
45+
} else {
46+
quote(engine == engine_)
47+
}
4248

4349
avail <-
4450
get_from_env(spec_) %>%
@@ -56,7 +62,7 @@ has_loaded_implementation <- function(spec_, engine_, mode_) {
5662

5763
is_printable_spec <- function(x) {
5864
!is.null(x$method$fit$args) &&
59-
has_loaded_implementation(class(x)[1], x$engine, x$mode)
65+
has_loaded_implementation(class(x)[1], x$engine, x$mode)
6066
}
6167

6268
# construct a message informing the user that there are no
@@ -67,31 +73,52 @@ is_printable_spec <- function(x) {
6773
inform_missing_implementation <- function(spec_, engine_, mode_) {
6874
avail <-
6975
show_engines(spec_) %>%
70-
dplyr::filter(mode == mode_, engine == engine_)
76+
dplyr::filter(engine == engine_)
7177
all <-
7278
model_info_table %>%
73-
dplyr::filter(model == spec_, mode == mode_, engine == engine_, !is.na(pkg)) %>%
79+
dplyr::filter(model == spec_, engine == engine_, !is.na(pkg)) %>%
7480
dplyr::select(-model)
7581

76-
if (identical(mode_, "unknown")) {
77-
mode_ <- ""
82+
if (!identical(mode_, "unknown")) {
83+
avail <- avail %>% dplyr::filter(mode == mode_)
84+
all <- all %>% dplyr::filter(mode == mode_)
85+
msg <- glue::glue(
86+
paste0(
87+
"parsnip could not locate an implementation for {spec_} {mode_} ",
88+
"model specifications using the `{engine_}` engine. "
89+
)
90+
)
91+
} else {
92+
msg <- glue::glue(
93+
paste0(
94+
"parsnip could not locate an implementation for {spec_} ",
95+
"model specifications using the `{engine_}` engine. "
96+
)
97+
)
7898
}
7999

80-
msg <-
81-
glue::glue(
82-
"parsnip could not locate an implementation for `{spec_}` {mode_} model \\
83-
specifications using the `{engine_}` engine."
84-
)
85100

86101
if (nrow(avail) == 0 && nrow(all) > 0) {
87-
msg <-
88-
c(
89-
msg,
90-
i = paste0("The parsnip extension package ", all$pkg[[1]],
91-
" implements support for this specification."),
92-
i = "Please install (if needed) and load to continue.",
93-
""
102+
if (nrow(all) == 1) {
103+
msg <-
104+
c(
105+
i = paste0(
106+
msg,
107+
glue::glue("Please install `{all$pkg[[1]]}` (if needed) and load to continue."),
108+
"\n"
109+
)
110+
)
111+
} else {
112+
msg <- c(
113+
i = paste0(
114+
msg,
115+
"The following parsnip extension packages ",
116+
"implement support for this specification."
117+
),
118+
i = paste0(unique(all$pkg), collapse = ", "),
119+
i = "Please install one of them (if needed) and load to continue.\n"
94120
)
121+
}
95122
}
96123

97124
msg
@@ -109,22 +136,25 @@ show_call <- function(object) {
109136
map(object$method$fit$args, convert_arg)
110137

111138
call2(object$method$fit$func["fun"],
112-
!!!object$method$fit$args,
113-
.ns = object$method$fit$func["pkg"])
139+
!!!object$method$fit$args,
140+
.ns = object$method$fit$func["pkg"]
141+
)
114142
}
115143

116144
convert_arg <- function(x) {
117-
if (is_quosure(x))
145+
if (is_quosure(x)) {
118146
quo_get_expr(x)
119-
else
147+
} else {
120148
x
149+
}
121150
}
122151

123152
levels_from_formula <- function(f, dat) {
124-
if (inherits(dat, "tbl_spark"))
153+
if (inherits(dat, "tbl_spark")) {
125154
res <- NULL
126-
else
155+
} else {
127156
res <- levels(eval_tidy(f[[2]], dat))
157+
}
128158
res
129159
}
130160

@@ -134,7 +164,7 @@ levels_from_formula <- function(f, dat) {
134164
show_fit <- function(model, eng) {
135165
mod <- translate(x = model, engine = eng)
136166
fit_call <- show_call(mod)
137-
call_text <- deparse(fit_call)
167+
call_text <- deparse(fit_call)
138168
call_text <- paste0(call_text, collapse = "\n")
139169
paste0(
140170
"\\preformatted{\n",
@@ -157,9 +187,10 @@ check_args.default <- function(object) {
157187

158188
# copied form recipes
159189

160-
names0 <- function (num, prefix = "x") {
161-
if (num < 1)
190+
names0 <- function(num, prefix = "x") {
191+
if (num < 1) {
162192
rlang::abort("`num` should be > 0.")
193+
}
163194
ind <- format(1:num)
164195
ind <- gsub(" ", "0", ind)
165196
paste0(prefix, ind)
@@ -172,16 +203,16 @@ names0 <- function (num, prefix = "x") {
172203
#' @keywords internal
173204
#' @rdname add_on_exports
174205
update_dot_check <- function(...) {
175-
176206
dots <- enquos(...)
177207

178-
if (length(dots) > 0)
208+
if (length(dots) > 0) {
179209
rlang::abort(
180210
glue::glue(
181211
"Extra arguments will be ignored: ",
182212
glue::glue_collapse(glue::glue("`{names(dots)}`"), sep = ", ")
183213
)
184214
)
215+
}
185216
invisible(NULL)
186217
}
187218

@@ -192,15 +223,16 @@ update_dot_check <- function(...) {
192223
#' @rdname add_on_exports
193224
new_model_spec <- function(cls, args, eng_args, mode, method, engine,
194225
check_missing_spec = TRUE) {
195-
196226
check_spec_mode_engine_val(cls, engine, mode)
197227

198228
if ((!has_loaded_implementation(cls, engine, mode)) && check_missing_spec) {
199229
rlang::inform(inform_missing_implementation(cls, engine, mode))
200230
}
201231

202-
out <- list(args = args, eng_args = eng_args,
203-
mode = mode, method = method, engine = engine)
232+
out <- list(
233+
args = args, eng_args = eng_args,
234+
mode = mode, method = method, engine = engine
235+
)
204236
class(out) <- make_classes(cls)
205237
out
206238
}
@@ -211,8 +243,9 @@ check_outcome <- function(y, spec) {
211243
if (spec$mode == "unknown") {
212244
return(invisible(NULL))
213245
} else if (spec$mode == "regression") {
214-
if (!all(map_lgl(y, is.numeric)))
246+
if (!all(map_lgl(y, is.numeric))) {
215247
rlang::abort("For a regression model, the outcome should be numeric.")
248+
}
216249
} else if (spec$mode == "classification") {
217250
if (!all(map_lgl(y, is.factor))) {
218251
rlang::abort("For a classification model, the outcome should be a factor.")
@@ -250,7 +283,6 @@ check_final_param <- function(x) {
250283
#' @keywords internal
251284
#' @rdname add_on_exports
252285
update_main_parameters <- function(args, param) {
253-
254286
if (length(param) == 0) {
255287
return(args)
256288
}
@@ -263,8 +295,10 @@ update_main_parameters <- function(args, param) {
263295
extra_args <- names(param)[has_extra_args]
264296
if (any(has_extra_args)) {
265297
rlang::abort(
266-
paste("At least one argument is not a main argument:",
267-
paste0("`", extra_args, "`", collapse = ", "))
298+
paste(
299+
"At least one argument is not a main argument:",
300+
paste0("`", extra_args, "`", collapse = ", ")
301+
)
268302
)
269303
}
270304
param <- param[!has_extra_args]
@@ -276,7 +310,6 @@ update_main_parameters <- function(args, param) {
276310
#' @keywords internal
277311
#' @rdname add_on_exports
278312
update_engine_parameters <- function(eng_args, fresh, ...) {
279-
280313
dots <- enquos(...)
281314

282315
## only update from dots when there are eng args in original model spec
@@ -303,16 +336,20 @@ update_engine_parameters <- function(eng_args, fresh, ...) {
303336
stan_conf_int <- function(object, newdata) {
304337
check_installs(list(method = list(libs = "rstanarm")))
305338
if (utils::packageVersion("rstanarm") >= "2.21.1") {
306-
fn <- rlang::call2("posterior_epred", .ns = "rstanarm",
307-
object = expr(object),
308-
newdata = expr(newdata),
309-
seed = expr(sample.int(10^5, 1)))
339+
fn <- rlang::call2("posterior_epred",
340+
.ns = "rstanarm",
341+
object = expr(object),
342+
newdata = expr(newdata),
343+
seed = expr(sample.int(10^5, 1))
344+
)
310345
} else {
311-
fn <- rlang::call2("posterior_linpred", .ns = "rstanarm",
312-
object = expr(object),
313-
newdata = expr(newdata),
314-
transform = TRUE,
315-
seed = expr(sample.int(10^5, 1)))
346+
fn <- rlang::call2("posterior_linpred",
347+
.ns = "rstanarm",
348+
object = expr(object),
349+
newdata = expr(newdata),
350+
transform = TRUE,
351+
seed = expr(sample.int(10^5, 1))
352+
)
316353
}
317354
rlang::eval_tidy(fn)
318355
}
@@ -357,30 +394,31 @@ stan_conf_int <- function(object, newdata) {
357394
#' @keywords internal
358395
#' @export
359396
.check_glmnet_penalty_predict <- function(penalty = NULL, object, multi = FALSE) {
360-
361397
if (is.null(penalty)) {
362398
penalty <- object$fit$lambda
363399
}
364400

365401
# when using `predict()`, allow for a single lambda
366402
if (!multi) {
367-
if (length(penalty) != 1)
403+
if (length(penalty) != 1) {
368404
rlang::abort(
369405
glue::glue(
370406
"`penalty` should be a single numeric value. `multi_predict()` ",
371407
"can be used to get multiple predictions per row of data.",
372408
)
373409
)
410+
}
374411
}
375412

376-
if (length(object$fit$lambda) == 1 && penalty != object$fit$lambda)
413+
if (length(object$fit$lambda) == 1 && penalty != object$fit$lambda) {
377414
rlang::abort(
378415
glue::glue(
379416
"The glmnet model was fit with a single penalty value of ",
380417
"{object$fit$lambda}. Predicting with a value of {penalty} ",
381418
"will give incorrect results from `glmnet()`."
382419
)
383420
)
421+
}
384422

385423
penalty
386424
}

0 commit comments

Comments
 (0)