Skip to content

Commit a9aadfb

Browse files
Merge pull request #1167 from tidymodels/sparse-matrix-predict
Make sure sparse matrices can be used with `predict()`
2 parents bc27e2b + f606403 commit a9aadfb

File tree

6 files changed

+92
-0
lines changed

6 files changed

+92
-0
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
* `fit()` and `fit_xy()` can now take sparse tibbles as data values (#1165).
66

7+
* `predict()` can now take dgCMatrix and sparse tibble input for `new_data` argument, and error informatively when model doesn't support it (#1167).
8+
79
* Transitioned package errors and warnings to use cli (#1147 and #1148 by
810
@shum461, #1153 by @RobLBaker and @wright13, #1154 by @JamesHWade, #1160,
911
#1161, #1081).

R/fit.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,10 @@ check_xy_interface <- function(x, y, cl, model) {
444444
}
445445

446446
allow_sparse <- function(x) {
447+
if (inherits(x, "model_fit")) {
448+
x <- x$spec
449+
}
450+
447451
res <- get_from_env(paste0(class(x)[1], "_encoding"))
448452
all(res$allow_sparse_x[res$engine == x$engine])
449453
}

R/predict.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
160160
}
161161
check_pred_type_dots(object, type, ...)
162162

163+
new_data <- to_sparse_data_frame(new_data, object)
164+
163165
res <- switch(
164166
type,
165167
numeric = predict_numeric(object = object, new_data = new_data, ...),

R/sparsevctrs.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ to_sparse_data_frame <- function(x, object) {
33
if (allow_sparse(object)) {
44
x <- sparsevctrs::coerce_to_sparse_data_frame(x)
55
} else {
6+
if (inherits(object, "model_fit")) {
7+
object <- object$spec
8+
}
9+
610
cli::cli_abort(
711
"{.arg x} is a sparse matrix, but {.fn {class(object)[1]}} with
812
engine {.code {object$engine}} doesn't accept that.")
@@ -19,6 +23,10 @@ is_sparse_tibble <- function(x) {
1923

2024
materialize_sparse_tibble <- function(x, object, input) {
2125
if (is_sparse_tibble(x) && (!allow_sparse(object))) {
26+
if (inherits(object, "model_fit")) {
27+
object <- object$spec
28+
}
29+
2230
cli::cli_warn(
2331
"{.arg {input}} is a sparse tibble, but {.fn {class(object)[1]}} with
2432
engine {.code {object$engine}} doesn't accept that. Converting to

tests/testthat/_snaps/sparsevctrs.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,22 @@
2222
Error in `to_sparse_data_frame()`:
2323
! `x` is a sparse matrix, but `linear_reg()` with engine `lm` doesn't accept that.
2424

25+
# sparse tibble can be passed to `predict()
26+
27+
Code
28+
preds <- predict(lm_fit, sparse_mtcars)
29+
Condition
30+
Warning:
31+
`x` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse.
32+
33+
# sparse matrices can be passed to `predict()
34+
35+
Code
36+
predict(lm_fit, sparse_mtcars)
37+
Condition
38+
Error in `to_sparse_data_frame()`:
39+
! `x` is a sparse matrix, but `linear_reg()` with engine `lm` doesn't accept that.
40+
2541
# to_sparse_data_frame() is used correctly
2642

2743
Code

tests/testthat/test-sparsevctrs.R

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,66 @@ test_that("sparse matrices can be passed to `fit_xy()", {
6767
)
6868
})
6969

70+
test_that("sparse tibble can be passed to `predict()", {
71+
skip_if_not_installed("ranger")
72+
73+
hotel_data <- sparse_hotel_rates()
74+
hotel_data <- sparsevctrs::coerce_to_sparse_tibble(hotel_data)
75+
76+
spec <- rand_forest(trees = 10) %>%
77+
set_mode("regression") %>%
78+
set_engine("ranger")
79+
80+
tree_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
81+
82+
expect_no_error(
83+
predict(tree_fit, hotel_data)
84+
)
85+
86+
spec <- linear_reg() %>%
87+
set_mode("regression") %>%
88+
set_engine("lm")
89+
90+
lm_fit <- fit(spec, mpg ~ ., data = mtcars)
91+
92+
sparse_mtcars <- mtcars %>%
93+
sparsevctrs::coerce_to_sparse_matrix() %>%
94+
sparsevctrs::coerce_to_sparse_tibble()
95+
96+
expect_snapshot(
97+
preds <- predict(lm_fit, sparse_mtcars)
98+
)
99+
})
100+
101+
test_that("sparse matrices can be passed to `predict()", {
102+
skip_if_not_installed("ranger")
103+
104+
hotel_data <- sparse_hotel_rates()
105+
106+
spec <- rand_forest(trees = 10) %>%
107+
set_mode("regression") %>%
108+
set_engine("ranger")
109+
110+
tree_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
111+
112+
expect_no_error(
113+
predict(tree_fit, hotel_data)
114+
)
115+
116+
spec <- linear_reg() %>%
117+
set_mode("regression") %>%
118+
set_engine("lm")
119+
120+
lm_fit <- fit(spec, mpg ~ ., data = mtcars)
121+
122+
sparse_mtcars <- sparsevctrs::coerce_to_sparse_matrix(mtcars)
123+
124+
expect_snapshot(
125+
error = TRUE,
126+
predict(lm_fit, sparse_mtcars)
127+
)
128+
})
129+
70130
test_that("to_sparse_data_frame() is used correctly", {
71131
skip_if_not_installed("xgboost")
72132

0 commit comments

Comments
 (0)