Skip to content

Commit 0890920

Browse files
committed
adds internals needed for supporting vector returning lpdfs
1 parent dfa8653 commit 0890920

File tree

5 files changed

+211
-19
lines changed

5 files changed

+211
-19
lines changed

stan/math/prim/functor/operands_and_partials.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ class operands_and_partials {
139139
* @param value the return value of the function we are compressing
140140
* @return the value with its derivative
141141
*/
142-
inline double build(double value) const noexcept { return value; }
142+
template <typename T>
143+
inline auto build(T&& value) const noexcept { return std::forward<T>(value); }
143144

144145
// These will always be 0 size base template instantiations (above).
145146
internal::ops_partials_edge<double, std::decay_t<Op1>> edge1_;

stan/math/prim/prob/normal_log.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ template <bool propto, typename T_y, typename T_loc, typename T_scale>
3232
inline return_type_t<T_y, T_loc, T_scale> normal_log(const T_y& y,
3333
const T_loc& mu,
3434
const T_scale& sigma) {
35-
return normal_lpdf<propto, T_y, T_loc, T_scale>(y, mu, sigma);
35+
return normal_lpdf<propto>(y, mu, sigma);
3636
}
3737

3838
/** \ingroup prob_dists
@@ -42,7 +42,7 @@ template <typename T_y, typename T_loc, typename T_scale>
4242
inline return_type_t<T_y, T_loc, T_scale> normal_log(const T_y& y,
4343
const T_loc& mu,
4444
const T_scale& sigma) {
45-
return normal_lpdf<T_y, T_loc, T_scale>(y, mu, sigma);
45+
return normal_lpdf<false>(y, mu, sigma);
4646
}
4747

4848
} // namespace math

stan/math/prim/prob/normal_lpdf.hpp

Lines changed: 131 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,123 @@
1919
namespace stan {
2020
namespace math {
2121

22+
enum class ProbReturnType {Scalar, Vector};
23+
24+
template <typename T, typename = void>
25+
struct prob_broadcaster;
26+
27+
template <typename T>
28+
struct prob_broadcaster<T, require_stan_scalar_t<T>> {
29+
T ret_;
30+
template <typename EigArr, require_eigen_t<EigArr>* = nullptr>
31+
prob_broadcaster(EigArr&& x) : ret_(sum(std::forward<EigArr>(x))) {}
32+
33+
template <typename Scalar, require_stan_scalar_t<Scalar>* = nullptr>
34+
prob_broadcaster(Scalar&& x) : ret_(x) {}
35+
36+
template <typename EigArr, require_eigen_t<EigArr>* = nullptr>
37+
inline auto operator=(EigArr&& x) {
38+
ret_ = sum(x);
39+
return *this;
40+
}
41+
42+
template <typename Scalar, require_stan_scalar_t<Scalar>* = nullptr>
43+
inline auto operator=(Scalar x) {
44+
ret_ = x;
45+
return *this;
46+
}
47+
48+
template <typename EigArr, require_eigen_t<EigArr>* = nullptr>
49+
inline auto operator+=(EigArr&& x) {
50+
ret_ += sum(x);
51+
return *this;
52+
}
53+
54+
template <typename Scalar, require_stan_scalar_t<Scalar>* = nullptr>
55+
inline auto operator+=(Scalar&& x) {
56+
ret_ += x;
57+
return *this;
58+
}
59+
60+
template <typename EigArr, require_eigen_t<EigArr>* = nullptr>
61+
inline auto operator-=(EigArr&& x) {
62+
ret_ -= sum(x);
63+
return *this;
64+
}
65+
66+
template <typename Scalar, require_stan_scalar_t<Scalar>* = nullptr>
67+
inline auto operator-=(Scalar&& x) {
68+
ret_ -= x;
69+
return *this;
70+
}
71+
inline auto ret() noexcept {
72+
return ret_;
73+
}
74+
template <typename T1>
75+
static auto zero(T1&& /* */) {
76+
return T(0);
77+
}
78+
79+
};
80+
81+
template <typename T>
82+
struct prob_broadcaster<T, require_eigen_t<T>> {
83+
T ret_;
84+
template <typename EigArr, require_eigen_t<EigArr>* = nullptr>
85+
prob_broadcaster(EigArr&& x) : ret_(std::forward<EigArr>(x)) {}
86+
87+
template <typename EigArr, require_eigen_t<EigArr>* = nullptr>
88+
inline auto operator=(EigArr&& x) {
89+
ret_ = sum(x);
90+
return *this;
91+
}
92+
93+
template <typename Scalar, require_stan_scalar_t<Scalar>* = nullptr>
94+
inline auto operator=(Scalar x) {
95+
ret_ = Eigen::Array<value_type_t<T>, -1, 1>::Constant(x, ret_.size());
96+
return *this;
97+
}
98+
99+
template <typename EigArr, require_eigen_t<EigArr>* = nullptr>
100+
inline auto operator+=(EigArr&& x) {
101+
ret_ += x;
102+
return *this;
103+
}
104+
105+
template <typename Scalar, require_stan_scalar_t<Scalar>* = nullptr>
106+
inline auto operator+=(Scalar&& x) {
107+
ret_ += x;
108+
return *this;
109+
}
110+
111+
template <typename EigArr, require_eigen_t<EigArr>* = nullptr>
112+
inline auto operator-=(EigArr&& x) {
113+
ret_ -= x;
114+
return *this;
115+
}
116+
117+
template <typename Scalar, require_stan_scalar_t<Scalar>* = nullptr>
118+
inline auto operator-=(Scalar&& x) {
119+
ret_ -= x;
120+
return *this;
121+
}
122+
123+
inline auto&& ret() noexcept {
124+
return std::move(ret_);
125+
}
126+
127+
template <typename T1>
128+
static auto zero(T1&& size) {
129+
return Eigen::Array<value_type_t<T>, -1, 1>::Constant(0, size).eval();
130+
}
131+
132+
};
133+
134+
135+
136+
template <ProbReturnType ReturnType, typename... Types>
137+
using prob_return_t = prob_broadcaster<std::conditional_t<ReturnType == ProbReturnType::Scalar, return_type_t<Types...>, Eigen::Array<return_type_t<Types...>, -1, 1>>>;
138+
22139
/** \ingroup prob_dists
23140
* The log of the normal density for the specified scalar(s) given
24141
* the specified mean(s) and deviation(s). y, mu, or sigma can
@@ -38,10 +155,10 @@ namespace math {
38155
* @return The log of the product of the densities.
39156
* @throw std::domain_error if the scale is not positive.
40157
*/
41-
template <bool propto, typename T_y, typename T_loc, typename T_scale,
158+
template <bool propto, ProbReturnType RetType = ProbReturnType::Scalar, typename T_y, typename T_loc, typename T_scale,
42159
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
43160
T_y, T_loc, T_scale>* = nullptr>
44-
inline return_type_t<T_y, T_loc, T_scale> normal_lpdf(const T_y& y,
161+
inline auto normal_lpdf(const T_y& y,
45162
const T_loc& mu,
46163
const T_scale& sigma) {
47164
using T_partials_return = partials_return_t<T_y, T_loc, T_scale>;
@@ -62,12 +179,13 @@ inline return_type_t<T_y, T_loc, T_scale> normal_lpdf(const T_y& y,
62179
check_not_nan(function, "Random variable", y_val);
63180
check_finite(function, "Location parameter", mu_val);
64181
check_positive(function, "Scale parameter", sigma_val);
65-
182+
using ret_t = prob_return_t<RetType, T_partials_return>;
183+
const size_t N = max_size(y, mu, sigma);
66184
if (size_zero(y, mu, sigma)) {
67-
return 0.0;
185+
return ret_t::zero(N);
68186
}
69187
if (!include_summand<propto, T_y, T_loc, T_scale>::value) {
70-
return 0.0;
188+
return ret_t::zero(N);
71189
}
72190

73191
operands_and_partials<T_y_ref, T_mu_ref, T_sigma_ref> ops_partials(
@@ -79,13 +197,16 @@ inline return_type_t<T_y, T_loc, T_scale> normal_lpdf(const T_y& y,
79197
const auto& y_scaled_sq
80198
= to_ref_if<!is_constant_all<T_scale>::value>(y_scaled * y_scaled);
81199

82-
size_t N = max_size(y, mu, sigma);
83-
T_partials_return logp = -0.5 * sum(y_scaled_sq);
200+
prob_return_t<RetType, T_partials_return> logp = -0.5 * y_scaled_sq;
84201
if (include_summand<propto>::value) {
85202
logp += NEG_LOG_SQRT_TWO_PI * N;
86203
}
87204
if (include_summand<propto, T_scale>::value) {
88-
logp -= sum(log(sigma_val)) * N / size(sigma);
205+
if (RetType == ProbReturnType::Scalar) {
206+
logp -= sum(log(sigma_val)) * N / size(sigma);
207+
} else {
208+
logp -= log(sigma_val);
209+
}
89210
}
90211

91212
if (!is_constant_all<T_y, T_scale, T_loc>::value) {
@@ -103,11 +224,11 @@ inline return_type_t<T_y, T_loc, T_scale> normal_lpdf(const T_y& y,
103224
ops_partials.edge2_.partials_ = std::move(scaled_diff);
104225
}
105226
}
106-
return ops_partials.build(logp);
227+
return ops_partials.build(logp.ret());
107228
}
108229

109230
template <typename T_y, typename T_loc, typename T_scale>
110-
inline return_type_t<T_y, T_loc, T_scale> normal_lpdf(const T_y& y,
231+
inline auto normal_lpdf(const T_y& y,
111232
const T_loc& mu,
112233
const T_scale& sigma) {
113234
return normal_lpdf<false>(y, mu, sigma);

stan/math/rev/functor/operands_and_partials.hpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,41 @@ inline void update_adjoints(StdVec1& x, const Vec2& y, const vari& z) {
7474
}
7575
}
7676

77+
template <typename T1, typename T2, typename T3,
78+
require_all_kernel_expressions_and_none_scalar_t<T1, T2, T3>* = nullptr>
79+
inline void update_adjoints(var_value<T1>& x, const T2& y, const T3& z) {
80+
x.adj() += z.adj() * y;
81+
}
82+
83+
template <typename Scalar1, typename Scalar2, typename T3, require_var_t<Scalar1>* = nullptr,
84+
require_not_var_matrix_t<Scalar1>* = nullptr,
85+
require_arithmetic_t<Scalar2>* = nullptr,
86+
require_eigen_t<T3>* = nullptr>
87+
inline void update_adjoints(Scalar1 x, Scalar2 y, const T3& z) noexcept {
88+
x.adj() += sum(z.adj() * y);
89+
}
90+
template <typename Matrix1, typename Matrix2, typename T3
91+
require_rev_matrix_t<Matrix1>* = nullptr,
92+
require_st_arithmetic<Matrix2>* = nullptr,
93+
require_eigen_t<T3>* = nullptr>
94+
inline void update_adjoints(Matrix1& x, const Matrix2& y, const T3& z) {
95+
x.adj().array() += z.adj() * y.array();
96+
}
97+
98+
template <typename Arith, typename Alt, require_st_arithmetic<Arith>* = nullptr, typename T3, require_eigen_t<T3>* = nullptr>
99+
inline constexpr void update_adjoints(Arith&& /* x */, Alt&& /* y */,
100+
const T3& /* z */) noexcept {}
101+
102+
template <typename StdVec1, typename Vec2, typename T3,
103+
require_std_vector_t<StdVec1>* = nullptr,
104+
require_st_arithmetic<Vec2>* = nullptr,
105+
require_eigen_t<T3>* = nullptr>
106+
inline void update_adjoints(StdVec1& x, const Vec2& y, const T3& z) {
107+
for (size_t i = 0; i < x.size(); ++i) {
108+
update_adjoints(x[i], y[i], z[i]);
109+
}
110+
}
111+
77112
} // namespace internal
78113

79114
/** \ingroup type_trait
@@ -169,6 +204,33 @@ class operands_and_partials<Op1, Op2, Op3, Op4, Op5, var> {
169204
}
170205
});
171206
}
207+
template <typename T, require_eigen_t<T>* = nullptr>
208+
auto build(T&& value) {
209+
arena_t<promote_scalar_t<var, T>> ret = value;
210+
reverse_pass_callback([ret, operand1 = edge1_.operand(), partial1 = edge1_.partial(),
211+
operand2 = edge2_.operand(), partial2 = edge2_.partial(),
212+
operand3 = edge3_.operand(), partial3 = edge3_.partial(),
213+
operand4 = edge4_.operand(), partial4 = edge4_.partial(),
214+
operand5 = edge5_.operand(),
215+
partial5 = edge5_.partial()]() mutable {
216+
if (!is_constant<Op1>::value) {
217+
internal::update_adjoints(operand1, partial1, vi);
218+
}
219+
if (!is_constant<Op2>::value) {
220+
internal::update_adjoints(operand2, partial2, vi);
221+
}
222+
if (!is_constant<Op3>::value) {
223+
internal::update_adjoints(operand3, partial3, vi);
224+
}
225+
if (!is_constant<Op4>::value) {
226+
internal::update_adjoints(operand4, partial4, vi);
227+
}
228+
if (!is_constant<Op5>::value) {
229+
internal::update_adjoints(operand5, partial5, vi);
230+
}
231+
});
232+
return plain_type_t<promote_scalar_t<var, T>>(ret);
233+
}
172234
};
173235

174236
namespace internal {

test/unit/math/prim/prob/normal_log_test.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,20 @@ TEST(ProbNormal, log_matches_lpdf) {
1313
EXPECT_FLOAT_EQ((stan::math::normal_lpdf<false>(y, mu, sigma)),
1414
(stan::math::normal_log<false>(y, mu, sigma)));
1515
EXPECT_FLOAT_EQ(
16-
(stan::math::normal_lpdf<true, double, double, double>(y, mu, sigma)),
17-
(stan::math::normal_log<true, double, double, double>(y, mu, sigma)));
16+
(stan::math::normal_lpdf<true>(y, mu, sigma)),
17+
(stan::math::normal_log<true>(y, mu, sigma)));
1818
EXPECT_FLOAT_EQ(
19-
(stan::math::normal_lpdf<false, double, double, double>(y, mu, sigma)),
20-
(stan::math::normal_log<false, double, double, double>(y, mu, sigma)));
19+
(stan::math::normal_lpdf<false>(y, mu, sigma)),
20+
(stan::math::normal_log<false>(y, mu, sigma)));
2121
EXPECT_FLOAT_EQ(
22-
(stan::math::normal_lpdf<double, double, double>(y, mu, sigma)),
23-
(stan::math::normal_log<double, double, double>(y, mu, sigma)));
22+
(stan::math::normal_lpdf(y, mu, sigma)),
23+
(stan::math::normal_log(y, mu, sigma)));
24+
}
25+
26+
TEST(ProbNormal, test_vlpdf) {
27+
Eigen::Matrix<double, -1, 1> Y = Eigen::Matrix<double, -1, 1>::Random(5);
28+
Eigen::Matrix<double, -1, 1> Mu = Eigen::Matrix<double, -1, 1>::Random(5);
29+
Eigen::Matrix<double, -1, 1> Sigma = stan::math::abs(Eigen::Matrix<double, -1, 1>::Random(5));
30+
Eigen::Matrix<double, -1, 1> A = stan::math::normal_lpdf<false, stan::math::ProbReturnType::Vector>(Y, Mu, Sigma);
31+
2432
}

0 commit comments

Comments
 (0)