Skip to content

Commit 059e21f

Browse files
authored
Merge pull request #2806 from stan-dev/feature/log_inv_logit
adds varmat for log_inv_logit
2 parents b3da945 + 6cd1dbe commit 059e21f

File tree

3 files changed

+41
-3
lines changed

3 files changed

+41
-3
lines changed

stan/math/prim/fun/log_inv_logit.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ struct log_inv_logit_fun {
7979
* @return elementwise log_inv_logit of members of container
8080
*/
8181
template <typename T,
82-
require_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr>
82+
require_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr,
83+
require_not_var_matrix_t<T>* = nullptr>
8384
inline auto log_inv_logit(const T& x) {
8485
return apply_scalar_unary<log_inv_logit_fun, T>::apply(x);
8586
}

stan/math/rev/fun/log_inv_logit.hpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,32 @@ namespace math {
1313
* Return the natural logarithm of the inverse logit of the
1414
* specified argument.
1515
*
16-
* @param u argument
16+
* @tparam T An arithmetic type
17+
* @param u `var_value` with inner arithmetic type
1718
* @return log inverse logit of the argument
1819
*/
19-
inline var log_inv_logit(const var& u) {
20+
template <typename T, require_arithmetic_t<T>* = nullptr>
21+
inline auto log_inv_logit(const var_value<T>& u) {
2022
return make_callback_var(log_inv_logit(u.val()), [u](auto& vi) mutable {
2123
u.adj() += vi.adj() * inv_logit(-u.val());
2224
});
2325
}
2426

27+
/**
28+
* Return the natural logarithm of the inverse logit of the
29+
* specified argument.
30+
*
31+
* @tparam T A type derived from `Eigen::EigenBase`
32+
* @param u `var_value` with inner Eigen type
33+
* @return log inverse logit of the argument
34+
*/
35+
template <typename T, require_eigen_t<T>* = nullptr>
36+
inline auto log_inv_logit(const var_value<T>& u) {
37+
return make_callback_var(log_inv_logit(u.val()), [u](auto& vi) mutable {
38+
u.adj().array() += vi.adj().array() * inv_logit(-u.val()).array();
39+
});
40+
}
41+
2542
} // namespace math
2643
} // namespace stan
2744
#endif

test/unit/math/mix/fun/log_inv_logit_test.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,23 @@ TEST(mathMixMatFun, logInvLogit) {
55
stan::test::expect_common_unary_vectorized(f);
66
stan::test::expect_unary_vectorized(f, -1.0, -0.5, 0.5, 1.3, 5);
77
}
8+
9+
TEST(mathMixMatFun, logInvLogitVarMat) {
10+
using stan::math::vec_concat;
11+
using stan::test::expect_ad_vector_matvar;
12+
using stan::test::internal::common_nonzero_args;
13+
auto f = [](const auto& x1) {
14+
using stan::math::acos;
15+
return stan::math::log_inv_logit(x1);
16+
};
17+
std::vector<double> com_args = common_nonzero_args();
18+
std::vector<double> args{
19+
-2.2, -0.8, 0.5, 1 + std::numeric_limits<double>::epsilon(),
20+
1.5, 3, 3.4, 4};
21+
auto all_args = vec_concat(com_args, args);
22+
Eigen::VectorXd A(all_args.size());
23+
for (int i = 0; i < all_args.size(); ++i) {
24+
A(i) = all_args[i];
25+
}
26+
expect_ad_vector_matvar(f, A);
27+
}

0 commit comments

Comments
 (0)