Skip to content

Commit 348f9fa

Browse files
authored
transition to cli conditions from rlang (#1154)
1 parent ffb7570 commit 348f9fa

File tree

5 files changed

+61
-39
lines changed

5 files changed

+61
-39
lines changed

R/predict_class.R

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,23 @@
99
#' @export predict_class.model_fit
1010
#' @export
1111
predict_class.model_fit <- function(object, new_data, ...) {
12-
if (object$spec$mode != "classification")
13-
rlang::abort("`predict.model_fit()` is for predicting factor outcomes.")
12+
if (object$spec$mode != "classification") {
13+
cli::cli_abort("{.fun predict.model_fit} is for predicting factor outcomes.")
14+
}
1415

1516
check_spec_pred_type(object, "class")
1617

1718
if (inherits(object$fit, "try-error")) {
18-
rlang::warn("Model fit failed; cannot make predictions.")
19+
cli::cli_warn("Model fit failed; cannot make predictions.")
1920
return(NULL)
2021
}
2122

2223
new_data <- prepare_data(object, new_data)
2324

2425
# preprocess data
25-
if (!is.null(object$spec$method$pred$class$pre))
26+
if (!is.null(object$spec$method$pred$class$pre)) {
2627
new_data <- object$spec$method$pred$class$pre(new_data, object)
28+
}
2729

2830
# create prediction call
2931
pred_call <- make_pred_call(object$spec$method$pred$class)
@@ -56,6 +58,6 @@ predict_class.model_fit <- function(object, new_data, ...) {
5658
# @keywords internal
5759
# @rdname other_predict
5860
# @inheritParams predict.model_fit
59-
predict_class <- function(object, ...)
61+
predict_class <- function(object, ...) {
6062
UseMethod("predict_class")
61-
63+
}

R/predict_classprob.R

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,24 @@
55
#' @export predict_classprob.model_fit
66
#' @export
77
predict_classprob.model_fit <- function(object, new_data, ...) {
8-
if (object$spec$mode != "classification")
9-
rlang::abort("`predict.model_fit()` is for predicting factor outcomes.")
8+
if (object$spec$mode != "classification") {
9+
cli::cli_abort("{.fun predict.model_fit()} is for predicting factor outcomes.")
10+
}
1011

1112
check_spec_pred_type(object, "prob")
1213
check_spec_levels(object)
1314

1415
if (inherits(object$fit, "try-error")) {
15-
rlang::warn("Model fit failed; cannot make predictions.")
16+
cli::cli_warn("Model fit failed; cannot make predictions.")
1617
return(NULL)
1718
}
1819

1920
new_data <- prepare_data(object, new_data)
2021

2122
# preprocess data
22-
if (!is.null(object$spec$method$pred$prob$pre))
23+
if (!is.null(object$spec$method$pred$prob$pre)) {
2324
new_data <- object$spec$method$pred$prob$pre(new_data, object)
25+
}
2426

2527
# create prediction call
2628
pred_call <- make_pred_call(object$spec$method$pred$prob)
@@ -33,11 +35,13 @@ predict_classprob.model_fit <- function(object, new_data, ...) {
3335
}
3436

3537
# check and sort names
36-
if (!is.data.frame(res) & !inherits(res, "tbl_spark"))
37-
rlang::abort("The was a problem with the probability predictions.")
38+
if (!is.data.frame(res) & !inherits(res, "tbl_spark")) {
39+
cli::cli_abort("The was a problem with the probability predictions.")
40+
}
3841

39-
if (!is_tibble(res) & !inherits(res, "tbl_spark"))
42+
if (!is_tibble(res) & !inherits(res, "tbl_spark")) {
4043
res <- as_tibble(res)
44+
}
4145

4246
res
4347
}
@@ -46,18 +50,19 @@ predict_classprob.model_fit <- function(object, new_data, ...) {
4650
# @keywords internal
4751
# @rdname other_predict
4852
# @inheritParams predict.model_fit
49-
predict_classprob <- function(object, ...)
53+
predict_classprob <- function(object, ...) {
5054
UseMethod("predict_classprob")
55+
}
5156

5257
check_spec_levels <- function(spec) {
5358
if ("class" %in% spec$lvl) {
54-
rlang::abort(
55-
glue::glue(
56-
"The outcome variable `{spec$preproc$y_var}` has a level called 'class'. ",
57-
"This value is reserved for parsnip's classification internals; please ",
58-
"change the levels, perhaps with `forcats::fct_relevel()`."
59-
),
60-
call = NULL
59+
cli::cli_abort(
60+
c(
61+
"The outcome variable {.var {spec$preproc$y_var}} has a level called {.val class}.",
62+
"i" = "This value is reserved for parsnip's classification internals; please
63+
change the levels, perhaps with {.fn forcats::fct_relevel}.",
64+
call = NULL
65+
)
6166
)
6267
}
6368
}

R/predict_numeric.R

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,35 @@
55
#' @export predict_numeric.model_fit
66
#' @export
77
predict_numeric.model_fit <- function(object, new_data, ...) {
8-
if (object$spec$mode != "regression")
9-
rlang::abort(glue::glue("`predict_numeric()` is for predicting numeric outcomes. ",
10-
"Use `predict_class()` or `predict_classprob()` for ",
11-
"classification models."))
8+
if (object$spec$mode != "regression") {
9+
cli::cli_abort(
10+
c(
11+
"{.fun predict_numeric} is for predicting numeric outcomes.",
12+
"i" = "Use {.fun predict_class} or {.fun predict_classprob} for
13+
classification models."
14+
)
15+
)
16+
}
1217

1318
check_spec_pred_type(object, "numeric")
1419

1520
if (inherits(object$fit, "try-error")) {
16-
rlang::warn("Model fit failed; cannot make predictions.")
21+
cli::cli_warn("Model fit failed; cannot make predictions.")
1722
return(NULL)
1823
}
1924

2025
new_data <- prepare_data(object, new_data)
2126

2227
# preprocess data
23-
if (!is.null(object$spec$method$pred$numeric$pre))
28+
if (!is.null(object$spec$method$pred$numeric$pre)) {
2429
new_data <- object$spec$method$pred$numeric$pre(new_data, object)
30+
}
2531

2632
# create prediction call
2733
pred_call <- make_pred_call(object$spec$method$pred$numeric)
2834

2935
res <- eval_tidy(pred_call)
30-
36+
3137
# post-process the predictions
3238
if (!is.null(object$spec$method$pred$numeric$post)) {
3339
res <- object$spec$method$pred$numeric$post(res, object)
@@ -36,8 +42,9 @@ predict_numeric.model_fit <- function(object, new_data, ...) {
3642
if (is.vector(res)) {
3743
res <- unname(res)
3844
} else {
39-
if (!inherits(res, "tbl_spark"))
45+
if (!inherits(res, "tbl_spark")) {
4046
res <- as.data.frame(res)
47+
}
4148
}
4249
res
4350
}
@@ -47,5 +54,6 @@ predict_numeric.model_fit <- function(object, new_data, ...) {
4754
#' @keywords internal
4855
#' @rdname other_predict
4956
#' @inheritParams predict_numeric.model_fit
50-
predict_numeric <- function(object, ...)
57+
predict_numeric <- function(object, ...) {
5158
UseMethod("predict_numeric")
59+
}

R/predict_time.R

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,35 @@
55
#' @export predict_time.model_fit
66
#' @export
77
predict_time.model_fit <- function(object, new_data, ...) {
8-
if (object$spec$mode != "censored regression")
9-
rlang::abort(glue::glue("`predict_time()` is for predicting time outcomes. ",
10-
"Use `predict_class()` or `predict_classprob()` for ",
11-
"classification models."))
8+
if (object$spec$mode != "censored regression") {
9+
cli::cli_abort(
10+
c(
11+
"{.fun predict_time} is for predicting time outcomes.",
12+
"i" = "Use {.fun predict_class} or {.fun predict_classprob} for
13+
classification models."
14+
)
15+
)
16+
}
1217

1318
check_spec_pred_type(object, "time")
1419

1520
if (inherits(object$fit, "try-error")) {
16-
rlang::warn("Model fit failed; cannot make predictions.")
21+
cli::cli_warn("Model fit failed; cannot make predictions.")
1722
return(NULL)
1823
}
1924

2025
new_data <- prepare_data(object, new_data)
2126

2227
# preprocess data
23-
if (!is.null(object$spec$method$pred$time$pre))
28+
if (!is.null(object$spec$method$pred$time$pre)) {
2429
new_data <- object$spec$method$pred$time$pre(new_data, object)
30+
}
2531

2632
# create prediction call
2733
pred_call <- make_pred_call(object$spec$method$pred$time)
2834

2935
res <- eval_tidy(pred_call)
30-
36+
3137
# post-process the predictions
3238
if (!is.null(object$spec$method$pred$time$post)) {
3339
res <- object$spec$method$pred$time$post(res, object)
@@ -45,5 +51,6 @@ predict_time.model_fit <- function(object, new_data, ...) {
4551
#' @keywords internal
4652
#' @rdname other_predict
4753
#' @inheritParams predict_time.model_fit
48-
predict_time <- function(object, ...)
54+
predict_time <- function(object, ...) {
4955
UseMethod("predict_time")
56+
}

tests/testthat/test-predict_formats.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ test_that('predict(type = "prob") with level "class" (see #720)', {
7676
)
7777

7878
expect_error(
79-
regexp = "variable `boop` has a level called 'class'",
79+
regexp = 'variable `boop` has a level called "class"',
8080
predict(mod, type = "prob", new_data = x)
8181
)
8282
})

0 commit comments

Comments
 (0)