Skip to content

Commit d888a0f

Browse files
committed
get vector working for prim and reverse mode
1 parent 0890920 commit d888a0f

File tree

4 files changed

+28
-17
lines changed

4 files changed

+28
-17
lines changed

stan/math/prim/prob/normal_log.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ namespace math {
2929
* @tparam T_loc Type of location parameter.
3030
*/
3131
template <bool propto, typename T_y, typename T_loc, typename T_scale>
32-
inline return_type_t<T_y, T_loc, T_scale> normal_log(const T_y& y,
32+
inline auto normal_log(const T_y& y,
3333
const T_loc& mu,
3434
const T_scale& sigma) {
3535
return normal_lpdf<propto>(y, mu, sigma);
@@ -39,7 +39,7 @@ inline return_type_t<T_y, T_loc, T_scale> normal_log(const T_y& y,
3939
* @deprecated use <code>normal_lpdf</code>
4040
*/
4141
template <typename T_y, typename T_loc, typename T_scale>
42-
inline return_type_t<T_y, T_loc, T_scale> normal_log(const T_y& y,
42+
inline auto normal_log(const T_y& y,
4343
const T_loc& mu,
4444
const T_scale& sigma) {
4545
return normal_lpdf<false>(y, mu, sigma);

stan/math/prim/prob/normal_lpdf.hpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,12 @@ struct prob_broadcaster<T, require_stan_scalar_t<T>> {
6969
return *this;
7070
}
7171
inline auto ret() noexcept {
72+
static_assert(!is_var<T>::value, "NOOO");
7273
return ret_;
7374
}
74-
template <typename T1>
75-
static auto zero(T1&& /* */) {
76-
return T(0);
75+
template <typename... Types>
76+
static auto zero(int /* */) {
77+
return return_type_t<Types...>(0);
7778
}
7879

7980
};
@@ -86,7 +87,7 @@ struct prob_broadcaster<T, require_eigen_t<T>> {
8687

8788
template <typename EigArr, require_eigen_t<EigArr>* = nullptr>
8889
inline auto operator=(EigArr&& x) {
89-
ret_ = sum(x);
90+
ret_ = x;
9091
return *this;
9192
}
9293

@@ -124,9 +125,9 @@ struct prob_broadcaster<T, require_eigen_t<T>> {
124125
return std::move(ret_);
125126
}
126127

127-
template <typename T1>
128-
static auto zero(T1&& size) {
129-
return Eigen::Array<value_type_t<T>, -1, 1>::Constant(0, size).eval();
128+
template <typename... Types>
129+
static auto zero(int size) {
130+
return Eigen::Array<return_type_t<Types...>, -1, 1>::Constant(0, size).eval();
130131
}
131132

132133
};
@@ -182,10 +183,10 @@ inline auto normal_lpdf(const T_y& y,
182183
using ret_t = prob_return_t<RetType, T_partials_return>;
183184
const size_t N = max_size(y, mu, sigma);
184185
if (size_zero(y, mu, sigma)) {
185-
return ret_t::zero(N);
186+
return ret_t::template zero<T_y, T_loc, T_scale>(N);
186187
}
187188
if (!include_summand<propto, T_y, T_loc, T_scale>::value) {
188-
return ret_t::zero(N);
189+
return ret_t::template zero<T_y, T_loc, T_scale>(N);
189190
}
190191

191192
operands_and_partials<T_y_ref, T_mu_ref, T_sigma_ref> ops_partials(

stan/math/rev/functor/operands_and_partials.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ template <typename Scalar1, typename Scalar2, typename T3, require_var_t<Scalar1
8787
inline void update_adjoints(Scalar1 x, Scalar2 y, const T3& z) noexcept {
8888
x.adj() += sum(z.adj() * y);
8989
}
90-
template <typename Matrix1, typename Matrix2, typename T3
90+
template <typename Matrix1, typename Matrix2, typename T3,
9191
require_rev_matrix_t<Matrix1>* = nullptr,
9292
require_st_arithmetic<Matrix2>* = nullptr,
9393
require_eigen_t<T3>* = nullptr>
@@ -214,19 +214,19 @@ class operands_and_partials<Op1, Op2, Op3, Op4, Op5, var> {
214214
operand5 = edge5_.operand(),
215215
partial5 = edge5_.partial()]() mutable {
216216
if (!is_constant<Op1>::value) {
217-
internal::update_adjoints(operand1, partial1, vi);
217+
internal::update_adjoints(operand1, partial1, ret);
218218
}
219219
if (!is_constant<Op2>::value) {
220-
internal::update_adjoints(operand2, partial2, vi);
220+
internal::update_adjoints(operand2, partial2, ret);
221221
}
222222
if (!is_constant<Op3>::value) {
223-
internal::update_adjoints(operand3, partial3, vi);
223+
internal::update_adjoints(operand3, partial3, ret);
224224
}
225225
if (!is_constant<Op4>::value) {
226-
internal::update_adjoints(operand4, partial4, vi);
226+
internal::update_adjoints(operand4, partial4, ret);
227227
}
228228
if (!is_constant<Op5>::value) {
229-
internal::update_adjoints(operand5, partial5, vi);
229+
internal::update_adjoints(operand5, partial5, ret);
230230
}
231231
});
232232
return plain_type_t<promote_scalar_t<var, T>>(ret);

test/unit/math/rev/prob/normal_log_test.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,13 @@ TEST(ProbDistributionsNormal, intVsDouble) {
2121
EXPECT_FLOAT_EQ(lp1adj, lp2adj);
2222
}
2323
}
24+
25+
26+
TEST(ProbNormal, test_vlpdf) {
27+
using stan::math::var;
28+
Eigen::Matrix<var, -1, 1> Y = Eigen::Matrix<double, -1, 1>::Random(5);
29+
Eigen::Matrix<var, -1, 1> Mu = Eigen::Matrix<double, -1, 1>::Random(5);
30+
Eigen::Matrix<var, -1, 1> Sigma = stan::math::abs(Eigen::Matrix<double, -1, 1>::Random(5));
31+
Eigen::Matrix<var, -1, 1> A = stan::math::normal_lpdf<false, stan::math::ProbReturnType::Vector>(Y, Mu, Sigma);
32+
33+
}

0 commit comments

Comments
 (0)