Skip to content

Commit 739475d

Browse files
Fixed log-add-exp per review feedback
1 parent 8343edc commit 739475d

File tree

1 file changed

+25
-11
lines changed
  • dpctl/tensor/libtensor/include/kernels/elementwise_functions

1 file changed

+25
-11
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,16 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
6969
const sycl::vec<argT2, vec_sz> &in2)
7070
{
7171
sycl::vec<resT, vec_sz> res;
72-
auto diff = in1 - in2;
72+
auto diff = in1 - in2; // take advantange of faster vec arithmetic
7373

7474
#pragma unroll
7575
for (int i = 0; i < vec_sz; ++i) {
76-
res[i] = impl<resT>(in1[i], in2[i]);
76+
if (std::isfinite(diff[i])) {
77+
res[i] = in2[i] + impl_finite<resT>(diff[i]);
78+
}
79+
else {
80+
res[i] = impl<resT>(in[i], in[2]);
81+
}
7782
}
7883

7984
return res;
@@ -82,19 +87,28 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
8287
private:
8388
template <typename T> T impl(T const &in1, T const &in2)
8489
{
85-
T max = std::max<T>(in1, in2);
86-
if (std::isnan(max)) {
87-
return std::numeric_limits<T>::quiet_NaN();
90+
if (in1 == in2) { // handle signed infinities
91+
const T log2 = std::log(T(2));
92+
return in1 + log2;
8893
}
8994
else {
90-
if (std::isinf(max)) {
91-
// if both args are -inf, and hence max is -inf
92-
// the result is -inf as well
93-
return max;
95+
const T tmp = in1 - in2;
96+
if (tmp > 0) {
97+
return in1 + std::log1p(std::exp(-tmp));
98+
}
99+
else if (tmp <= 0) {
100+
return in2 + std::log1p(std::exp(tmp));
101+
}
102+
else {
103+
return std::numeric_limits<T>::quiet_NaN();
94104
}
95105
}
96-
T min = std::min<T>(in1, in2);
97-
return max + std::log1p(std::exp(min - max));
106+
}
107+
108+
template <typename T> T impl_finite(T const &in)
109+
{
110+
return (in > 0) ? (in + std::log1p(std::exp(-in)))
111+
: std::log1p(std::exp(in));
98112
}
99113
};
100114

0 commit comments

Comments
 (0)