Skip to content

Commit 864f65f

Browse files
committed
Update doc, use log_sum_exp_signed
1 parent 53a02e3 commit 864f65f

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

stan/math/prim/fun/hypergeometric_3F2.hpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <stan/math/prim/fun/sum.hpp>
1313
#include <stan/math/prim/fun/sign.hpp>
1414
#include <stan/math/prim/fun/value_of_rec.hpp>
15+
#include <stan/math/prim/fun/log_sum_exp_signed.hpp>
1516

1617
namespace stan {
1718
namespace math {
@@ -30,7 +31,7 @@ T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
3031
check_3F2_converges("hypergeometric_3F2", a_array[0], a_array[1], a_array[2],
3132
b_array[0], b_array[1], z);
3233

33-
T_return t_acc = 1.0;
34+
T_return t_acc = 0.0;
3435
T_return log_t = 0.0;
3536
T_return log_z = log(fabs(z));
3637
Eigen::ArrayXi a_signs = sign(value_of_rec(a_array));
@@ -39,7 +40,7 @@ T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
3940
plain_type_t<decltype(b_array)> bpk = b_array;
4041
int z_sign = sign(value_of_rec(z));
4142
int t_sign = z_sign * a_signs.prod() * b_signs.prod();
42-
43+
int acc_sign = 1;
4344
int k = 0;
4445
while (k <= max_steps && log_t >= log(precision)) {
4546
// Replace zero values with 1 prior to taking the log so that we accumulate
@@ -52,7 +53,8 @@ T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
5253
}
5354

5455
log_t += p + log_z;
55-
t_acc += t_sign * exp(log_t);
56+
std::forward_as_tuple(t_acc, acc_sign)
57+
= log_sum_exp_signed(t_acc, acc_sign, log_t, t_sign);
5658

5759
if (is_inf(t_acc)) {
5860
throw_domain_error("hypergeometric_3F2", "sum (output)", t_acc,
@@ -70,7 +72,7 @@ T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
7072
"exceeded iterations, hypergeometric function did not ",
7173
"converge.");
7274
}
73-
return t_acc;
75+
return acc_sign * exp(t_acc);
7476
}
7577
} // namespace internal
7678

@@ -109,7 +111,8 @@ T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
109111
* @param[in] z z (is always called with 1 from beta binomial cdfs)
110112
* @param[in] precision precision of the infinite sum. defaults to 1e-6
111113
* @param[in] max_steps number of steps to take. defaults to 1e5
112-
* @return Generalized hypergeometric function applied to the inputs
114+
* The 3F2 generalized hypergeometric function applied to the
115+
* arguments {a1, a2, a3}, {b1, b2}
113116
*/
114117
template <typename Ta, typename Tb, typename Tz,
115118
require_all_vector_t<Ta, Tb>* = nullptr,

stan/math/prim/fun/log_sum_exp_signed.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/fun/constants.hpp>
66
#include <stan/math/prim/fun/Eigen.hpp>
7-
#include <stan/math/prim/fun/log1p_exp.hpp>
7+
#include <stan/math/prim/fun/log_diff_exp.hpp>
8+
#include <stan/math/prim/fun/log_sum_exp.hpp>
89
#include <cmath>
910
#include <vector>
1011

0 commit comments

Comments
 (0)