Skip to content

Commit

Permalink
added copying results back to opencl copy benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
t4c1 committed Dec 30, 2020
1 parent 89bd34d commit 6aa4fa3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
13 changes: 9 additions & 4 deletions benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,14 @@ def benchmark(
benchmark_name = function_name
setup = ""
var_conversions = ""
code = " auto res = stan::math::eval(stan::math::{}(".format(
function_name
)
if opencl in ("copy", "copy_rev") and return_type not in scalar_stan_types:
code = " auto res = stan::math::from_matrix_cl(stan::math::{}(".format(
function_name
)
else:
code = " auto res = stan::math::eval(stan::math::{}(".format(
function_name
)
for (
n,
(arg_overload, cpp_arg_template, stan_arg),
Expand All @@ -402,7 +407,7 @@ def benchmark(
var_name = "arg" + str(n)
make_arg_function = "make_arg"
is_argument_autodiff = "var" in arg_type
is_argument_scalar = stan_arg in ("int", "real", "rng", "ostream_ptr")
is_argument_scalar = stan_arg in scalar_stan_types
value = 0.4
if function_name in special_arg_values:
if isinstance(special_arg_values[function_name][n], str):
Expand Down
2 changes: 2 additions & 0 deletions test/sig_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
"ostream_ptr": "std::ostream*",
}

scalar_stan_types = ("int", "real", "rng", "ostream_ptr")

def get_cpp_type(stan_type):
n_vec = 0
if stan_type.endswith("]"):
Expand Down

0 comments on commit 6aa4fa3

Please sign in to comment.