Skip to content

Commit fc56834

Browse files
committed
Update test suite to support vectorization on complex-only functions
1 parent 8658a0e commit fc56834

20 files changed

+200
-90
lines changed

test/unit/math/mix/fun/acos_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ TEST(mathMixMatFun, acos) {
1212
// can't autodiff acos through integers
1313
for (auto x : stan::test::internal::common_nonzero_args())
1414
stan::test::expect_unary_vectorized(f, x);
15-
expect_unary_vectorized<stan::test::PromoteToComplex::No>(
15+
expect_unary_vectorized<stan::test::ScalarSupport::Real>(
1616
f, -2.2, -0.8, 0.5, 1 + std::numeric_limits<double>::epsilon(), 1.5, 3,
1717
3.4, 4);
1818
for (double re : std::vector<double>{-0.2, 0, 0.3}) {

test/unit/math/mix/fun/acosh_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ TEST(mathMixMatFun, acosh) {
88
};
99
for (double x : stan::test::internal::common_args())
1010
stan::test::expect_unary_vectorized(f, x);
11-
stan::test::expect_unary_vectorized<stan::test::PromoteToComplex::Yes>(
12-
f, 1.5, 3.2, 5, 10, 12.9);
11+
stan::test::expect_unary_vectorized<
12+
stan::test::ScalarSupport::RealAndComplex>(f, 1.5, 3.2, 5, 10, 12.9);
1313
// avoid pole at complex zero that can't be autodiffed
1414
for (double re : std::vector<double>{-0.2, 0, 0.3}) {
1515
for (double im : std::vector<double>{-0.3, 0.2}) {

test/unit/math/mix/fun/atan_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ TEST(mathMixMatFun, atan) {
88
return atan(x);
99
};
1010
stan::test::expect_common_nonzero_unary_vectorized<
11-
stan::test::PromoteToComplex::No>(f);
11+
stan::test::ScalarSupport::Real>(f);
1212
stan::test::expect_unary_vectorized(f, -2.6, -2, -0.2, 0.5, 1, 1.3, 1.5, 3);
1313
// avoid 0 imaginary component where autodiff doesn't work
1414
for (double re : std::vector<double>{-0.2, 0, 0.3}) {

test/unit/math/mix/fun/cbrt_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
TEST(mathMixMatFun, cbrt) {
44
auto f = [](const auto& x1) { return stan::math::cbrt(x1); };
55
stan::test::expect_common_nonzero_unary_vectorized<
6-
stan::test::PromoteToComplex::No>(f);
6+
stan::test::ScalarSupport::Real>(f);
77
stan::test::expect_unary_vectorized(f, -2.6, -2, 1, 1.3, 3);
88
}
99

test/unit/math/mix/fun/conj_test.cpp

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,11 @@
33
#include <complex>
44
#include <vector>
55

6-
namespace stan {
7-
namespace math {
8-
/**
9-
* Dummy overload to allow expect_unary_vectorized. It requires real-valued
10-
* overloads.
11-
*/
12-
template <typename V, require_not_st_complex<V>* = nullptr>
13-
inline auto conj(const V& z) {
14-
return z;
15-
}
16-
} // namespace math
17-
} // namespace stan
18-
196
TEST(mathMixMatFun, conj) {
207
auto f = [](const auto& x) { return stan::math::conj(x); };
218
stan::test::expect_complex_common(f);
22-
stan::test::expect_unary_vectorized<stan::test::PromoteToComplex::Yes>(f);
9+
stan::test::expect_unary_vectorized<stan::test::ScalarSupport::ComplexOnly>(
10+
f);
2311
}
2412

2513
template <typename T>

test/unit/math/mix/fun/cos_test.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@ TEST(mathMixMatFun, cos) {
66
return cos(x);
77
};
88
stan::test::expect_common_nonzero_unary_vectorized(f);
9-
stan::test::expect_unary_vectorized<stan::test::PromoteToComplex::Yes>(
10-
f, -2.6, -2, -0.2, -0.5, 0, 1.5, 3, 5, 5.3);
9+
stan::test::expect_unary_vectorized<
10+
stan::test::ScalarSupport::RealAndComplex>(f, -2.6, -2, -0.2, -0.5, 0,
11+
1.5, 3, 5, 5.3);
1112
stan::test::expect_complex_common(f);
1213
}
1314

test/unit/math/mix/fun/cosh_test.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@ TEST(mathMixMatFun, cosh) {
66
return cosh(x1);
77
};
88
stan::test::expect_common_unary_vectorized(f);
9-
stan::test::expect_unary_vectorized<stan::test::PromoteToComplex::Yes>(
10-
f, -2.6, -2, -1.2, -0.2, 0.5, 1, 1.3, 1.5);
9+
stan::test::expect_unary_vectorized<
10+
stan::test::ScalarSupport::RealAndComplex>(f, -2.6, -2, -1.2, -0.2, 0.5,
11+
1, 1.3, 1.5);
1112
stan::test::expect_complex_common(f);
1213
}
1314

test/unit/math/mix/fun/digamma_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
TEST(mathMixMatFun, digamma) {
44
auto f = [](const auto& x1) { return stan::math::digamma(x1); };
55
stan::test::expect_common_nonzero_unary_vectorized<
6-
stan::test::PromoteToComplex::No>(f);
6+
stan::test::ScalarSupport::Real>(f);
77
stan::test::expect_unary_vectorized(f, -25, -10.2, -1.2, -1, 2.3, 5.7);
88
}
99

test/unit/math/mix/fun/exp_test.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@ TEST(mathMixMatFun, exp) {
55
using stan::math::exp;
66
return exp(x);
77
};
8-
stan::test::expect_common_unary_vectorized<stan::test::PromoteToComplex::Yes>(
9-
f);
10-
stan::test::expect_unary_vectorized<stan::test::PromoteToComplex::Yes>(
11-
f, -15.2, -10, -0.5, 0.5, 1, 1.0, 1.3, 5, 10);
8+
stan::test::expect_common_unary_vectorized<
9+
stan::test::ScalarSupport::RealAndComplex>(f);
10+
stan::test::expect_unary_vectorized<
11+
stan::test::ScalarSupport::RealAndComplex>(f, -15.2, -10, -0.5, 0.5, 1,
12+
1.0, 1.3, 5, 10);
1213
stan::test::expect_complex_common(f);
1314

1415
std::vector<double> com_args = stan::test::internal::common_nonzero_args();

test/unit/math/mix/fun/log1m_exp_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
TEST(mathMixMatFun, log1m_exp) {
44
auto f = [](const auto& x1) { return stan::math::log1m_exp(x1); };
55
stan::test::expect_common_nonzero_unary_vectorized<
6-
stan::test::PromoteToComplex::No>(f);
6+
stan::test::ScalarSupport::Real>(f);
77
stan::test::expect_unary_vectorized(f, -14, -12.6, -2, -1, -0.2, -0.5, 1.3,
88
3);
99
}

0 commit comments

Comments
 (0)