Skip to content

Commit 6cd1dbe

Browse files
committed
update requires to seperate opencl and eigen var_value cases
1 parent 0d111e3 commit 6cd1dbe

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

stan/math/rev/fun/log_inv_logit.hpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,29 @@ namespace math {
1313
* Return the natural logarithm of the inverse logit of the
1414
* specified argument.
1515
*
16-
* @param u argument
16+
* @tparam T An arithmetic type
17+
* @param u `var_value` with inner arithmetic type
1718
* @return log inverse logit of the argument
1819
*/
19-
template <typename T>
20+
template <typename T, require_arithmetic_t<T>* = nullptr>
2021
inline auto log_inv_logit(const var_value<T>& u) {
2122
return make_callback_var(log_inv_logit(u.val()), [u](auto& vi) mutable {
22-
as_array_or_scalar(u.adj()) += as_array_or_scalar(vi.adj())
23-
* as_array_or_scalar(inv_logit(-u.val()));
23+
u.adj() += vi.adj() * inv_logit(-u.val());
24+
});
25+
}
26+
27+
/**
28+
* Return the natural logarithm of the inverse logit of the
29+
* specified argument.
30+
*
31+
* @tparam T A type derived from `Eigen::EigenBase`
32+
* @param u `var_value` with inner Eigen type
33+
* @return log inverse logit of the argument
34+
*/
35+
template <typename T, require_eigen_t<T>* = nullptr>
36+
inline auto log_inv_logit(const var_value<T>& u) {
37+
return make_callback_var(log_inv_logit(u.val()), [u](auto& vi) mutable {
38+
u.adj().array() += vi.adj().array() * inv_logit(-u.val()).array();
2439
});
2540
}
2641

0 commit comments

Comments
 (0)