Skip to content

Commit 2e174d2

Browse files
authored
Merge pull request #2792 from andrjohns/add-hyper-2f1
Expose hypergeometric_2F1 function
2 parents 231cbd0 + 1475b99 commit 2e174d2

16 files changed

+973
-399
lines changed

stan/math/fwd/fun.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include <stan/math/fwd/fun/gamma_p.hpp>
4242
#include <stan/math/fwd/fun/gamma_q.hpp>
4343
#include <stan/math/fwd/fun/grad_inc_beta.hpp>
44+
#include <stan/math/fwd/fun/hypergeometric_2F1.hpp>
4445
#include <stan/math/fwd/fun/hypergeometric_pFq.hpp>
4546
#include <stan/math/fwd/fun/hypot.hpp>
4647
#include <stan/math/fwd/fun/inc_beta.hpp>

stan/math/fwd/fun/grad_inc_beta.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,11 @@ void grad_inc_beta(fvar<T>& g1, fvar<T>& g2, fvar<T> a, fvar<T> b, fvar<T> z) {
4141
fvar<T> dF1 = 0;
4242
fvar<T> dF2 = 0;
4343
fvar<T> dF3 = 0;
44+
fvar<T> dFz = 0;
4445

45-
if (value_of(value_of(C))) {
46-
grad_2F1(dF1, dF2, dF3, a + b, fvar<T>(1.0), a + 1, z);
46+
if (value_of_rec(C)) {
47+
std::forward_as_tuple(dF1, dF2, dF3, dFz)
48+
= grad_2F1<true>(a + b, fvar<T>(1.0), a + 1, z);
4749
}
4850

4951
g1 = (c1 - 1.0 / a) * c3 + C * (dF1 + dF3);
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#ifndef STAN_MATH_FWD_FUN_HYPERGEOMETRIC_2F1_HPP
2+
#define STAN_MATH_FWD_FUN_HYPERGEOMETRIC_2F1_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/fwd/core.hpp>
6+
#include <stan/math/prim/fun/hypergeometric_2F1.hpp>
7+
#include <stan/math/prim/fun/grad_2F1.hpp>
8+
9+
namespace stan {
10+
namespace math {
11+
12+
/**
13+
* Returns the Gauss hypergeometric function applied to the
14+
* input arguments:
15+
* \f$_2F_1(a_1,a_2;b;z)\f$
16+
*
17+
* See 'grad_2F1.hpp' for the derivatives wrt each parameter
18+
*
19+
* @tparam Ta1 Type of scalar first 'a' argument
20+
* @tparam Ta2 Type of scalar second 'a' argument
21+
* @tparam Tb Type of scalar 'b' argument
22+
* @tparam Tz Type of scalar 'z' argument
23+
* @param[in] a1 First of 'a' arguments to function
24+
* @param[in] a2 Second of 'a' arguments to function
25+
* @param[in] b 'b' argument to function
26+
* @param[in] z Scalar z argument
27+
* @return Gauss hypergeometric function
28+
*/
29+
template <typename Ta1, typename Ta2, typename Tb, typename Tz,
30+
require_all_stan_scalar_t<Ta1, Ta2, Tb, Tz>* = nullptr,
31+
require_any_fvar_t<Ta1, Ta2, Tb, Tz>* = nullptr>
32+
inline return_type_t<Ta1, Ta1, Tb, Tz> hypergeometric_2F1(const Ta1& a1,
33+
const Ta2& a2,
34+
const Tb& b,
35+
const Tz& z) {
36+
using fvar_t = return_type_t<Ta1, Ta1, Tb, Tz>;
37+
38+
auto a1_val = value_of(a1);
39+
auto a2_val = value_of(a2);
40+
auto b_val = value_of(b);
41+
auto z_val = value_of(z);
42+
43+
auto grad_tuple = grad_2F1(a1, a2, b, z);
44+
45+
typename fvar_t::Scalar grad = 0;
46+
47+
if (!is_constant<Ta1>::value) {
48+
grad += forward_as<fvar_t>(a1).d() * std::get<0>(grad_tuple);
49+
}
50+
if (!is_constant<Ta2>::value) {
51+
grad += forward_as<fvar_t>(a2).d() * std::get<1>(grad_tuple);
52+
}
53+
if (!is_constant<Tb>::value) {
54+
grad += forward_as<fvar_t>(b).d() * std::get<2>(grad_tuple);
55+
}
56+
if (!is_constant<Tz>::value) {
57+
grad += forward_as<fvar_t>(z).d() * std::get<3>(grad_tuple);
58+
}
59+
60+
return fvar_t(hypergeometric_2F1(a1_val, a2_val, b_val, z_val), grad);
61+
}
62+
63+
} // namespace math
64+
} // namespace stan
65+
#endif

stan/math/prim/fun.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,9 @@
129129
#include <stan/math/prim/fun/grad_reg_inc_gamma.hpp>
130130
#include <stan/math/prim/fun/grad_reg_lower_inc_gamma.hpp>
131131
#include <stan/math/prim/fun/head.hpp>
132-
#include <stan/math/prim/fun/hypergeometric_pFq.hpp>
132+
#include <stan/math/prim/fun/hypergeometric_2F1.hpp>
133133
#include <stan/math/prim/fun/hypergeometric_2F2.hpp>
134+
#include <stan/math/prim/fun/hypergeometric_pFq.hpp>
134135
#include <stan/math/prim/fun/hypot.hpp>
135136
#include <stan/math/prim/fun/identity_constrain.hpp>
136137
#include <stan/math/prim/fun/identity_free.hpp>

0 commit comments

Comments
 (0)