28
28
#include < CL/sycl.hpp>
29
29
#include < cstddef>
30
30
#include < cstdint>
31
+ #include < limits>
31
32
#include < type_traits>
32
33
33
34
#include " utils/offset_utils.hpp"
@@ -55,13 +56,14 @@ using dpctl::tensor::type_utils::vec_cast;
55
56
56
57
template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
57
58
{
58
- using supports_sg_loadstore = typename std::negation<
59
- std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
60
- using supports_vec = typename std::negation<
61
- std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
59
+ using supports_sg_loadstore = std::true_type;
60
+ using supports_vec = std::true_type;
62
61
63
62
resT operator ()(const argT1 &in1, const argT2 &in2)
64
63
{
64
+ if (std::isnan (in1) || std::isnan (in2)) {
65
+ return std::numeric_limits<resT>::quiet_NaN ();
66
+ }
65
67
resT max = std::max<resT>(in1, in2);
66
68
resT min = std::min<resT>(in1, in2);
67
69
return max + std::log1p (std::exp (min - max));
@@ -76,8 +78,13 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
76
78
77
79
#pragma unroll
78
80
for (int i = 0 ; i < vec_sz; ++i) {
79
- resT max = std::max<resT>(in1[i], in2[i]);
80
- res[i] = max + std::log1p (std::exp (std::abs (diff[i])));
81
+ if (std::isnan (in1[i]) || std::isnan (in2[i])) {
82
+ res[i] = std::numeric_limits<resT>::quiet_NaN ();
83
+ }
84
+ else {
85
+ resT max = std::max<resT>(in1[i], in2[i]);
86
+ res[i] = max + std::log1p (std::exp (std::abs (diff[i])));
87
+ }
81
88
}
82
89
83
90
return res;
0 commit comments