Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions stan/math/prim/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@
#include <stan/math/prim/fun/log_rising_factorial.hpp>
#include <stan/math/prim/fun/log_softmax.hpp>
#include <stan/math/prim/fun/log_sum_exp.hpp>
#include <stan/math/prim/fun/log_sum_exp_signed.hpp>
#include <stan/math/prim/fun/logical_and.hpp>
#include <stan/math/prim/fun/logical_eq.hpp>
#include <stan/math/prim/fun/logical_gt.hpp>
Expand Down
42 changes: 42 additions & 0 deletions stan/math/prim/fun/log_sum_exp_signed.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#ifndef STAN_MATH_PRIM_FUN_LOG_SUM_EXP_SIGNED_HPP
#define STAN_MATH_PRIM_FUN_LOG_SUM_EXP_SIGNED_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <cmath>
#include <vector>

namespace stan {
namespace math {

/**
* Calculates the log sum of exponentials without overflow,
* accounting for the signs of the inputs
*
* @tparam T1 type of the first variable
* @tparam T2 type of the second variable
* @param a the first variable
* @param a_sign sign of the first variable
* @param b the second variable
* @param b_sign sign of the second variable
*/
template <typename T1, typename T2,
require_all_stan_scalar_t<T1, T2>* = nullptr>
inline std::tuple<return_type_t<T1, T2>, int> log_sum_exp_signed(const T1& a,
int a_sign,
const T2& b,
int b_sign) {
if (a_sign == b_sign) {
return std::make_tuple(log_sum_exp(a, b), a_sign);
}
bool a_larger = (a > b);
return std::make_tuple(a_larger ? log_diff_exp(a, b) : log_diff_exp(b, a),
a_larger ? a_sign : b_sign);
}

} // namespace math
} // namespace stan

#endif
26 changes: 26 additions & 0 deletions test/unit/math/mix/fun/log_sum_exp_signed_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include <test/unit/math/test_ad.hpp>
#include <limits>
#include <vector>

TEST(mathMixScalFun, logSumExp_signed) {
auto f = [](const int x1_sign, const int x2_sign) {
return [=](const auto& x1, const auto& x2) {
stan::return_type_t<decltype(x1), decltype(x2)> ret_val;
int ret_val_sign;
std::forward_as_tuple(ret_val, ret_val_sign)
= stan::math::log_sum_exp_signed(x1, x1_sign, x2, x2_sign);
return ret_val_sign * stan::math::exp(ret_val);
};
};
std::vector<double> a{0.15, 0.35, 0.51, 0.65, 0.89, 1.0};
std::vector<double> b{1.4, 1.2, 2.0, 3.0, 3.21, 3.4};

for (auto&& a_val : a) {
for (auto&& b_val : b) {
stan::test::expect_ad(f(1, 1), a_val, b_val);
stan::test::expect_ad(f(1, -1), a_val, b_val);
stan::test::expect_ad(f(-1, 1), a_val, b_val);
stan::test::expect_ad(f(-1, -1), a_val, b_val);
}
}
}
37 changes: 37 additions & 0 deletions test/unit/math/prim/fun/log_sum_exp_signed_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include <stan/math/prim.hpp>
#include <gtest/gtest.h>
#include <cmath>
#include <tuple>
#include <vector>

TEST(MathFunctions, log_sum_exp_signed) {
using stan::math::exp;
using stan::math::log_diff_exp;
using stan::math::log_sum_exp;
using stan::math::log_sum_exp_signed;
using stan::math::sign;

double a = 2.5;
double b = 76.2;
double exp_a = exp(a);
double exp_b = exp(b);

double val;
int val_sign;

std::forward_as_tuple(val, val_sign) = log_sum_exp_signed(a, 1, b, 1);

EXPECT_FLOAT_EQ(exp_a + exp_b, val_sign * exp(val));

std::forward_as_tuple(val, val_sign) = log_sum_exp_signed(a, 1, b, -1);

EXPECT_FLOAT_EQ(exp_a - exp_b, val_sign * exp(val));

std::forward_as_tuple(val, val_sign) = log_sum_exp_signed(a, -1, b, 1);

EXPECT_FLOAT_EQ(-exp_a + exp_b, val_sign * exp(val));

std::forward_as_tuple(val, val_sign) = log_sum_exp_signed(a, -1, b, -1);

EXPECT_FLOAT_EQ(-exp_a - exp_b, val_sign * exp(val));
}