Skip to content

Commit ff1081a

Browse files
committed
Fixed logaddexp for mixed nan and number operands
1 parent 80eae6e commit ff1081a

File tree

1 file changed

+13
-6
lines changed
  • dpctl/tensor/libtensor/include/kernels/elementwise_functions

1 file changed

+13
-6
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <CL/sycl.hpp>
2929
#include <cstddef>
3030
#include <cstdint>
31+
#include <limits>
3132
#include <type_traits>
3233

3334
#include "utils/offset_utils.hpp"
@@ -55,13 +56,14 @@ using dpctl::tensor::type_utils::vec_cast;
5556

5657
template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
5758
{
58-
using supports_sg_loadstore = typename std::negation<
59-
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
60-
using supports_vec = typename std::negation<
61-
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
59+
using supports_sg_loadstore = std::true_type;
60+
using supports_vec = std::true_type;
6261

6362
resT operator()(const argT1 &in1, const argT2 &in2)
6463
{
64+
if (std::isnan(in1) || std::isnan(in2)) {
65+
return std::numeric_limits<resT>::quiet_NaN();
66+
}
6567
resT max = std::max<resT>(in1, in2);
6668
resT min = std::min<resT>(in1, in2);
6769
return max + std::log1p(std::exp(min - max));
@@ -76,8 +78,13 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
7678

7779
#pragma unroll
7880
for (int i = 0; i < vec_sz; ++i) {
79-
resT max = std::max<resT>(in1[i], in2[i]);
80-
res[i] = max + std::log1p(std::exp(std::abs(diff[i])));
81+
if (std::isnan(in1[i]) || std::isnan(in2[i])) {
82+
res[i] = std::numeric_limits<resT>::quiet_NaN();
83+
}
84+
else {
85+
resT max = std::max<resT>(in1[i], in2[i]);
86+
res[i] = max + std::log1p(std::exp(std::abs(diff[i])));
87+
}
8188
}
8289

8390
return res;

0 commit comments

Comments
 (0)