Skip to content

Commit e0ea0be

Browse files
authored
handle attributes on outcome columns (#1062)
1 parent 8ccf1be commit e0ea0be

File tree

5 files changed

+42
-9
lines changed

5 files changed

+42
-9
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# parsnip (development version)
22

3+
* Tightened logic for outcome checking. This resolves issues—some errors and some silent failures—when atomic outcome variables have an attribute (#1060, #1061).
4+
35
* `rpart_train()` has been deprecated in favor of using `decision_tree()` with the `"rpart"` engine or `rpart::rpart()` directly (#1044).
46

57
* Fixed bug in fitting some model types with the `"spark"` engine (#1045).

R/convert_data.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@
234234
if (is.matrix(y)) {
235235
y <- as.data.frame(y)
236236
} else {
237-
if (is.vector(y) | is.factor(y)) {
237+
if (is.atomic(y)) {
238238
y <- data.frame(y)
239239
names(y) <- y_name
240240
}
@@ -328,7 +328,7 @@ make_formula <- function(x, y, short = TRUE) {
328328

329329

330330
will_make_matrix <- function(y) {
331-
if (is.matrix(y) | is.vector(y) | is.factor(y))
331+
if (is.matrix(y) | is.atomic(y))
332332
return(FALSE)
333333
cls <- unique(unlist(lapply(y, class)))
334334
if (length(cls) > 1)

R/fit.R

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ fit_xy.model_spec <-
267267
}
268268
y_var <- colnames(y)
269269

270-
if (object$engine != "spark" & NCOL(y) == 1 & !(is.vector(y) | is.factor(y))) {
270+
if (object$engine != "spark" & NCOL(y) == 1 & !(is.atomic(y))) {
271271
if (is.matrix(y)) {
272272
y <- y[, 1]
273273
} else {
@@ -411,10 +411,8 @@ check_xy_interface <- function(x, y, cl, model) {
411411
inher(x, c("data.frame", "matrix"), cl)
412412
}
413413

414-
# `y` can be a vector (which is not a class), or a factor or
415-
# Surv object (which are not vectors)
416-
if (!is.null(y) && !is.vector(y))
417-
inher(y, c("data.frame", "matrix", "factor", "Surv"), cl)
414+
if (!is.null(y) && !is.atomic(y))
415+
inher(y, c("data.frame", "matrix"), cl)
418416

419417
# rule out spark data sets that don't use the formula interface
420418
if (inherits(x, "tbl_spark") | inherits(y, "tbl_spark"))

R/fit_helpers.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ xy_xy <- function(object, env, control, target = "none", ...) {
121121
elapsed <- list(elapsed = NA_real_)
122122
}
123123

124-
if (is.vector(env$y)) {
124+
if (is.atomic(env$y)) {
125125
y_name <- character(0)
126126
} else {
127127
y_name <- colnames(env$y)
@@ -199,7 +199,7 @@ xy_form <- function(object, env, control, ...) {
199199
if (!is.null(env$y_var)) {
200200
data_obj$y_var <- env$y_var
201201
} else {
202-
if (is.vector(env$y)) {
202+
if (is.atomic(env$y)) {
203203
data_obj$y_var <- character(0)
204204
}
205205

tests/testthat/test_fit_interfaces.R

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,36 @@ test_that('No loaded engines', {
121121
expect_snapshot_error({poisson_reg() %>% fit(mpg ~., data = mtcars)})
122122
expect_snapshot_error({cubist_rules(engine = "Cubist") %>% fit(mpg ~., data = mtcars)})
123123
})
124+
125+
test_that("fit_xy() can handle attributes on a data.frame outcome (#1060)", {
126+
lr <- linear_reg()
127+
x <- data.frame(x = 1:5)
128+
y <- c(2:5, 5)
129+
130+
expect_silent(res <-
131+
fit_xy(lr, x = x, y = data.frame(y = structure(y, label = "hi")))
132+
)
133+
expect_equal(res[["fit"]], fit_xy(lr, x, y)[["fit"]], ignore_attr = "label")
134+
})
135+
136+
test_that("fit_xy() can handle attributes on an atomic outcome (#1061)", {
137+
lr <- linear_reg()
138+
x <- data.frame(x = 1:5)
139+
y <- c(2:5, 5)
140+
141+
expect_silent(res <- fit_xy(lr, x = x, y = structure(y, label = "hi")))
142+
expect_equal(res[["fit"]], fit_xy(lr, x, y)[["fit"]], ignore_attr = "label")
143+
})
144+
145+
test_that("fit() can handle attributes on a vector outcome", {
146+
lr <- linear_reg()
147+
dat <- data.frame(x = 1:5, y = c(2:5, 5))
148+
dat_attr <- data.frame(x = 1:5, y = structure(c(2:5, 5), label = "hi"))
149+
150+
expect_silent(res <- fit(lr, y ~ x, dat_attr))
151+
expect_equal(
152+
res[["fit"]],
153+
fit(lr, y ~ x, dat)[["fit"]],
154+
ignore_attr = TRUE
155+
)
156+
})

0 commit comments

Comments
 (0)