@@ -69,11 +69,17 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
69
69
const sycl::vec<argT2, vec_sz> &in2)
70
70
{
71
71
sycl::vec<resT, vec_sz> res;
72
- auto diff = in1 - in2;
72
+ auto diff = in1 - in2; // take advantange of faster vec arithmetic
73
73
74
74
#pragma unroll
75
75
for (int i = 0 ; i < vec_sz; ++i) {
76
- res[i] = impl<resT>(in1[i], in2[i]);
76
+ if (std::isfinite (diff[i])) {
77
+ res[i] = std::max<resT>(in1[i], in2[i]) +
78
+ impl_finite<resT>(-std::abs (diff[i]));
79
+ }
80
+ else {
81
+ res[i] = impl<resT>(in1[i], in2[i]);
82
+ }
77
83
}
78
84
79
85
return res;
@@ -82,19 +88,28 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
82
88
private:
83
89
template <typename T> T impl (T const &in1, T const &in2)
84
90
{
85
- T max = std::max<T> (in1, in2);
86
- if ( std::isnan (max)) {
87
- return std::numeric_limits<T>:: quiet_NaN () ;
91
+ if (in1 == in2) { // handle signed infinities
92
+ const T log2 = std::log ( T ( 2 ));
93
+ return in1 + log2 ;
88
94
}
89
95
else {
90
- if (std::isinf (max)) {
91
- // if both args are -inf, and hence max is -inf
92
- // the result is -inf as well
93
- return max;
96
+ const T tmp = in1 - in2;
97
+ if (tmp > 0 ) {
98
+ return in1 + std::log1p (std::exp (-tmp));
99
+ }
100
+ else if (tmp <= 0 ) {
101
+ return in2 + std::log1p (std::exp (tmp));
102
+ }
103
+ else {
104
+ return std::numeric_limits<T>::quiet_NaN ();
94
105
}
95
106
}
96
- T min = std::min<T>(in1, in2);
97
- return max + std::log1p (std::exp (min - max));
107
+ }
108
+
109
+ template <typename T> T impl_finite (T const &in)
110
+ {
111
+ return (in > 0 ) ? (in + std::log1p (std::exp (-in)))
112
+ : std::log1p (std::exp (in));
98
113
}
99
114
};
100
115
0 commit comments