Skip to content

Commit 2347eeb

Browse files
committed
make predict work with sparse tibbles
1 parent 69cd225 commit 2347eeb

File tree

4 files changed

+44
-1
lines changed

4 files changed

+44
-1
lines changed

NEWS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
* `fit()` and `fit_xy()` can now take sparse tibbles as data values (#1165).
66

7-
* `predict()` can now take dgCMatrix input for `new_data` argument, and error informatively when model doesn't support it (#1167).
7+
* `predict()` can now take dgCMatrix and sparse tibble input for `new_data` argument, and error informatively when model doesn't support it (#1167).
88

99

1010
* Transitioned package errors and warnings to use cli (#1147 and #1148 by

R/sparsevctrs.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ is_sparse_tibble <- function(x) {
2323

2424
materialize_sparse_tibble <- function(x, object, input) {
2525
if ((!allow_sparse(object)) && is_sparse_tibble(x)) {
26+
if (inherits(object, "model_fit")) {
27+
object <- object$spec
28+
}
29+
2630
cli::cli_warn(
2731
"{.arg {input}} is a sparse tibble, but {.fn {class(object)[1]}} with
2832
engine {.code {object$engine}} doesn't accept that. Converting to

tests/testthat/_snaps/sparsevctrs.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@
2222
Error in `to_sparse_data_frame()`:
2323
! `x` is a sparse matrix, but `linear_reg()` with engine `lm` doesn't accept that.
2424

25+
# sparse tibble can be passed to `predict()
26+
27+
Code
28+
preds <- predict(lm_fit, sparse_mtcars)
29+
Condition
30+
Warning:
31+
`x` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse.
32+
2533
# sparse matrices can be passed to `predict()
2634

2735
Code

tests/testthat/test-sparsevctrs.R

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,37 @@ test_that("sparse matrices can be passed to `fit_xy()", {
6767
)
6868
})
6969

70+
test_that("sparse tibble can be passed to `predict()", {
71+
skip_if_not_installed("ranger")
72+
73+
hotel_data <- sparse_hotel_rates()
74+
hotel_data <- sparsevctrs::coerce_to_sparse_tibble(hotel_data)
75+
76+
spec <- rand_forest(trees = 10) %>%
77+
set_mode("regression") %>%
78+
set_engine("ranger")
79+
80+
tree_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
81+
82+
expect_no_error(
83+
predict(tree_fit, hotel_data)
84+
)
85+
86+
spec <- linear_reg() %>%
87+
set_mode("regression") %>%
88+
set_engine("lm")
89+
90+
lm_fit <- fit(spec, mpg ~ ., data = mtcars)
91+
92+
sparse_mtcars <- mtcars %>%
93+
sparsevctrs::coerce_to_sparse_matrix() %>%
94+
sparsevctrs::coerce_to_sparse_tibble()
95+
96+
expect_snapshot(
97+
preds <- predict(lm_fit, sparse_mtcars)
98+
)
99+
})
100+
70101
test_that("sparse matrices can be passed to `predict()", {
71102
skip_if_not_installed("ranger")
72103

0 commit comments

Comments
 (0)