Skip to content

Commit f66a8f9

Browse files
Merge pull request #1165 from tidymodels/sparse-tibbles-fit
Allow sparse tibbles in `fit()` and `fit_xy()`
2 parents eeaf82f + 60edee7 commit f66a8f9

File tree

8 files changed

+114
-4
lines changed

8 files changed

+114
-4
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
* `fit_xy()` can now take dgCMatrix input for `x` argument (#1121).
44

5+
* `fit()` and `fit_xy()` can now take sparse tibbles as data values (#1165).
6+
57
* Transitioned package errors and warnings to use cli (#1147 and #1148 by
68
@shum461, #1153 by @RobLBaker and @wright13, #1154 by @JamesHWade, #1160,
79
#1161, #1081).

R/arguments.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ make_xy_call <- function(object, target, env) {
264264
none = rlang::expr(x),
265265
data.frame = rlang::expr(maybe_data_frame(x)),
266266
matrix = rlang::expr(maybe_matrix(x)),
267+
dgCMatrix = rlang::expr(maybe_sparse_matrix(x)),
267268
cli::cli_abort("Invalid data type target: {target}.")
268269
)
269270
if (uses_weights) {

R/convert_data.R

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@
4141
indicators = "traditional",
4242
composition = "data.frame",
4343
remove_intercept = TRUE) {
44-
if (!(composition %in% c("data.frame", "matrix"))) {
44+
if (!(composition %in% c("data.frame", "matrix", "dgCMatrix"))) {
4545
cli::cli_abort(
46-
"{.arg composition} should be either {.val data.frame} or {.val matrix}."
46+
"{.arg composition} should be either {.val data.frame}, {.val matrix}, or
47+
{.val dgCMatrix}."
4748
)
4849
}
4950

@@ -122,6 +123,18 @@
122123
xlevels = .getXlevels(mod_terms, mod_frame),
123124
options = options
124125
)
126+
} else if (composition == "dgCMatrix") {
127+
x <- sparsevctrs::coerce_to_sparse_matrix(data)
128+
res <-
129+
list(
130+
x = x,
131+
y = y,
132+
weights = w,
133+
offset = offset,
134+
terms = mod_terms,
135+
xlevels = .getXlevels(mod_terms, mod_frame),
136+
options = options
137+
)
125138
} else {
126139
# Since a matrix is requested, try to convert y but check
127140
# to see if it is possible
@@ -389,7 +402,11 @@ maybe_matrix <- function(x) {
389402
}
390403

391404
maybe_sparse_matrix <- function(x) {
392-
if (any(vapply(x, sparsevctrs::is_sparse_vector, logical(1)))) {
405+
if (methods::is(x, "sparseMatrix")) {
406+
return(x)
407+
}
408+
409+
if (is_sparse_tibble(x)) {
393410
res <- sparsevctrs::coerce_to_sparse_matrix(x)
394411
} else {
395412
res <- as.matrix(x)

R/fit.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ fit.model_spec <-
174174
eval_env$formula <- formula
175175
eval_env$weights <- wts
176176

177+
data <- materialize_sparse_tibble(data, object, "data")
178+
177179
fit_interface <-
178180
check_interface(eval_env$formula, eval_env$data, cl, object)
179181

R/fit_helpers.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,11 @@ form_xy <- function(object, control, env,
129129

130130
indicators <- encoding_info %>% dplyr::pull(predictor_indicators)
131131
remove_intercept <- encoding_info %>% dplyr::pull(remove_intercept)
132+
allow_sparse_x <- encoding_info %>% dplyr::pull(allow_sparse_x)
133+
134+
if (allow_sparse_x && is_sparse_tibble(env$data)) {
135+
target <- "dgCMatrix"
136+
}
132137

133138
data_obj <- .convert_form_to_xy_fit(
134139
formula = env$formula,

R/sparsevctrs.R

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,27 @@ to_sparse_data_frame <- function(x, object) {
77
"{.arg x} is a sparse matrix, but {.fn {class(object)[1]}} with
88
engine {.code {object$engine}} doesn't accept that.")
99
}
10+
} else if (is.data.frame(x)) {
11+
x <- materialize_sparse_tibble(x, object, "x")
12+
}
13+
x
14+
}
15+
16+
is_sparse_tibble <- function(x) {
17+
any(vapply(x, sparsevctrs::is_sparse_vector, logical(1)))
18+
}
19+
20+
materialize_sparse_tibble <- function(x, object, input) {
21+
if ((!allow_sparse(object)) && is_sparse_tibble(x)) {
22+
cli::cli_warn(
23+
"{.arg {input}} is a sparse tibble, but {.fn {class(object)[1]}} with
24+
engine {.code {object$engine}} doesn't accept that. Converting to
25+
non-sparse."
26+
)
27+
for (i in seq_along(ncol(x))) {
28+
# materialize with []
29+
x[[i]] <- x[[i]][]
30+
}
1031
}
1132
x
1233
}

tests/testthat/_snaps/sparsevctrs.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
# sparse tibble can be passed to `fit()
2+
3+
Code
4+
lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ])
5+
Condition
6+
Warning:
7+
`data` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse.
8+
9+
# sparse tibble can be passed to `fit_xy()
10+
11+
Code
12+
lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1])
13+
Condition
14+
Warning:
15+
`x` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse.
16+
117
# sparse matrices can be passed to `fit_xy()
218

319
Code

tests/testthat/test-sparsevctrs.R

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,49 @@
1+
test_that("sparse tibble can be passed to `fit()", {
2+
skip_if_not_installed("xgboost")
3+
4+
hotel_data <- sparse_hotel_rates()
5+
hotel_data <- sparsevctrs::coerce_to_sparse_tibble(hotel_data)
6+
7+
spec <- boost_tree() %>%
8+
set_mode("regression") %>%
9+
set_engine("xgboost")
10+
11+
expect_no_error(
12+
lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data)
13+
)
14+
15+
spec <- linear_reg() %>%
16+
set_mode("regression") %>%
17+
set_engine("lm")
18+
19+
expect_snapshot(
20+
lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ])
21+
)
22+
})
23+
24+
test_that("sparse tibble can be passed to `fit_xy()", {
25+
skip_if_not_installed("xgboost")
26+
27+
hotel_data <- sparse_hotel_rates()
28+
hotel_data <- sparsevctrs::coerce_to_sparse_tibble(hotel_data)
29+
30+
spec <- boost_tree() %>%
31+
set_mode("regression") %>%
32+
set_engine("xgboost")
33+
34+
expect_no_error(
35+
lm_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
36+
)
37+
38+
spec <- linear_reg() %>%
39+
set_mode("regression") %>%
40+
set_engine("lm")
41+
42+
expect_snapshot(
43+
lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1])
44+
)
45+
})
46+
147
test_that("sparse matrices can be passed to `fit_xy()", {
248
skip_if_not_installed("xgboost")
349

@@ -66,7 +112,7 @@ test_that("maybe_sparse_matrix() is used correctly", {
66112

67113
local_mocked_bindings(
68114
maybe_sparse_matrix = function(x) {
69-
if (any(vapply(x, sparsevctrs::is_sparse_vector, logical(1)))) {
115+
if (is_sparse_tibble(x)) {
70116
stop("sparse vectors detected")
71117
} else {
72118
stop("no sparse vectors detected")

0 commit comments

Comments
 (0)