Skip to content

Commit 2ead20c

Browse files
Merge pull request #919 from tidymodels/check-outcome
More verbose check_outcome()
2 parents 00f3cdb + 3f924d8 commit 2ead20c

File tree

5 files changed

+87
-5
lines changed

5 files changed

+87
-5
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
* `logistic_reg()` will now warn at `fit()` when the outcome has more than two levels (#545).
2222

23+
* Functions now indicate what class the outcome was if the outcome is the wrong class (#887).
2324

2425
# parsnip 1.0.4
2526

R/misc.R

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,14 +336,22 @@ check_outcome <- function(y, spec) {
336336
if (spec$mode == "regression") {
337337
outcome_is_numeric <- if (is.atomic(y)) {is.numeric(y)} else {all(map_lgl(y, is.numeric))}
338338
if (!outcome_is_numeric) {
339-
rlang::abort("For a regression model, the outcome should be numeric.")
339+
cls <- class(y)[[1]]
340+
abort(paste0(
341+
"For a regression model, the outcome should be `numeric`, ",
342+
"not a `", cls, "`."
343+
))
340344
}
341345
}
342346

343347
if (spec$mode == "classification") {
344348
outcome_is_factor <- if (is.atomic(y)) {is.factor(y)} else {all(map_lgl(y, is.factor))}
345349
if (!outcome_is_factor) {
346-
rlang::abort("For a classification model, the outcome should be a factor.")
350+
cls <- class(y)[[1]]
351+
abort(paste0(
352+
"For a classification model, the outcome should be a `factor`, ",
353+
"not a `", cls, "`."
354+
))
347355
}
348356

349357
if (inherits(spec, "logistic_reg") && is.atomic(y) && length(levels(y)) > 2) {
@@ -361,7 +369,11 @@ check_outcome <- function(y, spec) {
361369
if (spec$mode == "censored regression") {
362370
outcome_is_surv <- inherits(y, "Surv")
363371
if (!outcome_is_surv) {
364-
rlang::abort("For a censored regression model, the outcome should be a `Surv` object.")
372+
cls <- class(y)[[1]]
373+
abort(paste0(
374+
"For a censored regression model, the outcome should be a `Surv` object, ",
375+
"not a `", cls, "`."
376+
))
365377
}
366378
}
367379

tests/testthat/_snaps/misc.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,27 @@
132132
Error in `fn()`:
133133
! Please use `new_data` instead of `newdata`.
134134

135+
# check_outcome works as expected
136+
137+
Code
138+
check_outcome(factor(1:2), reg_spec)
139+
Condition
140+
Error in `check_outcome()`:
141+
! For a regression model, the outcome should be `numeric`, not a `factor`.
142+
143+
---
144+
145+
Code
146+
check_outcome(1:2, class_spec)
147+
Condition
148+
Error in `check_outcome()`:
149+
! For a classification model, the outcome should be a `factor`, not a `integer`.
150+
151+
---
152+
153+
Code
154+
check_outcome(1:2, cens_spec)
155+
Condition
156+
Error in `check_outcome()`:
157+
! For a censored regression model, the outcome should be a `Surv` object, not a `integer`.
158+

tests/testthat/test_misc.R

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,51 @@ test_that('set_engine works as a generic', {
185185
test_that('check_for_newdata points out correct context', {
186186
fn <- function(...) {check_for_newdata(...); invisible()}
187187
expect_snapshot(error = TRUE,
188-
fn(newdata = "boop!")
188+
fn(newdata = "boop!")
189+
)
190+
})
191+
192+
test_that('check_outcome works as expected', {
193+
reg_spec <- linear_reg()
194+
195+
expect_no_error(
196+
check_outcome(1:2, reg_spec)
197+
)
198+
199+
expect_no_error(
200+
check_outcome(mtcars, reg_spec)
201+
)
202+
203+
expect_snapshot(
204+
error = TRUE,
205+
check_outcome(factor(1:2), reg_spec)
206+
)
207+
208+
class_spec <- logistic_reg()
209+
210+
expect_no_error(
211+
check_outcome(factor(1:2), class_spec)
212+
)
213+
214+
expect_no_error(
215+
check_outcome(lapply(mtcars, as.factor), class_spec)
216+
)
217+
218+
expect_snapshot(
219+
error = TRUE,
220+
check_outcome(1:2, class_spec)
221+
)
222+
223+
# Fake specification to avoid having to load {censored}
224+
cens_spec <- logistic_reg()
225+
cens_spec$mode <- "censored regression"
226+
227+
expect_no_error(
228+
check_outcome(survival::Surv(1, 1), cens_spec)
229+
)
230+
231+
expect_snapshot(
232+
error = TRUE,
233+
check_outcome(1:2, cens_spec)
189234
)
190235
})

tests/testthat/test_nearest_neighbor_kknn.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ test_that('kknn execution', {
2424
x = hpc[, num_pred],
2525
y = hpc$input_fields
2626
),
27-
regexp = "outcome should be a factor"
27+
regexp = "outcome should be a `factor`"
2828
)
2929

3030
# nominal

0 commit comments

Comments
 (0)