Skip to content

Commit cd60e75

Browse files
authored
Merge pull request #2817 from stan-dev/feature/2815-vectorize-conj
Vectorize conj
2 parents 9b393df + 58631ee commit cd60e75

21 files changed

+268
-80
lines changed

stan/math/prim/fun/conj.hpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define STAN_MATH_PRIM_FUN_CONJ_HPP
33

44
#include <complex>
5+
#include <stan/math/prim/meta.hpp>
56

67
namespace stan {
78
namespace math {
@@ -18,6 +19,33 @@ inline std::complex<V> conj(const std::complex<V>& z) {
1819
return std::conj(z);
1920
}
2021

22+
/**
23+
* Return the complex conjugate the Eigen object.
24+
*
25+
* @tparam Eig A type derived from `Eigen::EigenBase`
26+
* @param[in] z argument
27+
* @return complex conjugate of the argument
28+
*/
29+
template <typename Eig, require_eigen_vt<is_complex, Eig>* = nullptr>
30+
inline auto conj(const Eig& z) {
31+
return z.conjugate();
32+
}
33+
34+
/**
35+
* Return the complex conjugate the vector with complex scalar components.
36+
*
37+
* @tparam StdVec A `std::vector` type with complex scalar type
38+
* @param[in] z argument
39+
* @return complex conjugate of the argument
40+
*/
41+
template <typename StdVec, require_std_vector_st<is_complex, StdVec>* = nullptr>
42+
inline auto conj(const StdVec& z) {
43+
promote_scalar_t<scalar_type_t<StdVec>, StdVec> result(z.size());
44+
std::transform(z.begin(), z.end(), result.begin(),
45+
[](auto&& x) { return stan::math::conj(x); });
46+
return result;
47+
}
48+
2149
namespace internal {
2250
/**
2351
* Return the complex conjugate the complex argument.

test/unit/math/mix/fun/acos_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ TEST(mathMixMatFun, acos) {
1212
// can't autodiff acos through integers
1313
for (auto x : stan::test::internal::common_nonzero_args())
1414
stan::test::expect_unary_vectorized(f, x);
15-
expect_unary_vectorized<stan::test::PromoteToComplex::No>(
15+
expect_unary_vectorized<stan::test::ScalarSupport::Real>(
1616
f, -2.2, -0.8, 0.5, 1 + std::numeric_limits<double>::epsilon(), 1.5, 3,
1717
3.4, 4);
1818
for (double re : std::vector<double>{-0.2, 0, 0.3}) {

test/unit/math/mix/fun/acosh_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ TEST(mathMixMatFun, acosh) {
88
};
99
for (double x : stan::test::internal::common_args())
1010
stan::test::expect_unary_vectorized(f, x);
11-
stan::test::expect_unary_vectorized<stan::test::PromoteToComplex::Yes>(
12-
f, 1.5, 3.2, 5, 10, 12.9);
11+
stan::test::expect_unary_vectorized<
12+
stan::test::ScalarSupport::RealAndComplex>(f, 1.5, 3.2, 5, 10, 12.9);
1313
// avoid pole at complex zero that can't be autodiffed
1414
for (double re : std::vector<double>{-0.2, 0, 0.3}) {
1515
for (double im : std::vector<double>{-0.3, 0.2}) {

test/unit/math/mix/fun/atan_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ TEST(mathMixMatFun, atan) {
88
return atan(x);
99
};
1010
stan::test::expect_common_nonzero_unary_vectorized<
11-
stan::test::PromoteToComplex::No>(f);
11+
stan::test::ScalarSupport::Real>(f);
1212
stan::test::expect_unary_vectorized(f, -2.6, -2, -0.2, 0.5, 1, 1.3, 1.5, 3);
1313
// avoid 0 imaginary component where autodiff doesn't work
1414
for (double re : std::vector<double>{-0.2, 0, 0.3}) {

test/unit/math/mix/fun/cbrt_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
TEST(mathMixMatFun, cbrt) {
44
auto f = [](const auto& x1) { return stan::math::cbrt(x1); };
55
stan::test::expect_common_nonzero_unary_vectorized<
6-
stan::test::PromoteToComplex::No>(f);
6+
stan::test::ScalarSupport::Real>(f);
77
stan::test::expect_unary_vectorized(f, -2.6, -2, 1, 1.3, 3);
88
}
99

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,50 @@
1+
#include <stan/math/prim.hpp>
12
#include <test/unit/math/test_ad.hpp>
23
#include <complex>
34
#include <vector>
45

5-
TEST(mixScalFun, conj) {
6-
auto f = [](const auto& x) { return conj(x); };
6+
TEST(mathMixMatFun, conj) {
7+
auto f = [](const auto& x) { return stan::math::conj(x); };
78
stan::test::expect_complex_common(f);
9+
stan::test::expect_unary_vectorized<stan::test::ScalarSupport::ComplexOnly>(
10+
f);
11+
}
12+
13+
template <typename T>
14+
void test_vectorized_conj() {
15+
using stan::math::value_of_rec;
16+
using complex_t = std::complex<T>;
17+
using complex_matrix = Eigen::Matrix<complex_t, -1, -1>;
18+
complex_matrix A(2, 2);
19+
A << complex_t(T(0), T(1)), complex_t(T(2), T(3)), complex_t(T(4), T(5)),
20+
complex_t(T(6), T(7));
21+
auto A_conj = stan::math::conj(A);
22+
EXPECT_MATRIX_COMPLEX_FLOAT_EQ(value_of_rec(A_conj),
23+
value_of_rec(A.conjugate()))
24+
std::vector<complex_t> v{complex_t(T(0), T(1)), complex_t(T(2), T(3)),
25+
complex_t(T(4), T(5)), complex_t(T(6), T(7))};
26+
27+
std::vector<complex_t> v_conj = stan::math::conj(v);
28+
for (int i = 0; i < v.size(); ++i) {
29+
std::complex<double> vi = value_of_rec(v_conj[i]);
30+
std::complex<double> ci = value_of_rec(stan::math::conj(v[i]));
31+
EXPECT_FLOAT_EQ(vi.real(), ci.real());
32+
EXPECT_FLOAT_EQ(vi.imag(), ci.imag());
33+
}
34+
}
35+
36+
TEST(mathMixMatFun, conj_vectorized) {
37+
using d_t = double;
38+
using v_t = stan::math::var;
39+
using fd_t = stan::math::fvar<d_t>;
40+
using ffd_t = stan::math::fvar<fd_t>;
41+
using fv_t = stan::math::fvar<v_t>;
42+
using ffv_t = stan::math::fvar<fv_t>;
43+
44+
test_vectorized_conj<d_t>();
45+
test_vectorized_conj<v_t>();
46+
test_vectorized_conj<fd_t>();
47+
test_vectorized_conj<ffd_t>();
48+
test_vectorized_conj<fv_t>();
49+
test_vectorized_conj<ffv_t>();
850
}

test/unit/math/mix/fun/cos_test.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@ TEST(mathMixMatFun, cos) {
66
return cos(x);
77
};
88
stan::test::expect_common_nonzero_unary_vectorized(f);
9-
stan::test::expect_unary_vectorized<stan::test::PromoteToComplex::Yes>(
10-
f, -2.6, -2, -0.2, -0.5, 0, 1.5, 3, 5, 5.3);
9+
stan::test::expect_unary_vectorized<
10+
stan::test::ScalarSupport::RealAndComplex>(f, -2.6, -2, -0.2, -0.5, 0,
11+
1.5, 3, 5, 5.3);
1112
stan::test::expect_complex_common(f);
1213
}
1314

test/unit/math/mix/fun/cosh_test.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@ TEST(mathMixMatFun, cosh) {
66
return cosh(x1);
77
};
88
stan::test::expect_common_unary_vectorized(f);
9-
stan::test::expect_unary_vectorized<stan::test::PromoteToComplex::Yes>(
10-
f, -2.6, -2, -1.2, -0.2, 0.5, 1, 1.3, 1.5);
9+
stan::test::expect_unary_vectorized<
10+
stan::test::ScalarSupport::RealAndComplex>(f, -2.6, -2, -1.2, -0.2, 0.5,
11+
1, 1.3, 1.5);
1112
stan::test::expect_complex_common(f);
1213
}
1314

test/unit/math/mix/fun/digamma_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
TEST(mathMixMatFun, digamma) {
44
auto f = [](const auto& x1) { return stan::math::digamma(x1); };
55
stan::test::expect_common_nonzero_unary_vectorized<
6-
stan::test::PromoteToComplex::No>(f);
6+
stan::test::ScalarSupport::Real>(f);
77
stan::test::expect_unary_vectorized(f, -25, -10.2, -1.2, -1, 2.3, 5.7);
88
}
99

test/unit/math/mix/fun/exp_test.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@ TEST(mathMixMatFun, exp) {
55
using stan::math::exp;
66
return exp(x);
77
};
8-
stan::test::expect_common_unary_vectorized<stan::test::PromoteToComplex::Yes>(
9-
f);
10-
stan::test::expect_unary_vectorized<stan::test::PromoteToComplex::Yes>(
11-
f, -15.2, -10, -0.5, 0.5, 1, 1.0, 1.3, 5, 10);
8+
stan::test::expect_common_unary_vectorized<
9+
stan::test::ScalarSupport::RealAndComplex>(f);
10+
stan::test::expect_unary_vectorized<
11+
stan::test::ScalarSupport::RealAndComplex>(f, -15.2, -10, -0.5, 0.5, 1,
12+
1.0, 1.3, 5, 10);
1213
stan::test::expect_complex_common(f);
1314

1415
std::vector<double> com_args = stan::test::internal::common_nonzero_args();

0 commit comments

Comments
 (0)