@@ -69,11 +69,16 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
69
69
const sycl::vec<argT2, vec_sz> &in2)
70
70
{
71
71
sycl::vec<resT, vec_sz> res;
72
- auto diff = in1 - in2;
72
+ auto diff = in1 - in2; // take advantange of faster vec arithmetic
73
73
74
74
#pragma unroll
75
75
for (int i = 0 ; i < vec_sz; ++i) {
76
- res[i] = impl<resT>(in1[i], in2[i]);
76
+ if (std::isfinite (diff[i])) {
77
+ res[i] = in2[i] + impl_finite<resT>(diff[i]);
78
+ }
79
+ else {
80
+ res[i] = impl<resT>(in1[i], in2[i]);
81
+ }
77
82
}
78
83
79
84
return res;
@@ -82,19 +87,28 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
82
87
private:
83
88
template <typename T> T impl (T const &in1, T const &in2)
84
89
{
85
- T max = std::max<T> (in1, in2);
86
- if ( std::isnan (max)) {
87
- return std::numeric_limits<T>:: quiet_NaN () ;
90
+ if (in1 == in2) { // handle signed infinities
91
+ const T log2 = std::log ( T ( 2 ));
92
+ return in1 + log2 ;
88
93
}
89
94
else {
90
- if (std::isinf (max)) {
91
- // if both args are -inf, and hence max is -inf
92
- // the result is -inf as well
93
- return max;
95
+ const T tmp = in1 - in2;
96
+ if (tmp > 0 ) {
97
+ return in1 + std::log1p (std::exp (-tmp));
98
+ }
99
+ else if (tmp <= 0 ) {
100
+ return in2 + std::log1p (std::exp (tmp));
101
+ }
102
+ else {
103
+ return std::numeric_limits<T>::quiet_NaN ();
94
104
}
95
105
}
96
- T min = std::min<T>(in1, in2);
97
- return max + std::log1p (std::exp (min - max));
106
+ }
107
+
108
+ template <typename T> T impl_finite (T const &in)
109
+ {
110
+ return (in > 0 ) ? (in + std::log1p (std::exp (-in)))
111
+ : std::log1p (std::exp (in));
98
112
}
99
113
};
100
114
0 commit comments