Skip to content

Commit f364b81

Browse files
authored
Merge pull request #489 from tidymodels/pred-type-error
Harmonize error for wrong prediction `type`
2 parents a546656 + 2f00678 commit f364b81

13 files changed

+30
-46
lines changed

R/aaa_models.R

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,12 +286,27 @@ check_pred_info <- function(pred_obj, type) {
286286
invisible(NULL)
287287
}
288288

289+
check_spec_pred_type <- function(object, type) {
290+
possible_preds <- names(object$spec$method$pred)
291+
if (!any(possible_preds == type)) {
292+
rlang::abort(c(
293+
glue::glue("No {type} prediction method available for this model."),
294+
glue::glue("Value for `type` should be one of: ",
295+
glue::glue_collapse(glue::glue("'{possible_preds}'"), sep = ", "))
296+
))
297+
}
298+
invisible(NULL)
299+
}
300+
301+
289302
check_pkg_val <- function(pkg) {
290-
if (rlang::is_missing(pkg) || length(pkg) != 1 || !is.character(pkg))
303+
if (rlang::is_missing(pkg) || length(pkg) != 1 || !is.character(pkg)) {
291304
rlang::abort("Please supply a single character value for the package name.")
305+
}
292306
invisible(NULL)
293307
}
294308

309+
295310
check_interface_val <- function(x) {
296311
exp_interf <- c("data.frame", "formula", "matrix")
297312
if (length(x) != 1 || !(x %in% exp_interf)) {

R/predict_class.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ predict_class.model_fit <- function(object, new_data, ...) {
1212
if (object$spec$mode != "classification")
1313
rlang::abort("`predict.model_fit()` is for predicting factor outcomes.")
1414

15-
if (!any(names(object$spec$method$pred) == "class"))
16-
rlang::abort("No class prediction module defined for this model.")
15+
check_spec_pred_type(object, "class")
1716

1817
if (inherits(object$fit, "try-error")) {
1918
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_classprob.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ predict_classprob.model_fit <- function(object, new_data, ...) {
99
if (object$spec$mode != "classification")
1010
rlang::abort("`predict.model_fit()` is for predicting factor outcomes.")
1111

12-
if (!any(names(object$spec$method$pred) == "prob"))
13-
rlang::abort("No class probability module defined for this model.")
12+
check_spec_pred_type(object, "prob")
13+
1414

1515
if (inherits(object$fit, "try-error")) {
1616
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_hazard.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
predict_hazard.model_fit <-
88
function(object, new_data, .time, ...) {
99

10-
if (is.null(object$spec$method$pred$hazard))
11-
rlang::abort("No hazard prediction method defined for this engine.")
10+
check_spec_pred_type(object, "hazard")
1211

1312
if (inherits(object$fit, "try-error")) {
1413
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_interval.R

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
#' @export
1111
predict_confint.model_fit <- function(object, new_data, level = 0.95, std_error = FALSE, ...) {
1212

13-
if (is.null(object$spec$method$pred$conf_int))
14-
rlang::abort("No confidence interval method defined for this engine.")
13+
check_spec_pred_type(object, "conf_int")
1514

1615
if (inherits(object$fit, "try-error")) {
1716
rlang::warn("Model fit failed; cannot make predictions.")
@@ -58,8 +57,7 @@ predict_confint <- function(object, ...)
5857
# @export
5958
predict_predint.model_fit <- function(object, new_data, level = 0.95, std_error = FALSE, ...) {
6059

61-
if (is.null(object$spec$method$pred$pred_int))
62-
rlang::abort("No prediction interval method defined for this engine.")
60+
check_spec_pred_type(object, "pred_int")
6361

6462
if (inherits(object$fit, "try-error")) {
6563
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_linear_pred.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
#' @export
77
predict_linear_pred.model_fit <- function(object, new_data, ...) {
88

9-
if (!any(names(object$spec$method$pred) == "linear_pred"))
10-
rlang::abort("No prediction module defined for this model.")
9+
check_spec_pred_type(object, "linear_pred")
1110

1211
if (inherits(object$fit, "try-error")) {
1312
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_numeric.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ predict_numeric.model_fit <- function(object, new_data, ...) {
1010
"Use `predict_class()` or `predict_classprob()` for ",
1111
"classification models."))
1212

13-
if (!any(names(object$spec$method$pred) == "numeric"))
14-
rlang::abort("No prediction module defined for this model.")
13+
check_spec_pred_type(object, "numeric")
1514

1615
if (inherits(object$fit, "try-error")) {
1716
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_quantile.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
predict_quantile.model_fit <-
1010
function(object, new_data, quantile = (1:9)/10, ...) {
1111

12-
if (is.null(object$spec$method$pred$quantile))
13-
rlang::abort("No quantile prediction method defined for this engine.")
12+
check_spec_pred_type(object, "quantile")
1413

1514
if (inherits(object$fit, "try-error")) {
1615
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_raw.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ predict_raw.model_fit <- function(object, new_data, opts = list(), ...) {
1313
c(object$spec$method$pred$raw$args, opts)
1414
}
1515

16-
if (!any(names(object$spec$method$pred) == "raw"))
17-
rlang::abort("No raw prediction module defined for this model.")
16+
check_spec_pred_type(object, "raw")
1817

1918
if (inherits(object$fit, "try-error")) {
2019
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_survival.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
predict_survival.model_fit <-
88
function(object, new_data, .time, ...) {
99

10-
if (is.null(object$spec$method$pred$survival))
11-
rlang::abort("No survival prediction method defined for this engine.")
10+
check_spec_pred_type(object, "survival")
1211

1312
if (inherits(object$fit, "try-error")) {
1413
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_time.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ predict_time.model_fit <- function(object, new_data, ...) {
1010
"Use `predict_class()` or `predict_classprob()` for ",
1111
"classification models."))
1212

13-
if (!any(names(object$spec$method$pred) == "time"))
14-
rlang::abort("No prediction module defined for this model.")
13+
check_spec_pred_type(object, "time")
1514

1615
if (inherits(object$fit, "try-error")) {
1716
rlang::warn("Model fit failed; cannot make predictions.")

R/svm_linear_data.R

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -123,27 +123,6 @@ set_pred(
123123
)
124124
)
125125
)
126-
set_pred(
127-
model = "svm_linear",
128-
eng = "LiblineaR",
129-
mode = "classification",
130-
type = "prob",
131-
value = list(
132-
pre = function(x, object) {
133-
rlang::abort(
134-
paste0("The LiblineaR engine does not support class probabilities ",
135-
"for any `svm` models.")
136-
)
137-
},
138-
post = NULL,
139-
func = c(fun = "predict"),
140-
args =
141-
list(
142-
object = quote(object$fit),
143-
newx = expr(as.matrix(new_data))
144-
)
145-
)
146-
)
147126
set_pred(
148127
model = "svm_linear",
149128
eng = "LiblineaR",

tests/testthat/test_svm_linear.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,12 +280,12 @@ test_that('linear svm classification prediction: LiblineaR', {
280280

281281
expect_error(
282282
predict(cls_form, hpc_no_m[ind, -5], type = "prob"),
283-
"The LiblineaR engine does not support class probabilities"
283+
"No prob prediction method available for this model"
284284
)
285285

286286
expect_error(
287287
predict(cls_xy_form, hpc_no_m[ind, -5], type = "prob"),
288-
"The LiblineaR engine does not support class probabilities"
288+
"No prob prediction method available for this model"
289289
)
290290

291291
})

0 commit comments

Comments
 (0)