Skip to content

Commit c4b0717

Browse files
authored
Merge pull request #2797 from andrjohns/hyper3f2-expose
Expose `hypergeometric_3F2` function
2 parents 4d2b936 + 4575585 commit c4b0717

File tree

15 files changed

+289
-185
lines changed

15 files changed

+289
-185
lines changed

stan/math/fwd/fun/inv_inc_beta.hpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
#include <stan/math/prim/fun/lbeta.hpp>
1212
#include <stan/math/prim/fun/lgamma.hpp>
1313
#include <stan/math/prim/fun/digamma.hpp>
14-
#include <stan/math/prim/fun/F32.hpp>
14+
#include <stan/math/prim/fun/hypergeometric_3F2.hpp>
1515

1616
namespace stan {
1717
namespace math {
@@ -53,19 +53,21 @@ inline fvar<partials_return_t<T1, T2, T3>> inv_inc_beta(const T1& a,
5353
T_return inv_d_(0);
5454

5555
if (is_fvar<T1>::value) {
56+
std::vector<T_return> da_a{a_val, a_val, one_m_b};
57+
std::vector<T_return> da_b{ap1, ap1};
5658
auto da1 = exp(one_m_b * log1m_w + one_m_a * log_w);
57-
auto da2
58-
= exp(a_val * log_w + 2 * lgamma(a_val)
59-
+ log(F32(a_val, a_val, one_m_b, ap1, ap1, w)) - 2 * lgamma(ap1));
59+
auto da2 = exp(a_val * log_w + 2 * lgamma(a_val)
60+
+ log(hypergeometric_3F2(da_a, da_b, w)) - 2 * lgamma(ap1));
6061
auto da3 = inc_beta(a_val, b_val, w) * exp(lbeta_ab)
6162
* (log_w - digamma(a_val) + digamma_apb);
6263
inv_d_ += forward_as<fvar<T_return>>(a).d_ * da1 * (da2 - da3);
6364
}
6465

6566
if (is_fvar<T2>::value) {
67+
std::vector<T_return> db_a{b_val, b_val, one_m_a};
68+
std::vector<T_return> db_b{bp1, bp1};
6669
auto db1 = (w - 1) * exp(-b_val * log1m_w + one_m_a * log_w);
67-
auto db2 = 2 * lgamma(b_val)
68-
+ log(F32(b_val, b_val, one_m_a, bp1, bp1, one_m_w))
70+
auto db2 = 2 * lgamma(b_val) + log(hypergeometric_3F2(db_a, db_b, one_m_w))
6971
- 2 * lgamma(bp1) + b_val * log1m_w;
7072

7173
auto db3 = inc_beta(b_val, a_val, one_m_w) * exp(lbeta_ab)

stan/math/prim/fun.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@
9494
#include <stan/math/prim/fun/exp.hpp>
9595
#include <stan/math/prim/fun/exp2.hpp>
9696
#include <stan/math/prim/fun/expm1.hpp>
97-
#include <stan/math/prim/fun/F32.hpp>
9897
#include <stan/math/prim/fun/fabs.hpp>
9998
#include <stan/math/prim/fun/factor_U.hpp>
10099
#include <stan/math/prim/fun/factor_cov_matrix.hpp>
@@ -132,6 +131,7 @@
132131
#include <stan/math/prim/fun/head.hpp>
133132
#include <stan/math/prim/fun/hypergeometric_2F1.hpp>
134133
#include <stan/math/prim/fun/hypergeometric_2F2.hpp>
134+
#include <stan/math/prim/fun/hypergeometric_3F2.hpp>
135135
#include <stan/math/prim/fun/hypergeometric_pFq.hpp>
136136
#include <stan/math/prim/fun/hypot.hpp>
137137
#include <stan/math/prim/fun/identity_constrain.hpp>

stan/math/prim/fun/F32.hpp

Lines changed: 0 additions & 99 deletions
This file was deleted.

stan/math/prim/fun/grad_pFq.hpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -234,12 +234,18 @@ void grad_pFq_impl(TupleT&& grad_tuple, const Ta& a, const Tb& b, const Tz& z,
234234
log_phammer_1n += log1p(n);
235235
log_phammer_2_mpn += log(2 + m + n);
236236

237-
log_phammer_ap1_n += log(stan::math::fabs(ap1n));
238-
log_phammer_bp1_n += log(stan::math::fabs(bp1n));
239-
log_phammer_an += log(stan::math::fabs(an));
240-
log_phammer_bn += log(stan::math::fabs(bn));
241-
log_phammer_ap1_mpn += log(stan::math::fabs(ap1mn));
242-
log_phammer_bp1_mpn += log(stan::math::fabs(bp1mn));
237+
log_phammer_ap1_n.array()
238+
+= log(math::fabs((ap1n.array() == 0).select(1.0, ap1n.array())));
239+
log_phammer_bp1_n.array()
240+
+= log(math::fabs((bp1n.array() == 0).select(1.0, bp1n.array())));
241+
log_phammer_an.array()
242+
+= log(math::fabs((an.array() == 0).select(1.0, an.array())));
243+
log_phammer_bn.array()
244+
+= log(math::fabs((bn.array() == 0).select(1.0, bn.array())));
245+
log_phammer_ap1_mpn.array()
246+
+= log(math::fabs((ap1mn.array() == 0).select(1.0, ap1mn.array())));
247+
log_phammer_bp1_mpn.array()
248+
+= log(math::fabs((bp1mn.array() == 0).select(1.0, bp1mn.array())));
243249

244250
z_pow_mn_sign *= z_sign;
245251
log_phammer_ap1n_sign.array() *= sign(value_of_rec(ap1n)).array();
@@ -266,9 +272,9 @@ void grad_pFq_impl(TupleT&& grad_tuple, const Ta& a, const Tb& b, const Tz& z,
266272
log_z_m += log_z;
267273
log_phammer_1m += log1p(m);
268274
log_phammer_2m += log(2 + m);
269-
log_phammer_ap1_m += log(stan::math::fabs(ap1m));
275+
log_phammer_ap1_m += log(math::fabs(ap1m));
270276
log_phammer_ap1m_sign.array() *= sign(value_of_rec(ap1m)).array();
271-
log_phammer_bp1_m += log(stan::math::fabs(bp1m));
277+
log_phammer_bp1_m += log(math::fabs(bp1m));
272278
log_phammer_bp1m_sign.array() *= sign(value_of_rec(bp1m)).array();
273279

274280
m += 1;
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
#ifndef STAN_MATH_PRIM_FUN_HYPERGEOMETRIC_3F2_HPP
2+
#define STAN_MATH_PRIM_FUN_HYPERGEOMETRIC_3F2_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/err.hpp>
6+
#include <stan/math/prim/fun/append_row.hpp>
7+
#include <stan/math/prim/fun/as_array_or_scalar.hpp>
8+
#include <stan/math/prim/fun/to_vector.hpp>
9+
#include <stan/math/prim/fun/constants.hpp>
10+
#include <stan/math/prim/fun/fabs.hpp>
11+
#include <stan/math/prim/fun/hypergeometric_pFq.hpp>
12+
#include <stan/math/prim/fun/sum.hpp>
13+
#include <stan/math/prim/fun/sign.hpp>
14+
#include <stan/math/prim/fun/value_of_rec.hpp>
15+
16+
namespace stan {
17+
namespace math {
18+
namespace internal {
19+
template <typename Ta, typename Tb, typename Tz,
20+
typename T_return = return_type_t<Ta, Tb, Tz>,
21+
typename ArrayAT = Eigen::Array<scalar_type_t<Ta>, 3, 1>,
22+
typename ArrayBT = Eigen::Array<scalar_type_t<Ta>, 3, 1>,
23+
require_all_vector_t<Ta, Tb>* = nullptr,
24+
require_stan_scalar_t<Tz>* = nullptr>
25+
T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
26+
double precision = 1e-6,
27+
int max_steps = 1e5) {
28+
ArrayAT a_array = as_array_or_scalar(a);
29+
ArrayBT b_array = append_row(as_array_or_scalar(b), 1.0);
30+
check_3F2_converges("hypergeometric_3F2", a_array[0], a_array[1], a_array[2],
31+
b_array[0], b_array[1], z);
32+
33+
T_return t_acc = 1.0;
34+
T_return log_t = 0.0;
35+
T_return log_z = log(fabs(z));
36+
Eigen::ArrayXi a_signs = sign(value_of_rec(a_array));
37+
Eigen::ArrayXi b_signs = sign(value_of_rec(b_array));
38+
plain_type_t<decltype(a_array)> apk = a_array;
39+
plain_type_t<decltype(b_array)> bpk = b_array;
40+
int z_sign = sign(value_of_rec(z));
41+
int t_sign = z_sign * a_signs.prod() * b_signs.prod();
42+
43+
int k = 0;
44+
while (k <= max_steps && log_t >= log(precision)) {
45+
// Replace zero values with 1 prior to taking the log so that we accumulate
46+
// 0.0 rather than -inf
47+
const auto& abs_apk = math::fabs((apk == 0).select(1.0, apk));
48+
const auto& abs_bpk = math::fabs((bpk == 0).select(1.0, bpk));
49+
T_return p = sum(log(abs_apk)) - sum(log(abs_bpk));
50+
if (p == NEGATIVE_INFTY) {
51+
return t_acc;
52+
}
53+
54+
log_t += p + log_z;
55+
t_acc += t_sign * exp(log_t);
56+
57+
if (is_inf(t_acc)) {
58+
throw_domain_error("hypergeometric_3F2", "sum (output)", t_acc,
59+
"overflow hypergeometric function did not converge.");
60+
}
61+
k++;
62+
apk.array() += 1.0;
63+
bpk.array() += 1.0;
64+
a_signs = sign(value_of_rec(apk));
65+
b_signs = sign(value_of_rec(bpk));
66+
t_sign = a_signs.prod() * b_signs.prod() * t_sign;
67+
}
68+
if (k == max_steps) {
69+
throw_domain_error("hypergeometric_3F2", "k (internal counter)", max_steps,
70+
"exceeded iterations, hypergeometric function did not ",
71+
"converge.");
72+
}
73+
return t_acc;
74+
}
75+
} // namespace internal
76+
77+
/**
78+
* Hypergeometric function (3F2).
79+
*
80+
* Function reference: http://dlmf.nist.gov/16.2
81+
*
82+
* \f[
83+
* _3F_2 \left(
84+
* \begin{matrix}a_1 a_2 a[2] \\ b_1 b_2\end{matrix}; z
85+
* \right) = \sum_k=0^\infty
86+
* \frac{(a_1)_k(a_2)_k(a_3)_k}{(b_1)_k(b_2)_k}\frac{z^k}{k!} \f]
87+
*
88+
* Where $(a_1)_k$ is an upper shifted factorial.
89+
*
90+
* Calculate the hypergeometric function (3F2) as the power series
91+
* directly to within <code>precision</code> or until
92+
* <code>max_steps</code> terms.
93+
*
94+
* This function does not have a closed form but will converge if:
95+
* - <code>|z|</code> is less than 1
96+
* - <code>|z|</code> is equal to one and <code>b[0] + b[1] < a[0] + a[1] +
97+
* a[2]</code> This function is a rational polynomial if
98+
* - <code>a[0]</code>, <code>a[1]</code>, or <code>a[2]</code> is a
99+
* non-positive integer
100+
* This function can be treated as a rational polynomial if
101+
* - <code>b[0]</code> or <code>b[1]</code> is a non-positive integer
102+
* and the series is terminated prior to the final term.
103+
*
104+
* @tparam Ta type of Eigen/Std vector 'a' arguments
105+
* @tparam Tb type of Eigen/Std vector 'b' arguments
106+
* @tparam Tz type of z argument
107+
* @param[in] a Always called with a[1] > 1, a[2] <= 0
108+
* @param[in] b Always called with int b[0] < |a[2]|, <= 1)
109+
* @param[in] z z (is always called with 1 from beta binomial cdfs)
110+
* @param[in] precision precision of the infinite sum. defaults to 1e-6
111+
* @param[in] max_steps number of steps to take. defaults to 1e5
112+
* @return The 3F2 generalized hypergeometric function applied to the
113+
* arguments {a1, a2, a3}, {b1, b2}
114+
*/
115+
template <typename Ta, typename Tb, typename Tz,
116+
require_all_vector_t<Ta, Tb>* = nullptr,
117+
require_stan_scalar_t<Tz>* = nullptr>
118+
auto hypergeometric_3F2(const Ta& a, const Tb& b, const Tz& z) {
119+
check_3F2_converges("hypergeometric_3F2", a[0], a[1], a[2], b[0], b[1], z);
120+
// Boost's pFq throws convergence errors in some cases, fallback to naive
121+
// infinite-sum approach (tests pass for these)
122+
if (z == 1.0 && (sum(b) - sum(a)) < 0.0) {
123+
return internal::hypergeometric_3F2_infsum(a, b, z);
124+
}
125+
return hypergeometric_pFq(to_vector(a), to_vector(b), z);
126+
}
127+
128+
/**
129+
* Hypergeometric function (3F2).
130+
*
131+
* Overload for initializer_list inputs
132+
*
133+
* @tparam Ta type of scalar 'a' arguments
134+
* @tparam Tb type of scalar 'b' arguments
135+
* @tparam Tz type of z argument
136+
* @param[in] a Always called with a[1] > 1, a[2] <= 0
137+
* @param[in] b Always called with int b[0] < |a[2]|, <= 1)
138+
* @param[in] z z (is always called with 1 from beta binomial cdfs)
139+
* @param[in] precision precision of the infinite sum. defaults to 1e-6
140+
* @param[in] max_steps number of steps to take. defaults to 1e5
141+
* @return The 3F2 generalized hypergeometric function applied to the
142+
* arguments {a1, a2, a3}, {b1, b2}
143+
*/
144+
template <typename Ta, typename Tb, typename Tz,
145+
require_all_stan_scalar_t<Ta, Tb, Tz>* = nullptr>
146+
auto hypergeometric_3F2(const std::initializer_list<Ta>& a,
147+
const std::initializer_list<Tb>& b, const Tz& z) {
148+
return hypergeometric_3F2(std::vector<Ta>(a), std::vector<Tb>(b), z);
149+
}
150+
151+
} // namespace math
152+
} // namespace stan
153+
#endif

stan/math/prim/fun/log_sum_exp_signed.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/fun/constants.hpp>
66
#include <stan/math/prim/fun/Eigen.hpp>
7-
#include <stan/math/prim/fun/log1p_exp.hpp>
7+
#include <stan/math/prim/fun/log_diff_exp.hpp>
8+
#include <stan/math/prim/fun/log_sum_exp.hpp>
89
#include <cmath>
910
#include <vector>
1011

stan/math/prim/prob/beta_binomial_cdf.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#include <stan/math/prim/fun/constants.hpp>
88
#include <stan/math/prim/fun/digamma.hpp>
99
#include <stan/math/prim/fun/exp.hpp>
10-
#include <stan/math/prim/fun/F32.hpp>
10+
#include <stan/math/prim/fun/hypergeometric_3F2.hpp>
1111
#include <stan/math/prim/fun/grad_F32.hpp>
1212
#include <stan/math/prim/fun/lbeta.hpp>
1313
#include <stan/math/prim/fun/max_size.hpp>
@@ -100,8 +100,8 @@ return_type_t<T_size1, T_size2> beta_binomial_cdf(const T_n& n, const T_N& N,
100100
const T_partials_return nu = beta_dbl + N_minus_n - 1;
101101
const T_partials_return one = 1;
102102

103-
const T_partials_return F
104-
= F32(one, mu, 1 - N_minus_n, n_dbl + 2, 1 - nu, one);
103+
const T_partials_return F = hypergeometric_3F2({one, mu, 1 - N_minus_n},
104+
{n_dbl + 2, 1 - nu}, one);
105105

106106
T_partials_return C = lbeta(nu, mu) - lbeta(alpha_dbl, beta_dbl)
107107
- lbeta(N_minus_n, n_dbl + 2);

stan/math/prim/prob/beta_binomial_lccdf.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#include <stan/math/prim/fun/constants.hpp>
88
#include <stan/math/prim/fun/digamma.hpp>
99
#include <stan/math/prim/fun/exp.hpp>
10-
#include <stan/math/prim/fun/F32.hpp>
10+
#include <stan/math/prim/fun/hypergeometric_3F2.hpp>
1111
#include <stan/math/prim/fun/grad_F32.hpp>
1212
#include <stan/math/prim/fun/lbeta.hpp>
1313
#include <stan/math/prim/fun/log.hpp>
@@ -101,8 +101,8 @@ return_type_t<T_size1, T_size2> beta_binomial_lccdf(const T_n& n, const T_N& N,
101101
const T_partials_return nu = beta_dbl + N_dbl - n_dbl - 1;
102102
const T_partials_return one = 1;
103103

104-
const T_partials_return F
105-
= F32(one, mu, -N_dbl + n_dbl + 1, n_dbl + 2, 1 - nu, one);
104+
const T_partials_return F = hypergeometric_3F2(
105+
{one, mu, -N_dbl + n_dbl + 1}, {n_dbl + 2, 1 - nu}, one);
106106
T_partials_return C = lbeta(nu, mu) - lbeta(alpha_dbl, beta_dbl)
107107
- lbeta(N_dbl - n_dbl, n_dbl + 2);
108108
C = F * exp(C) / (N_dbl + 1);

0 commit comments

Comments
 (0)