File tree Expand file tree Collapse file tree 1 file changed +19
-4
lines changed Expand file tree Collapse file tree 1 file changed +19
-4
lines changed Original file line number Diff line number Diff line change @@ -13,14 +13,29 @@ namespace math {
1313 * Return the natural logarithm of the inverse logit of the
1414 * specified argument.
1515 *
16- * @param u argument
16+ * @tparam T An arithmetic type
17+ * @param u `var_value` with inner arithmetic type
1718 * @return log inverse logit of the argument
1819 */
19- template <typename T>
20+ template <typename T, require_arithmetic_t <T>* = nullptr >
2021inline auto log_inv_logit (const var_value<T>& u) {
2122 return make_callback_var (log_inv_logit (u.val ()), [u](auto & vi) mutable {
22- as_array_or_scalar (u.adj ()) += as_array_or_scalar (vi.adj ())
23- * as_array_or_scalar (inv_logit (-u.val ()));
23+ u.adj () += vi.adj () * inv_logit (-u.val ());
24+ });
25+ }
26+
27+ /* *
28+ * Return the natural logarithm of the inverse logit of the
29+ * specified argument.
30+ *
31+ * @tparam T A type derived from `Eigen::EigenBase`
32+ * @param u `var_value` with inner Eigen type
33+ * @return log inverse logit of the argument
34+ */
35+ template <typename T, require_eigen_t <T>* = nullptr >
36+ inline auto log_inv_logit (const var_value<T>& u) {
37+ return make_callback_var (log_inv_logit (u.val ()), [u](auto & vi) mutable {
38+ u.adj ().array () += vi.adj ().array () * inv_logit (-u.val ()).array ();
2439 });
2540}
2641
You can’t perform that action at this time.
0 commit comments