Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions stan/math/prim/fun/conj.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define STAN_MATH_PRIM_FUN_CONJ_HPP

#include <complex>
#include <stan/math/prim/meta.hpp>

namespace stan {
namespace math {
Expand All @@ -18,6 +19,33 @@ inline std::complex<V> conj(const std::complex<V>& z) {
return std::conj(z);
}

/**
* Return the complex conjugate the Eigen object.
*
* @tparam Eig A type derived from `Eigen::EigenBase`
* @param[in] z argument
* @return complex conjugate of the argument
*/
template <typename Eig, require_eigen_vt<is_complex, Eig>* = nullptr>
inline auto conj(const Eig& z) {
return z.conjugate();
}

/**
* Return the complex conjugate the vector with complex scalar components.
*
* @tparam StdVec A `std::vector` type with complex scalar type
* @param[in] z argument
* @return complex conjugate of the argument
*/
template <typename StdVec, require_std_vector_st<is_complex, StdVec>* = nullptr>
inline auto conj(const StdVec& z) {
promote_scalar_t<scalar_type_t<StdVec>, StdVec> result(z.size());
std::transform(z.begin(), z.end(), result.begin(),
[](auto&& x) { return stan::math::conj(x); });
return result;
}

namespace internal {
/**
* Return the complex conjugate the complex argument.
Expand Down
2 changes: 1 addition & 1 deletion test/unit/math/mix/fun/acos_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ TEST(mathMixMatFun, acos) {
// can't autodiff acos through integers
for (auto x : stan::test::internal::common_nonzero_args())
stan::test::expect_unary_vectorized(f, x);
expect_unary_vectorized<stan::test::PromoteToComplex::No>(
expect_unary_vectorized<stan::test::ScalarSupport::Real>(
f, -2.2, -0.8, 0.5, 1 + std::numeric_limits<double>::epsilon(), 1.5, 3,
3.4, 4);
for (double re : std::vector<double>{-0.2, 0, 0.3}) {
Expand Down
4 changes: 2 additions & 2 deletions test/unit/math/mix/fun/acosh_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ TEST(mathMixMatFun, acosh) {
};
for (double x : stan::test::internal::common_args())
stan::test::expect_unary_vectorized(f, x);
stan::test::expect_unary_vectorized<stan::test::PromoteToComplex::Yes>(
f, 1.5, 3.2, 5, 10, 12.9);
stan::test::expect_unary_vectorized<
stan::test::ScalarSupport::RealAndComplex>(f, 1.5, 3.2, 5, 10, 12.9);
// avoid pole at complex zero that can't be autodiffed
for (double re : std::vector<double>{-0.2, 0, 0.3}) {
for (double im : std::vector<double>{-0.3, 0.2}) {
Expand Down
2 changes: 1 addition & 1 deletion test/unit/math/mix/fun/atan_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ TEST(mathMixMatFun, atan) {
return atan(x);
};
stan::test::expect_common_nonzero_unary_vectorized<
stan::test::PromoteToComplex::No>(f);
stan::test::ScalarSupport::Real>(f);
stan::test::expect_unary_vectorized(f, -2.6, -2, -0.2, 0.5, 1, 1.3, 1.5, 3);
// avoid 0 imaginary component where autodiff doesn't work
for (double re : std::vector<double>{-0.2, 0, 0.3}) {
Expand Down
2 changes: 1 addition & 1 deletion test/unit/math/mix/fun/cbrt_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
TEST(mathMixMatFun, cbrt) {
auto f = [](const auto& x1) { return stan::math::cbrt(x1); };
stan::test::expect_common_nonzero_unary_vectorized<
stan::test::PromoteToComplex::No>(f);
stan::test::ScalarSupport::Real>(f);
stan::test::expect_unary_vectorized(f, -2.6, -2, 1, 1.3, 3);
}

Expand Down
46 changes: 44 additions & 2 deletions test/unit/math/mix/fun/conj_test.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,50 @@
#include <stan/math/prim.hpp>
#include <test/unit/math/test_ad.hpp>
#include <complex>
#include <vector>

TEST(mixScalFun, conj) {
auto f = [](const auto& x) { return conj(x); };
TEST(mathMixMatFun, conj) {
auto f = [](const auto& x) { return stan::math::conj(x); };
stan::test::expect_complex_common(f);
stan::test::expect_unary_vectorized<stan::test::ScalarSupport::ComplexOnly>(
f);
}

template <typename T>
void test_vectorized_conj() {
using stan::math::value_of_rec;
using complex_t = std::complex<T>;
using complex_matrix = Eigen::Matrix<complex_t, -1, -1>;
complex_matrix A(2, 2);
A << complex_t(T(0), T(1)), complex_t(T(2), T(3)), complex_t(T(4), T(5)),
complex_t(T(6), T(7));
auto A_conj = stan::math::conj(A);
EXPECT_MATRIX_COMPLEX_FLOAT_EQ(value_of_rec(A_conj),
value_of_rec(A.conjugate()))
std::vector<complex_t> v{complex_t(T(0), T(1)), complex_t(T(2), T(3)),
complex_t(T(4), T(5)), complex_t(T(6), T(7))};

std::vector<complex_t> v_conj = stan::math::conj(v);
for (int i = 0; i < v.size(); ++i) {
std::complex<double> vi = value_of_rec(v_conj[i]);
std::complex<double> ci = value_of_rec(stan::math::conj(v[i]));
EXPECT_FLOAT_EQ(vi.real(), ci.real());
EXPECT_FLOAT_EQ(vi.imag(), ci.imag());
}
}

TEST(mathMixMatFun, conj_vectorized) {
using d_t = double;
using v_t = stan::math::var;
using fd_t = stan::math::fvar<d_t>;
using ffd_t = stan::math::fvar<fd_t>;
using fv_t = stan::math::fvar<v_t>;
using ffv_t = stan::math::fvar<fv_t>;

test_vectorized_conj<d_t>();
test_vectorized_conj<v_t>();
test_vectorized_conj<fd_t>();
test_vectorized_conj<ffd_t>();
test_vectorized_conj<fv_t>();
test_vectorized_conj<ffv_t>();
}
5 changes: 3 additions & 2 deletions test/unit/math/mix/fun/cos_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ TEST(mathMixMatFun, cos) {
return cos(x);
};
stan::test::expect_common_nonzero_unary_vectorized(f);
stan::test::expect_unary_vectorized<stan::test::PromoteToComplex::Yes>(
f, -2.6, -2, -0.2, -0.5, 0, 1.5, 3, 5, 5.3);
stan::test::expect_unary_vectorized<
stan::test::ScalarSupport::RealAndComplex>(f, -2.6, -2, -0.2, -0.5, 0,
1.5, 3, 5, 5.3);
stan::test::expect_complex_common(f);
}

Expand Down
5 changes: 3 additions & 2 deletions test/unit/math/mix/fun/cosh_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ TEST(mathMixMatFun, cosh) {
return cosh(x1);
};
stan::test::expect_common_unary_vectorized(f);
stan::test::expect_unary_vectorized<stan::test::PromoteToComplex::Yes>(
f, -2.6, -2, -1.2, -0.2, 0.5, 1, 1.3, 1.5);
stan::test::expect_unary_vectorized<
stan::test::ScalarSupport::RealAndComplex>(f, -2.6, -2, -1.2, -0.2, 0.5,
1, 1.3, 1.5);
stan::test::expect_complex_common(f);
}

Expand Down
2 changes: 1 addition & 1 deletion test/unit/math/mix/fun/digamma_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
TEST(mathMixMatFun, digamma) {
auto f = [](const auto& x1) { return stan::math::digamma(x1); };
stan::test::expect_common_nonzero_unary_vectorized<
stan::test::PromoteToComplex::No>(f);
stan::test::ScalarSupport::Real>(f);
stan::test::expect_unary_vectorized(f, -25, -10.2, -1.2, -1, 2.3, 5.7);
}

Expand Down
9 changes: 5 additions & 4 deletions test/unit/math/mix/fun/exp_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ TEST(mathMixMatFun, exp) {
using stan::math::exp;
return exp(x);
};
stan::test::expect_common_unary_vectorized<stan::test::PromoteToComplex::Yes>(
f);
stan::test::expect_unary_vectorized<stan::test::PromoteToComplex::Yes>(
f, -15.2, -10, -0.5, 0.5, 1, 1.0, 1.3, 5, 10);
stan::test::expect_common_unary_vectorized<
stan::test::ScalarSupport::RealAndComplex>(f);
stan::test::expect_unary_vectorized<
stan::test::ScalarSupport::RealAndComplex>(f, -15.2, -10, -0.5, 0.5, 1,
1.0, 1.3, 5, 10);
stan::test::expect_complex_common(f);

std::vector<double> com_args = stan::test::internal::common_nonzero_args();
Expand Down
2 changes: 1 addition & 1 deletion test/unit/math/mix/fun/log1m_exp_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
TEST(mathMixMatFun, log1m_exp) {
auto f = [](const auto& x1) { return stan::math::log1m_exp(x1); };
stan::test::expect_common_nonzero_unary_vectorized<
stan::test::PromoteToComplex::No>(f);
stan::test::ScalarSupport::Real>(f);
stan::test::expect_unary_vectorized(f, -14, -12.6, -2, -1, -0.2, -0.5, 1.3,
3);
}
Expand Down
2 changes: 1 addition & 1 deletion test/unit/math/mix/fun/log1m_inv_logit_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
TEST(mathMixMatFun, log1mInvLogit) {
auto f = [](const auto& x1) { return stan::math::log1m_inv_logit(x1); };
stan::test::expect_common_nonzero_unary_vectorized<
stan::test::PromoteToComplex::No>(f);
stan::test::ScalarSupport::Real>(f);
stan::test::expect_unary_vectorized(f, -2.6, -2, -1, -0.5, -0.2, 0.5, 1, 1.3,
3, 5);
}
Expand Down
2 changes: 1 addition & 1 deletion test/unit/math/mix/fun/log1p_exp_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
TEST(mathMixMatFun, log1pExp) {
auto f = [](const auto& x1) { return stan::math::log1p_exp(x1); };
stan::test::expect_common_nonzero_unary_vectorized<
stan::test::PromoteToComplex::No>(f);
stan::test::ScalarSupport::Real>(f);
stan::test::expect_unary_vectorized(f, -2.6, -2, -1, -0.5, -0.2, 0.5, 1.0,
1.3, 2, 3);

Expand Down
6 changes: 3 additions & 3 deletions test/unit/math/mix/fun/sin_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ TEST(mathMixMatFun, sin) {
return sin(x1);
};
stan::test::expect_common_nonzero_unary_vectorized<
stan::test::PromoteToComplex::No>(f);
stan::test::expect_unary_vectorized<stan::test::PromoteToComplex::Yes>(
f, -2.6, -2, -0.2, 3, 5, 5.3);
stan::test::ScalarSupport::Real>(f);
stan::test::expect_unary_vectorized<
stan::test::ScalarSupport::RealAndComplex>(f, -2.6, -2, -0.2, 3, 5, 5.3);
stan::test::expect_complex_common(f);
}

Expand Down
7 changes: 4 additions & 3 deletions test/unit/math/mix/fun/sinh_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ TEST(mathMixMatFun, sinh) {
return sinh(x);
};
stan::test::expect_common_nonzero_unary_vectorized<
stan::test::PromoteToComplex::No>(f);
stan::test::expect_unary_vectorized<stan::test::PromoteToComplex::Yes>(
f, -2, -1.2, -0.5, -0.2, 0.5, 1.3, 1.5, 3);
stan::test::ScalarSupport::Real>(f);
stan::test::expect_unary_vectorized<
stan::test::ScalarSupport::RealAndComplex>(f, -2, -1.2, -0.5, -0.2, 0.5,
1.3, 1.5, 3);
stan::test::expect_complex_common(f);
}

Expand Down
2 changes: 1 addition & 1 deletion test/unit/math/mix/fun/sqrt_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ TEST(mathMixMatFun, sqrt) {
return sqrt(x1);
};
stan::test::expect_common_nonzero_unary_vectorized<
stan::test::PromoteToComplex::No>(f);
stan::test::ScalarSupport::Real>(f);
stan::test::expect_unary_vectorized(f, -6, -5.2, 1.3, 7, 10.7, 36, 1e6);

// undefined with 0 in denominator
Expand Down
6 changes: 3 additions & 3 deletions test/unit/math/mix/fun/tan_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ TEST(mathMixMatFun, tan) {
return tan(x);
};
stan::test::expect_common_nonzero_unary_vectorized<
stan::test::PromoteToComplex::No>(f);
stan::test::expect_unary_vectorized<stan::test::PromoteToComplex::Yes>(
f, -2, -0.5, 0.5, 1.5, 3, 4.4);
stan::test::ScalarSupport::Real>(f);
stan::test::expect_unary_vectorized<
stan::test::ScalarSupport::RealAndComplex>(f, -2, -0.5, 0.5, 1.5, 3, 4.4);
stan::test::expect_complex_common(f);
}

Expand Down
7 changes: 4 additions & 3 deletions test/unit/math/mix/fun/tanh_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ TEST(mathMixMatFun, tanh) {
return tanh(x1);
};
stan::test::expect_common_nonzero_unary_vectorized<
stan::test::PromoteToComplex::No>(f);
stan::test::expect_unary_vectorized<stan::test::PromoteToComplex::Yes>(
f, -2.6, -2, -1.2, -0.5, 0.5, 1.5);
stan::test::ScalarSupport::Real>(f);
stan::test::expect_unary_vectorized<
stan::test::ScalarSupport::RealAndComplex>(f, -2.6, -2, -1.2, -0.5, 0.5,
1.5);
stan::test::expect_complex_common(f);
}

Expand Down
2 changes: 1 addition & 1 deletion test/unit/math/mix/fun/trigamma_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ TEST(mathMixMatFun, trigamma) {
tols.grad_hessian_hessian_ = relative_tolerance(1e-3, 1e-2);
tols.grad_hessian_grad_hessian_ = relative_tolerance(1e-2, 1e-1);

expect_unary_vectorized<stan::test::PromoteToComplex::No>(
expect_unary_vectorized<stan::test::ScalarSupport::Real>(
tols, f, -103.52, -0.9, -0.5, 0, 0.5, 1.3, 5.1, 19.2);

// reduce tol_min for first deriv tests one order, second derivs four orders,
Expand Down
Loading