Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 46 additions & 24 deletions r/tests/testthat/helper-expectation.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,30 +70,52 @@ verify_output <- function(...) {
testthat::verify_output(...)
}

expect_dplyr_equal <- function(expr, # A dplyr pipeline with `input` as its start
tbl, # A tbl/df as reference, will make RB/Table with
skip_record_batch = NULL, # Msg, if should skip RB test
skip_table = NULL, # Msg, if should skip Table test
#' @param expr A dplyr pipeline with `input` as its start
#' @param tbl A tbl/df as reference, will make RB/Table with
#' @param skip_record_batch string skip message, if should skip RB test
#' @param skip_table string skip message, if should skip Table test
#' @param warning string expected warning from the RecordBatch and Table paths,
#' passed to `expect_warning()`. Special values:
#' * `NA` (the default) for ensuring no warning message
#' * `TRUE` is a special case to mean to check for the
#' "not supported in Arrow; pulling data into R" message.
#' @param ... additional arguments, passed to `expect_equivalent()`
expect_dplyr_equal <- function(expr,
tbl,
skip_record_batch = NULL,
skip_table = NULL,
warning = NA,
...) {
expr <- rlang::enquo(expr)
expected <- rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input = tbl)))

if (isTRUE(warning)) {
# Special-case the simple warning:
warning <- "not supported in Arrow; pulling data into R"
}

skip_msg <- NULL

if (is.null(skip_record_batch)) {
via_batch <- rlang::eval_tidy(
expr,
rlang::new_data_mask(rlang::env(input = record_batch(tbl)))
expect_warning(
via_batch <- rlang::eval_tidy(
expr,
rlang::new_data_mask(rlang::env(input = record_batch(tbl)))
),
warning
)
expect_equivalent(via_batch, expected, ...)
} else {
skip_msg <- c(skip_msg, skip_record_batch)
}

if (is.null(skip_table)) {
via_table <- rlang::eval_tidy(
expr,
rlang::new_data_mask(rlang::env(input = Table$create(tbl)))
expect_warning(
via_table <- rlang::eval_tidy(
expr,
rlang::new_data_mask(rlang::env(input = Table$create(tbl)))
),
warning
)
expect_equivalent(via_table, expected, ...)
} else {
Expand All @@ -110,7 +132,7 @@ expect_dplyr_error <- function(expr, # A dplyr pipeline with `input` as its star
...) {
# ensure we have supplied tbl
force(tbl)

expr <- rlang::enquo(expr)
msg <- tryCatch(
rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input = tbl))),
Expand All @@ -126,7 +148,7 @@ expect_dplyr_error <- function(expr, # A dplyr pipeline with `input` as its star
# but what we really care about is the `x` block
# so (temporarily) let's pull those blocks out when we find them
pattern <- i18ize_error_messages()

if (grepl(pattern, msg)) {
msg <- sub(paste0("^.*(", pattern, ").*$"), "\\1", msg)
}
Expand Down Expand Up @@ -179,7 +201,7 @@ expect_vector_equal <- function(expr, # A vectorized R expression containing `in
if (is.null(skip_chunked_array)) {
# split input vector into two to exercise ChunkedArray with >1 chunk
split_vector <- split_vector_as_list(vec)

via_chunked <- rlang::eval_tidy(
expr,
rlang::new_data_mask(rlang::env(input = ChunkedArray$create(split_vector[[1]], split_vector[[2]])))
Expand All @@ -199,29 +221,29 @@ expect_vector_error <- function(expr, # A vectorized R expression containing `in
skip_array = NULL, # Msg, if should skip Array test
skip_chunked_array = NULL, # Msg, if should skip ChunkedArray test
...) {

expr <- rlang::enquo(expr)

msg <- tryCatch(
rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input = vec))),
error = function (e) {
msg <- conditionMessage(e)

pattern <- i18ize_error_messages()

if (grepl(pattern, msg)) {
msg <- sub(paste0("^.*(", pattern, ").*$"), "\\1", msg)
}
msg
}
)

expect_true(identical(typeof(msg), "character"), label = "vector errored")

skip_msg <- NULL

if (is.null(skip_array)) {

expect_error(
rlang::eval_tidy(
expr,
Expand All @@ -233,11 +255,11 @@ expect_vector_error <- function(expr, # A vectorized R expression containing `in
} else {
skip_msg <- c(skip_msg, skip_array)
}

if (is.null(skip_chunked_array)) {
# split input vector into two to exercise ChunkedArray with >1 chunk
split_vector <- split_vector_as_list(vec)

expect_error(
rlang::eval_tidy(
expr,
Expand All @@ -249,7 +271,7 @@ expect_vector_error <- function(expr, # A vectorized R expression containing `in
} else {
skip_msg <- c(skip_msg, skip_chunked_array)
}

if (!is.null(skip_msg)) {
skip(paste(skip_msg, collpase = "\n"))
}
Expand Down
45 changes: 26 additions & 19 deletions r/tests/testthat/test-compute-aggregate.R
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,15 @@ test_that("Edge cases", {
for (type in c(int32(), float64(), bool())) {
expect_equal(as.vector(sum(a$cast(type), na.rm = TRUE)), sum(NA, na.rm = TRUE))
expect_equal(as.vector(mean(a$cast(type), na.rm = TRUE)), mean(NA, na.rm = TRUE))
expect_equal(as.vector(min(a$cast(type), na.rm = TRUE)), min(NA, na.rm = TRUE))
expect_equal(as.vector(max(a$cast(type), na.rm = TRUE)), max(NA, na.rm = TRUE))
expect_equal(
as.vector(min(a$cast(type), na.rm = TRUE)),
# Suppress the base R warning about no non-missing arguments
suppressWarnings(min(NA, na.rm = TRUE))
)
expect_equal(
as.vector(max(a$cast(type), na.rm = TRUE)),
suppressWarnings(max(NA, na.rm = TRUE))
)
}
})

Expand Down Expand Up @@ -342,29 +349,29 @@ test_that("match_arrow", {

ca <- ChunkedArray$create(c(1, 4, 3, 1, 1, 3, 4))
expect_equal(match_arrow(ca, tab), ChunkedArray$create(c(3L, 0L, 1L, 3L, 3L, 1L, 0L)))

sc <- Scalar$create(3)
expect_equal(match_arrow(sc, tab), Scalar$create(1L))

vec <- c(1,2)
expect_equal(match_arrow(vec, tab), Array$create(c(3L, 2L)))

})

test_that("is_in", {
a <- Array$create(c(9, 4, 3))
tab <- c(4, 3, 2, 1)
expect_equal(is_in(a, tab), Array$create(c(FALSE, TRUE, TRUE)))

ca <- ChunkedArray$create(c(9, 4, 3))
expect_equal(is_in(ca, tab), ChunkedArray$create(c(FALSE, TRUE, TRUE)))

sc <- Scalar$create(3)
expect_equal(is_in(sc, tab), Scalar$create(TRUE))

vec <- c(1,9)
expect_equal(is_in(vec, tab), Array$create(c(TRUE, FALSE)))

})

test_that("value_counts", {
Expand All @@ -383,40 +390,40 @@ test_that("value_counts", {
})

test_that("any.Array and any.ChunkedArray", {

data <- c(1:10, NA, NA)

expect_vector_equal(any(input > 5), data)
expect_vector_equal(any(input < 1), data)
expect_vector_equal(any(input < 1, na.rm = TRUE), data)

data_logical <- c(TRUE, FALSE, TRUE, NA, FALSE)

expect_vector_equal(any(input), data_logical)
expect_vector_equal(any(input, na.rm = TRUE), data_logical)

})

test_that("all.Array and all.ChunkedArray", {

data <- c(1:10, NA, NA)

expect_vector_equal(all(input > 5), data)
expect_vector_equal(all(input < 11), data)
expect_vector_equal(all(input < 11, na.rm = TRUE), data)

data_logical <- c(TRUE, TRUE, NA)

expect_vector_equal(all(input), data_logical)
expect_vector_equal(all(input, na.rm = TRUE), data_logical)

})

test_that("variance", {
data <- c(-37, 267, 88, -120, 9, 101, -65, -23, NA)
arr <- Array$create(data)
chunked_arr <- ChunkedArray$create(data)

expect_equal(call_function("variance", arr, options = list(ddof = 5)), Scalar$create(34596))
expect_equal(call_function("variance", chunked_arr, options = list(ddof = 5)), Scalar$create(34596))
})
Expand All @@ -425,7 +432,7 @@ test_that("stddev", {
data <- c(-37, 267, 88, -120, 9, 101, -65, -23, NA)
arr <- Array$create(data)
chunked_arr <- ChunkedArray$create(data)

expect_equal(call_function("stddev", arr, options = list(ddof = 5)), Scalar$create(186))
expect_equal(call_function("stddev", chunked_arr, options = list(ddof = 5)), Scalar$create(186))
})
44 changes: 19 additions & 25 deletions r/tests/testthat/test-dplyr-filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -315,31 +315,25 @@ test_that("Filtering on a column that doesn't exist errors correctly", {
})

test_that("Filtering with unsupported functions", {
expect_warning(
expect_dplyr_equal(
input %>%
filter(int > 2, pnorm(dbl) > .99) %>%
collect(),
tbl
),
'Expression pnorm(dbl) > 0.99 not supported in Arrow; pulling data into R',
fixed = TRUE
)
expect_warning(
expect_dplyr_equal(
input %>%
filter(
nchar(chr, type = "bytes", allowNA = TRUE) == 1, # bad, Arrow msg
int > 2, # good
pnorm(dbl) > .99 # bad, opaque
) %>%
collect(),
tbl
),
'* In nchar(chr, type = "bytes", allowNA = TRUE) == 1, allowNA = TRUE not supported by Arrow
* Expression pnorm(dbl) > 0.99 not supported in Arrow
pulling data into R',
fixed = TRUE
expect_dplyr_equal(
input %>%
filter(int > 2, pnorm(dbl) > .99) %>%
collect(),
tbl,
warning = 'Expression pnorm\\(dbl\\) > 0.99 not supported in Arrow; pulling data into R'
)
expect_dplyr_equal(
input %>%
filter(
nchar(chr, type = "bytes", allowNA = TRUE) == 1, # bad, Arrow msg
int > 2, # good
pnorm(dbl) > .99 # bad, opaque
) %>%
collect(),
tbl,
warning = '\\* In nchar\\(chr, type = "bytes", allowNA = TRUE\\) == 1, allowNA = TRUE not supported by Arrow
\\* Expression pnorm\\(dbl\\) > 0.99 not supported in Arrow
pulling data into R'
)
})

Expand Down
Loading