@@ -19,9 +19,9 @@ limitations under the License.
19
19
// Functor definitions for Reduction ops, must be compilable by nvcc.
20
20
21
21
#include < iostream>
22
- #include " third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23
22
#include " tensorflow/core/framework/op_kernel.h"
24
23
#include " tensorflow/core/framework/tensor_types.h"
24
+ #include " third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25
25
26
26
namespace tensorflow {
27
27
namespace functor {
@@ -58,6 +58,29 @@ struct ReduceEigenImpl {
58
58
}
59
59
};
60
60
61
+ // Specialization for BF16 Reducer to fix accuracy.
62
+ // TODO: all BF16 Reducer should have specialization to fix accuracy.
63
+ #define CASTING_SPECIALIZATION (Reducer, ScalarType, IntermediateType ) \
64
+ template <typename Device, typename OUT_T, typename IN_T, \
65
+ typename ReductionAxes> \
66
+ struct ReduceEigenImpl <Device, OUT_T, IN_T, ReductionAxes, \
67
+ Reducer<ScalarType>> { \
68
+ void operator ()(const Device& d, OUT_T out, IN_T in, \
69
+ const ReductionAxes& reduction_axes, \
70
+ const Reducer<ScalarType>& reducer) { \
71
+ static_assert (std::is_same<ScalarType, typename OUT_T::Scalar>::value, \
72
+ " " ); \
73
+ Reducer<IntermediateType> intermediate_reducer; \
74
+ auto in_as_intermediate = in.template cast <IntermediateType>(); \
75
+ out.device (d) = \
76
+ in_as_intermediate.reduce (reduction_axes, intermediate_reducer) \
77
+ .template cast <ScalarType>(); \
78
+ } \
79
+ };
80
+
81
+ CASTING_SPECIALIZATION (Eigen::internal::SumReducer, bfloat16, float );
82
+ #undef CASTING_SPECIALIZATION
83
+
61
84
template <typename Device, typename OUT_T, typename IN_T,
62
85
typename ReductionAxes, typename Scalar>
63
86
struct ReduceEigenImpl <Device, OUT_T, IN_T, ReductionAxes,
0 commit comments