Skip to content

Commit

Permalink
Test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
andrjohns committed Aug 21, 2024
1 parent 0444cae commit 55c6554
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions tests/testthat/test-opencl.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,40 +74,41 @@ test_that("all methods run with valid opencl_ids", {
fit <- mod$sample(data = testing_data("bernoulli"), opencl_ids = c(0, 0), chains = 1)
)
expect_false(is.null(fit$metadata()$opencl_platform_name))
expect_false(is.null(fit$metadata()$opencl_ids_name))
expect_false(is.null(fit$metadata()$opencl_device_name))

stan_file_gq <- testing_stan_file("bernoulli_ppc")
mod_gq <- cmdstan_model(stan_file = stan_file_gq, cpp_options = list(stan_opencl = TRUE))
expect_gq_output(
fit <- mod_gq$generate_quantities(fitted_params = fit, data = testing_data("bernoulli"), opencl_ids = c(0, 0)),
)
expect_false(is.null(fit$metadata()$opencl_platform_name))
expect_false(is.null(fit$metadata()$opencl_ids_name))
expect_false(is.null(fit$metadata()$opencl_device_name))

expect_sample_output(
fit <- mod$sample(data = testing_data("bernoulli"), opencl_ids = c(0, 0))
)
expect_false(is.null(fit$metadata()$opencl_platform_name))
expect_false(is.null(fit$metadata()$opencl_ids_name))
expect_false(is.null(fit$metadata()$opencl_device_name))

expect_optim_output(
fit <- mod$optimize(data = testing_data("bernoulli"), opencl_ids = c(0, 0))
)
expect_false(is.null(fit$metadata()$opencl_platform_name))
expect_false(is.null(fit$metadata()$opencl_ids_name))
expect_false(is.null(fit$metadata()$opencl_device_name))

expect_vb_output(
fit <- mod$variational(data = testing_data("bernoulli"), opencl_ids = c(0, 0))
)
expect_false(is.null(fit$metadata()$opencl_platform_name))
expect_false(is.null(fit$metadata()$opencl_ids_name))
expect_false(is.null(fit$metadata()$opencl_device_name))
})

test_that("error for runtime selection of OpenCL devices if version less than 2.26", {
skip_if_not(Sys.getenv("CMDSTANR_OPENCL_TESTS") %in% c("1", "true"))
fake_cmdstan_version("2.25.0")

stan_file <- testing_stan_file("bernoulli")
data_list <- testing_data("bernoulli")
mod <- cmdstan_model(stan_file = stan_file, cpp_options = list(stan_opencl = TRUE),
force_recompile = TRUE)
expect_error(
Expand Down

0 comments on commit 55c6554

Please sign in to comment.