Skip to content

Commit a202d13

Browse files
committed
logaddexp now handles both NaNs and infinities correctly per array API
1 parent ff1081a commit a202d13

File tree

1 file changed

+7
-8
lines changed
  • dpctl/tensor/libtensor/include/kernels/elementwise_functions

1 file changed

+7
-8
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
6161

6262
resT operator()(const argT1 &in1, const argT2 &in2)
6363
{
64-
if (std::isnan(in1) || std::isnan(in2)) {
65-
return std::numeric_limits<resT>::quiet_NaN();
66-
}
6764
resT max = std::max<resT>(in1, in2);
65+
if (std::isnan(max) || std::isinf(max)) {
66+
return max;
67+
}
6868
resT min = std::min<resT>(in1, in2);
6969
return max + std::log1p(std::exp(min - max));
7070
}
@@ -78,11 +78,10 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
7878

7979
#pragma unroll
8080
for (int i = 0; i < vec_sz; ++i) {
81-
if (std::isnan(in1[i]) || std::isnan(in2[i])) {
82-
res[i] = std::numeric_limits<resT>::quiet_NaN();
83-
}
84-
else {
85-
resT max = std::max<resT>(in1[i], in2[i]);
81+
resT max = std::max<resT>(in1[i], in2[i]);
82+
if (std::isnan(max) || std::isinf(max)) {
83+
res[i] = max;
84+
} else {
8685
res[i] = max + std::log1p(std::exp(std::abs(diff[i])));
8786
}
8887
}

0 commit comments

Comments
 (0)