Skip to content

Commit ebd1faf

Browse files
Fixed log-add-exp per review feedback
1 parent 8343edc commit ebd1faf

File tree

1 file changed

+26
-11
lines changed
  • dpctl/tensor/libtensor/include/kernels/elementwise_functions

1 file changed

+26
-11
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,17 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
6969
const sycl::vec<argT2, vec_sz> &in2)
7070
{
7171
sycl::vec<resT, vec_sz> res;
72-
auto diff = in1 - in2;
72+
auto diff = in1 - in2; // take advantange of faster vec arithmetic
7373

7474
#pragma unroll
7575
for (int i = 0; i < vec_sz; ++i) {
76-
res[i] = impl<resT>(in1[i], in2[i]);
76+
if (std::isfinite(diff[i])) {
77+
res[i] = std::max<resT>(in1[i], in2[i]) +
78+
impl_finite<resT>(-std::abs(diff[i]));
79+
}
80+
else {
81+
res[i] = impl<resT>(in1[i], in2[i]);
82+
}
7783
}
7884

7985
return res;
@@ -82,19 +88,28 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
8288
private:
8389
template <typename T> T impl(T const &in1, T const &in2)
8490
{
85-
T max = std::max<T>(in1, in2);
86-
if (std::isnan(max)) {
87-
return std::numeric_limits<T>::quiet_NaN();
91+
if (in1 == in2) { // handle signed infinities
92+
const T log2 = std::log(T(2));
93+
return in1 + log2;
8894
}
8995
else {
90-
if (std::isinf(max)) {
91-
// if both args are -inf, and hence max is -inf
92-
// the result is -inf as well
93-
return max;
96+
const T tmp = in1 - in2;
97+
if (tmp > 0) {
98+
return in1 + std::log1p(std::exp(-tmp));
99+
}
100+
else if (tmp <= 0) {
101+
return in2 + std::log1p(std::exp(tmp));
102+
}
103+
else {
104+
return std::numeric_limits<T>::quiet_NaN();
94105
}
95106
}
96-
T min = std::min<T>(in1, in2);
97-
return max + std::log1p(std::exp(min - max));
107+
}
108+
109+
template <typename T> T impl_finite(T const &in)
110+
{
111+
return (in > 0) ? (in + std::log1p(std::exp(-in)))
112+
: std::log1p(std::exp(in));
98113
}
99114
};
100115

0 commit comments

Comments
 (0)