Skip to content

Commit 4d2b936

Browse files
authored
Merge pull request #2829 from andrjohns/log_sum_exp-signed
Log sum exp signed
2 parents 059e21f + 85a2ea1 commit 4d2b936

File tree

4 files changed

+106
-0
lines changed

4 files changed

+106
-0
lines changed

stan/math/prim/fun.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@
197197
#include <stan/math/prim/fun/log_rising_factorial.hpp>
198198
#include <stan/math/prim/fun/log_softmax.hpp>
199199
#include <stan/math/prim/fun/log_sum_exp.hpp>
200+
#include <stan/math/prim/fun/log_sum_exp_signed.hpp>
200201
#include <stan/math/prim/fun/logical_and.hpp>
201202
#include <stan/math/prim/fun/logical_eq.hpp>
202203
#include <stan/math/prim/fun/logical_gt.hpp>
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#ifndef STAN_MATH_PRIM_FUN_LOG_SUM_EXP_SIGNED_HPP
2+
#define STAN_MATH_PRIM_FUN_LOG_SUM_EXP_SIGNED_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/fun/constants.hpp>
6+
#include <stan/math/prim/fun/Eigen.hpp>
7+
#include <stan/math/prim/fun/log1p_exp.hpp>
8+
#include <cmath>
9+
#include <vector>
10+
11+
namespace stan {
12+
namespace math {
13+
14+
/**
15+
* Calculates the log sum of exponentials without overflow,
16+
* accounting for the signs of the inputs
17+
*
18+
* @tparam T1 type of the first variable
19+
* @tparam T2 type of the second variable
20+
* @param a the first variable
21+
* @param a_sign sign of the first variable
22+
* @param b the second variable
23+
* @param b_sign sign of the second variable
24+
*/
25+
template <typename T1, typename T2,
26+
require_all_stan_scalar_t<T1, T2>* = nullptr>
27+
inline std::tuple<return_type_t<T1, T2>, int> log_sum_exp_signed(const T1& a,
28+
int a_sign,
29+
const T2& b,
30+
int b_sign) {
31+
if (a_sign == b_sign) {
32+
return std::make_tuple(log_sum_exp(a, b), a_sign);
33+
}
34+
bool a_larger = (a > b);
35+
return std::make_tuple(a_larger ? log_diff_exp(a, b) : log_diff_exp(b, a),
36+
a_larger ? a_sign : b_sign);
37+
}
38+
39+
} // namespace math
40+
} // namespace stan
41+
42+
#endif
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#include <test/unit/math/test_ad.hpp>
2+
#include <limits>
3+
#include <vector>
4+
5+
TEST(mathMixScalFun, logSumExp_signed) {
6+
auto f = [](const int x1_sign, const int x2_sign) {
7+
return [=](const auto& x1, const auto& x2) {
8+
stan::return_type_t<decltype(x1), decltype(x2)> ret_val;
9+
int ret_val_sign;
10+
std::forward_as_tuple(ret_val, ret_val_sign)
11+
= stan::math::log_sum_exp_signed(x1, x1_sign, x2, x2_sign);
12+
return ret_val_sign * stan::math::exp(ret_val);
13+
};
14+
};
15+
std::vector<double> a{0.15, 0.35, 0.51, 0.65, 0.89, 1.0};
16+
std::vector<double> b{1.4, 1.2, 2.0, 3.0, 3.21, 3.4};
17+
18+
for (auto&& a_val : a) {
19+
for (auto&& b_val : b) {
20+
stan::test::expect_ad(f(1, 1), a_val, b_val);
21+
stan::test::expect_ad(f(1, -1), a_val, b_val);
22+
stan::test::expect_ad(f(-1, 1), a_val, b_val);
23+
stan::test::expect_ad(f(-1, -1), a_val, b_val);
24+
}
25+
}
26+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#include <stan/math/prim.hpp>
2+
#include <gtest/gtest.h>
3+
#include <cmath>
4+
#include <tuple>
5+
#include <vector>
6+
7+
TEST(MathFunctions, log_sum_exp_signed) {
8+
using stan::math::exp;
9+
using stan::math::log_diff_exp;
10+
using stan::math::log_sum_exp;
11+
using stan::math::log_sum_exp_signed;
12+
using stan::math::sign;
13+
14+
double a = 2.5;
15+
double b = 76.2;
16+
double exp_a = exp(a);
17+
double exp_b = exp(b);
18+
19+
double val;
20+
int val_sign;
21+
22+
std::forward_as_tuple(val, val_sign) = log_sum_exp_signed(a, 1, b, 1);
23+
24+
EXPECT_FLOAT_EQ(exp_a + exp_b, val_sign * exp(val));
25+
26+
std::forward_as_tuple(val, val_sign) = log_sum_exp_signed(a, 1, b, -1);
27+
28+
EXPECT_FLOAT_EQ(exp_a - exp_b, val_sign * exp(val));
29+
30+
std::forward_as_tuple(val, val_sign) = log_sum_exp_signed(a, -1, b, 1);
31+
32+
EXPECT_FLOAT_EQ(-exp_a + exp_b, val_sign * exp(val));
33+
34+
std::forward_as_tuple(val, val_sign) = log_sum_exp_signed(a, -1, b, -1);
35+
36+
EXPECT_FLOAT_EQ(-exp_a - exp_b, val_sign * exp(val));
37+
}

0 commit comments

Comments
 (0)