1919namespace stan {
2020namespace math {
2121
22+ enum class ProbReturnType {Scalar, Vector};
23+
24+ template <typename T, typename = void >
25+ struct prob_broadcaster ;
26+
27+ template <typename T>
28+ struct prob_broadcaster <T, require_stan_scalar_t <T>> {
29+ T ret_;
30+ template <typename EigArr, require_eigen_t <EigArr>* = nullptr >
31+ prob_broadcaster (EigArr&& x) : ret_(sum(std::forward<EigArr>(x))) {}
32+
33+ template <typename Scalar, require_stan_scalar_t <Scalar>* = nullptr >
34+ prob_broadcaster (Scalar&& x) : ret_(x) {}
35+
36+ template <typename EigArr, require_eigen_t <EigArr>* = nullptr >
37+ inline auto operator =(EigArr&& x) {
38+ ret_ = sum (x);
39+ return *this ;
40+ }
41+
42+ template <typename Scalar, require_stan_scalar_t <Scalar>* = nullptr >
43+ inline auto operator =(Scalar x) {
44+ ret_ = x;
45+ return *this ;
46+ }
47+
48+ template <typename EigArr, require_eigen_t <EigArr>* = nullptr >
49+ inline auto operator +=(EigArr&& x) {
50+ ret_ += sum (x);
51+ return *this ;
52+ }
53+
54+ template <typename Scalar, require_stan_scalar_t <Scalar>* = nullptr >
55+ inline auto operator +=(Scalar&& x) {
56+ ret_ += x;
57+ return *this ;
58+ }
59+
60+ template <typename EigArr, require_eigen_t <EigArr>* = nullptr >
61+ inline auto operator -=(EigArr&& x) {
62+ ret_ -= sum (x);
63+ return *this ;
64+ }
65+
66+ template <typename Scalar, require_stan_scalar_t <Scalar>* = nullptr >
67+ inline auto operator -=(Scalar&& x) {
68+ ret_ -= x;
69+ return *this ;
70+ }
71+ inline auto ret () noexcept {
72+ return ret_;
73+ }
74+ template <typename T1>
75+ static auto zero (T1&& /* */ ) {
76+ return T (0 );
77+ }
78+
79+ };
80+
81+ template <typename T>
82+ struct prob_broadcaster <T, require_eigen_t <T>> {
83+ T ret_;
84+ template <typename EigArr, require_eigen_t <EigArr>* = nullptr >
85+ prob_broadcaster (EigArr&& x) : ret_(std::forward<EigArr>(x)) {}
86+
87+ template <typename EigArr, require_eigen_t <EigArr>* = nullptr >
88+ inline auto operator =(EigArr&& x) {
89+ ret_ = sum (x);
90+ return *this ;
91+ }
92+
93+ template <typename Scalar, require_stan_scalar_t <Scalar>* = nullptr >
94+ inline auto operator =(Scalar x) {
95+ ret_ = Eigen::Array<value_type_t <T>, -1 , 1 >::Constant (x, ret_.size ());
96+ return *this ;
97+ }
98+
99+ template <typename EigArr, require_eigen_t <EigArr>* = nullptr >
100+ inline auto operator +=(EigArr&& x) {
101+ ret_ += x;
102+ return *this ;
103+ }
104+
105+ template <typename Scalar, require_stan_scalar_t <Scalar>* = nullptr >
106+ inline auto operator +=(Scalar&& x) {
107+ ret_ += x;
108+ return *this ;
109+ }
110+
111+ template <typename EigArr, require_eigen_t <EigArr>* = nullptr >
112+ inline auto operator -=(EigArr&& x) {
113+ ret_ -= x;
114+ return *this ;
115+ }
116+
117+ template <typename Scalar, require_stan_scalar_t <Scalar>* = nullptr >
118+ inline auto operator -=(Scalar&& x) {
119+ ret_ -= x;
120+ return *this ;
121+ }
122+
123+ inline auto && ret() noexcept {
124+ return std::move (ret_);
125+ }
126+
127+ template <typename T1>
128+ static auto zero (T1&& size) {
129+ return Eigen::Array<value_type_t <T>, -1 , 1 >::Constant (0 , size).eval ();
130+ }
131+
132+ };
133+
134+
135+
136+ template <ProbReturnType ReturnType, typename ... Types>
137+ using prob_return_t = prob_broadcaster<std::conditional_t <ReturnType == ProbReturnType::Scalar, return_type_t <Types...>, Eigen::Array<return_type_t <Types...>, -1 , 1 >>>;
138+
22139/* * \ingroup prob_dists
23140 * The log of the normal density for the specified scalar(s) given
24141 * the specified mean(s) and deviation(s). y, mu, or sigma can
@@ -38,10 +155,10 @@ namespace math {
38155 * @return The log of the product of the densities.
39156 * @throw std::domain_error if the scale is not positive.
40157 */
41- template <bool propto, typename T_y, typename T_loc, typename T_scale,
158+ template <bool propto, ProbReturnType RetType = ProbReturnType::Scalar, typename T_y, typename T_loc, typename T_scale,
42159 require_all_not_nonscalar_prim_or_rev_kernel_expression_t <
43160 T_y, T_loc, T_scale>* = nullptr >
44- inline return_type_t <T_y, T_loc, T_scale> normal_lpdf (const T_y& y,
161+ inline auto normal_lpdf (const T_y& y,
45162 const T_loc& mu,
46163 const T_scale& sigma) {
47164 using T_partials_return = partials_return_t <T_y, T_loc, T_scale>;
@@ -62,12 +179,13 @@ inline return_type_t<T_y, T_loc, T_scale> normal_lpdf(const T_y& y,
62179 check_not_nan (function, " Random variable" , y_val);
63180 check_finite (function, " Location parameter" , mu_val);
64181 check_positive (function, " Scale parameter" , sigma_val);
65-
182+ using ret_t = prob_return_t <RetType, T_partials_return>;
183+ const size_t N = max_size (y, mu, sigma);
66184 if (size_zero (y, mu, sigma)) {
67- return 0.0 ;
185+ return ret_t::zero (N) ;
68186 }
69187 if (!include_summand<propto, T_y, T_loc, T_scale>::value) {
70- return 0.0 ;
188+ return ret_t::zero (N) ;
71189 }
72190
73191 operands_and_partials<T_y_ref, T_mu_ref, T_sigma_ref> ops_partials (
@@ -79,13 +197,16 @@ inline return_type_t<T_y, T_loc, T_scale> normal_lpdf(const T_y& y,
79197 const auto & y_scaled_sq
80198 = to_ref_if<!is_constant_all<T_scale>::value>(y_scaled * y_scaled);
81199
82- size_t N = max_size (y, mu, sigma);
83- T_partials_return logp = -0.5 * sum (y_scaled_sq);
200+ prob_return_t <RetType, T_partials_return> logp = -0.5 * y_scaled_sq;
84201 if (include_summand<propto>::value) {
85202 logp += NEG_LOG_SQRT_TWO_PI * N;
86203 }
87204 if (include_summand<propto, T_scale>::value) {
88- logp -= sum (log (sigma_val)) * N / size (sigma);
205+ if (RetType == ProbReturnType::Scalar) {
206+ logp -= sum (log (sigma_val)) * N / size (sigma);
207+ } else {
208+ logp -= log (sigma_val);
209+ }
89210 }
90211
91212 if (!is_constant_all<T_y, T_scale, T_loc>::value) {
@@ -103,11 +224,11 @@ inline return_type_t<T_y, T_loc, T_scale> normal_lpdf(const T_y& y,
103224 ops_partials.edge2_ .partials_ = std::move (scaled_diff);
104225 }
105226 }
106- return ops_partials.build (logp);
227+ return ops_partials.build (logp. ret () );
107228}
108229
109230template <typename T_y, typename T_loc, typename T_scale>
110- inline return_type_t <T_y, T_loc, T_scale> normal_lpdf (const T_y& y,
231+ inline auto normal_lpdf (const T_y& y,
111232 const T_loc& mu,
112233 const T_scale& sigma) {
113234 return normal_lpdf<false >(y, mu, sigma);
0 commit comments