Skip to content

Commit 2f39563

Browse files
authored
Merge pull request #23 from Intel-tensorflow/tenglu/bf16/softmax
Add BF16 Softmax/SoftmaxGrad and fix accuracy issue by accum type
2 parents 3e10354 + fa901fa commit 2f39563

File tree

3 files changed

+31
-2
lines changed

3 files changed

+31
-2
lines changed

tensorflow/core/kernels/mkl_tmp_bf16_ops.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ namespace tensorflow {
5858
REGISTER_KERNEL_BUILDER( \
5959
Name("_FusedMatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), NoOp); \
6060
REGISTER_KERNEL_BUILDER( \
61-
Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), NoOp);
61+
Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), NoOp); \
62+
REGISTER_KERNEL_BUILDER( \
63+
Name("Softmax").Device(DEVICE_CPU).TypeConstraint<T>("T"), NoOp);
6264

6365
TF_CALL_bfloat16(REGISTER_CPU);
6466
#undef REGISTER_CPU

tensorflow/core/kernels/reduction_ops.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ limitations under the License.
1919
// Functor definitions for Reduction ops, must be compilable by nvcc.
2020

2121
#include <iostream>
22-
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
2322
#include "tensorflow/core/framework/op_kernel.h"
2423
#include "tensorflow/core/framework/tensor_types.h"
24+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
2525

2626
namespace tensorflow {
2727
namespace functor {
@@ -58,6 +58,29 @@ struct ReduceEigenImpl {
5858
}
5959
};
6060

61+
// Specialization for BF16 Reducer to fix accuracy.
62+
// TODO: all BF16 Reducer should have specialization to fix accuracy.
63+
#define CASTING_SPECIALIZATION(Reducer, ScalarType, IntermediateType) \
64+
template <typename Device, typename OUT_T, typename IN_T, \
65+
typename ReductionAxes> \
66+
struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes, \
67+
Reducer<ScalarType>> { \
68+
void operator()(const Device& d, OUT_T out, IN_T in, \
69+
const ReductionAxes& reduction_axes, \
70+
const Reducer<ScalarType>& reducer) { \
71+
static_assert(std::is_same<ScalarType, typename OUT_T::Scalar>::value, \
72+
""); \
73+
Reducer<IntermediateType> intermediate_reducer; \
74+
auto in_as_intermediate = in.template cast<IntermediateType>(); \
75+
out.device(d) = \
76+
in_as_intermediate.reduce(reduction_axes, intermediate_reducer) \
77+
.template cast<ScalarType>(); \
78+
} \
79+
};
80+
81+
CASTING_SPECIALIZATION(Eigen::internal::SumReducer, bfloat16, float);
82+
#undef CASTING_SPECIALIZATION
83+
6184
template <typename Device, typename OUT_T, typename IN_T,
6285
typename ReductionAxes, typename Scalar>
6386
struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes,

tensorflow/core/ops/nn_grad.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@ Status SoftmaxGrad(const AttrSlice& attrs, FunctionDef* g) {
3131
// Ret val defs
3232
{"grad_x: T"},
3333
// Attr defs
34+
#if defined(INTEL_MKL) && defined(ENABLE_INTEL_MKL_BFLOAT16)
35+
{{"T: {float, double, bfloat16}"}},
36+
#else
3437
{{"T: {float, double}"}},
38+
#endif
3539
// Nodes
3640
// Based on _SoftmaxGrad in nn_grad.py.
3741
{

0 commit comments

Comments
 (0)