Skip to content

Move bf16 into eigen #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions tensorflow/compiler/xla/client/lib/constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) {
builder,
static_cast<Eigen::half>(Eigen::NumTraits<Eigen::half>::epsilon()));
case BF16:
return ConstantR0<bfloat16>(builder, bfloat16::epsilon());
return ConstantR0<Eigen::bfloat16>(
builder, static_cast<Eigen::bfloat16>(
Eigen::NumTraits<Eigen::bfloat16>::epsilon()));
case F32:
return ConstantR0<float>(builder, std::numeric_limits<float>::epsilon());
case F64:
Expand All @@ -70,7 +72,8 @@ XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) {
return ConstantR0<Eigen::half>(builder,
Eigen::NumTraits<Eigen::half>::lowest());
case BF16:
return ConstantR0<bfloat16>(builder, bfloat16::lowest());
return ConstantR0<Eigen::bfloat16>(
builder, Eigen::NumTraits<Eigen::bfloat16>::lowest());
case F32:
return ConstantR0<float>(builder, -std::numeric_limits<float>::max());
case F64:
Expand All @@ -86,7 +89,8 @@ XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type) {
return ConstantR0<Eigen::half>(builder,
std::numeric_limits<Eigen::half>::min());
case BF16:
return ConstantR0<bfloat16>(builder, bfloat16::min_positive_normal());
return ConstantR0<Eigen::bfloat16>(
builder, std::numeric_limits<Eigen::bfloat16>::min());
case F32:
return ConstantR0<float>(builder, std::numeric_limits<float>::min());
case F64:
Expand All @@ -108,7 +112,8 @@ XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) {
return ConstantR0<Eigen::half>(builder,
Eigen::NumTraits<Eigen::half>::highest());
case BF16:
return ConstantR0<bfloat16>(builder, bfloat16::highest());
return ConstantR0<Eigen::bfloat16>(
builder, Eigen::NumTraits<Eigen::bfloat16>::highest());
case F32:
return ConstantR0<float>(builder, std::numeric_limits<float>::max());
case F64:
Expand All @@ -125,8 +130,8 @@ XlaOp NanValue(XlaBuilder* builder, PrimitiveType type) {
return ConstantR0<Eigen::half>(
builder, Eigen::NumTraits<Eigen::half>::quiet_NaN());
case BF16:
return ConstantR0<bfloat16>(
builder, bfloat16(std::numeric_limits<float>::quiet_NaN()));
return ConstantR0<Eigen::bfloat16>(
builder, Eigen::NumTraits<Eigen::bfloat16>::quiet_NaN());
case F32:
return ConstantR0<float>(builder,
std::numeric_limits<float>::quiet_NaN());
Expand Down
26 changes: 13 additions & 13 deletions tensorflow/compiler/xla/service/hlo_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2212,8 +2212,8 @@ bool HloParserImpl::SetValueInLiteral(LocTy loc, double value, int64 index,
case F16:
return SetValueInLiteralHelper<Eigen::half>(loc, value, index, literal);
case BF16:
return SetValueInLiteralHelper<tensorflow::bfloat16>(loc, value, index,
literal);
return SetValueInLiteralHelper<Eigen::bfloat16>(loc, value, index,
literal);
case F32:
return SetValueInLiteralHelper<float>(loc, value, index, literal);
case F64:
Expand Down Expand Up @@ -2351,11 +2351,10 @@ bool HloParserImpl::ParseDenseLiteral(Literal* literal, const Shape& shape) {
auto get_index_str = [&elems_seen_per_dim](int dim) -> std::string {
std::vector<int64> elems_seen_until_dim(elems_seen_per_dim.begin(),
elems_seen_per_dim.begin() + dim);
return StrCat("[",
StrJoin(elems_seen_until_dim, ",",
[](std::string* out, const int64 num_elems) {
StrAppend(out, num_elems - 1);
}),
return StrCat("[", StrJoin(elems_seen_until_dim, ",",
[](std::string* out, const int64 num_elems) {
StrAppend(out, num_elems - 1);
}),
"]");
};

Expand Down Expand Up @@ -2523,8 +2522,10 @@ struct MinMaxFiniteValue<Eigen::half> {
};

template <>
struct MinMaxFiniteValue<bfloat16> {
static double max() { return static_cast<double>(bfloat16::highest()); }
struct MinMaxFiniteValue<Eigen::bfloat16> {
static double max() {
return static_cast<double>(Eigen::NumTraits<Eigen::bfloat16>::highest());
}
static double min() { return -max(); }
};

Expand Down Expand Up @@ -4285,10 +4286,9 @@ bool HloParserImpl::ParseSingleInstruction(HloModule* module) {
// The missing instruction hook we register creates the shaped instruction on
// the fly as a parameter and returns it.
int64 parameter_count = 0;
create_missing_instruction_ =
[this, &builder, &parameter_count](
const std::string& name,
const Shape& shape) -> std::pair<HloInstruction*, LocTy>* {
create_missing_instruction_ = [this, &builder, &parameter_count](
const std::string& name,
const Shape& shape) -> std::pair<HloInstruction*, LocTy>* {
std::string new_name = name.empty() ? StrCat("_", parameter_count) : name;
HloInstruction* parameter = builder.AddInstruction(
HloInstruction::CreateParameter(parameter_count++, shape, new_name));
Expand Down
20 changes: 13 additions & 7 deletions tensorflow/core/framework/bfloat16_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@ TEST(Bfloat16Test, FlushDenormalsToZero) {
for (float denorm = -std::numeric_limits<float>::denorm_min();
denorm < std::numeric_limits<float>::denorm_min();
denorm = std::nextafterf(denorm, 1.0f)) {
bfloat16 bf_trunc = bfloat16::truncate_to_bfloat16(denorm);
bfloat16 bf_trunc =
bfloat16(Eigen::bfloat16_impl::truncate_to_bfloat16(denorm));
ASSERT_EQ(static_cast<float>(bf_trunc), 0.0f);
if (std::signbit(denorm)) {
ASSERT_EQ(bf_trunc.value, 0x8000) << denorm;
} else {
ASSERT_EQ(bf_trunc.value, 0x0000) << denorm;
}
bfloat16 bf_round = bfloat16::round_to_bfloat16(denorm);
bfloat16 bf_round = bfloat16(denorm);
ASSERT_EQ(static_cast<float>(bf_round), 0.0f);
if (std::signbit(denorm)) {
ASSERT_EQ(bf_round.value, 0x8000) << denorm;
Expand Down Expand Up @@ -88,7 +89,8 @@ class Bfloat16Test : public ::testing::Test,
public ::testing::WithParamInterface<Bfloat16TestParam> {};

TEST_P(Bfloat16Test, TruncateTest) {
bfloat16 truncated = bfloat16::truncate_to_bfloat16((GetParam().input));
bfloat16 truncated =
bfloat16(Eigen::bfloat16_impl::truncate_to_bfloat16((GetParam().input)));

if (std::isnan(GetParam().input)) {
EXPECT_TRUE(std::isnan(float(truncated)) || std::isinf(float(truncated)));
Expand All @@ -97,7 +99,7 @@ TEST_P(Bfloat16Test, TruncateTest) {

EXPECT_EQ(GetParam().expected_truncation, float(truncated));

bfloat16 rounded = bfloat16::round_to_bfloat16((GetParam().input));
bfloat16 rounded = bfloat16((GetParam().input));
if (std::isnan(GetParam().input)) {
EXPECT_TRUE(std::isnan(float(rounded)) || std::isinf(float(rounded)));
return;
Expand Down Expand Up @@ -172,9 +174,13 @@ TEST(Bfloat16Test, Conversion) {
}

TEST(Bfloat16Test, Epsilon) {
EXPECT_LT(1.0f, static_cast<float>(bfloat16::epsilon() + bfloat16(1.0f)));
EXPECT_EQ(1.0f, static_cast<float>((bfloat16::epsilon() / bfloat16(2.0f)) +
bfloat16(1.0f)));
EXPECT_LT(1.0f,
static_cast<float>(Eigen::NumTraits<Eigen::bfloat16>::epsilon() +
bfloat16(1.0f)));
EXPECT_EQ(1.0f,
static_cast<float>((Eigen::NumTraits<Eigen::bfloat16>::epsilon() /
bfloat16(2.0f)) +
bfloat16(1.0f)));
}

TEST(Bfloat16Test, Negate) {
Expand Down
64 changes: 0 additions & 64 deletions tensorflow/core/framework/numeric_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,47 +43,7 @@ typedef Eigen::QUInt16 quint16;

} // namespace tensorflow




static inline tensorflow::bfloat16 FloatToBFloat16(float float_val) {
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
return *reinterpret_cast<tensorflow::bfloat16*>(
reinterpret_cast<uint16_t*>(&float_val));
#else
return *reinterpret_cast<tensorflow::bfloat16*>(
&(reinterpret_cast<uint16_t*>(&float_val)[1]));
#endif
}

namespace Eigen {
// TODO(xpan): We probably need to overwrite more methods to have correct eigen
// behavior. E.g. epsilon(), dummy_precision, etc. See NumTraits.h in eigen.
template <>
struct NumTraits<tensorflow::bfloat16>
: GenericNumTraits<tensorflow::bfloat16> {
enum {
IsInteger = 0,
IsSigned = 1,
RequireInitialization = 0
};
static EIGEN_STRONG_INLINE tensorflow::bfloat16 highest() {
return FloatToBFloat16(NumTraits<float>::highest());
}

static EIGEN_STRONG_INLINE tensorflow::bfloat16 lowest() {
return FloatToBFloat16(NumTraits<float>::lowest());
}

static EIGEN_STRONG_INLINE tensorflow::bfloat16 infinity() {
return FloatToBFloat16(NumTraits<float>::infinity());
}

static EIGEN_STRONG_INLINE tensorflow::bfloat16 quiet_NaN() {
return FloatToBFloat16(NumTraits<float>::quiet_NaN());
}
};

template <>
struct NumTraits<tensorflow::tstring> : GenericNumTraits<tensorflow::tstring> {
enum {
Expand All @@ -104,30 +64,6 @@ struct NumTraits<tensorflow::tstring> : GenericNumTraits<tensorflow::tstring> {
static inline tensorflow::tstring quiet_NaN();
};

using ::tensorflow::operator==;
using ::tensorflow::operator!=;

namespace numext {

template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 log(
const tensorflow::bfloat16& x) {
return static_cast<tensorflow::bfloat16>(::logf(static_cast<float>(x)));
}

template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 exp(
const tensorflow::bfloat16& x) {
return static_cast<tensorflow::bfloat16>(::expf(static_cast<float>(x)));
}

template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 abs(
const tensorflow::bfloat16& x) {
return static_cast<tensorflow::bfloat16>(::fabsf(static_cast<float>(x)));
}

} // namespace numext
} // namespace Eigen

#if defined(_MSC_VER) && !defined(__clang__)
Expand Down
48 changes: 3 additions & 45 deletions tensorflow/core/kernels/cast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_CAST_OP_H_
#define TENSORFLOW_CORE_KERNELS_CAST_OP_H_

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/bfloat16.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/types.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

// Note that the GPU cast functor templates need to be instantiated unlike the
// CPU ones, and hence their specializations are different than that for CPUs.
Expand Down Expand Up @@ -72,7 +72,7 @@ limitations under the License.
SPECIALIZE_CAST(devname, Eigen::half, float) \
SPECIALIZE_CAST(devname, Eigen::half, std::complex<double>) \
SPECIALIZE_CAST(devname, Eigen::half, std::complex<float>) \
SPECIALIZE_CAST(devname, bfloat16, float) \
SPECIALIZE_CAST(devname, Eigen::bfloat16, float) \
template <typename OUT_TYPE, typename IN_OUT> \
struct CastFunctor<devname, OUT_TYPE, IN_OUT> { \
void operator()(const devname& d, \
Expand Down Expand Up @@ -131,7 +131,7 @@ constexpr int MantissaWidth<Eigen::half>() {
}

template <>
constexpr int MantissaWidth<bfloat16>() {
constexpr int MantissaWidth<Eigen::bfloat16>() {
// Remember, there's 1 hidden bit
return 7 + 1;
}
Expand Down Expand Up @@ -278,48 +278,6 @@ template <typename From, typename To>
struct functor_traits<scalar_cast_op<std::complex<From>, std::complex<To>>>
: functor_traits_complex_impl<std::complex<From>, std::complex<To>> {};

// Specialized cast op impls for bfloat16.
template <>
struct scalar_cast_op<::tensorflow::bfloat16, float> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
typedef float result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(
const ::tensorflow::bfloat16& a) const {
float ret;
uint16_t* p = reinterpret_cast<uint16_t*>(&ret);
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
p[0] = a.value;
p[1] = 0;
#else
static_assert(::tensorflow::port::kLittleEndian,
"Not a little endian system!");
p[0] = 0;
p[1] = a.value;
#endif
return ret;
}
};

template <>
struct functor_traits<scalar_cast_op<::tensorflow::bfloat16, float>> {
enum { Cost = NumTraits<float>::AddCost, PacketAccess = false };
};

template <>
struct scalar_cast_op<float, ::tensorflow::bfloat16> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
typedef ::tensorflow::bfloat16 result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const ::tensorflow::bfloat16 operator()(
const float a) const {
return ::tensorflow::bfloat16(a);
}
};

template <>
struct functor_traits<scalar_cast_op<float, ::tensorflow::bfloat16>> {
enum { Cost = NumTraits<float>::AddCost, PacketAccess = false };
};

} // namespace internal
} // namespace Eigen

Expand Down
10 changes: 5 additions & 5 deletions tensorflow/core/kernels/sparse_matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ limitations under the License.
#include <memory>
#include <vector>

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/bfloat16.h"
#include "tensorflow/core/framework/op.h"
Expand All @@ -37,6 +36,7 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#ifdef TENSORFLOW_USE_LIBXSMM
#include "include/libxsmm_intrinsics_x86.h"
#include "include/libxsmm_malloc.h"
Expand Down Expand Up @@ -165,7 +165,7 @@ bool IsZero(T v);

template <>
ALWAYS_INLINE bool IsZero(bfloat16 v) {
return v.IsZero();
return float(v) == 0.0f;
}

template <>
Expand Down Expand Up @@ -977,9 +977,9 @@ class SparseMatMulOp : public OpKernel {
const int k2 = transpose_b_ ? b.dim_size(1) : b.dim_size(0);

OP_REQUIRES(ctx, k == k2,
errors::InvalidArgument(
"Matrix size incompatible: a: ", a.shape().DebugString(),
", b: ", b.shape().DebugString()));
errors::InvalidArgument("Matrix size incompatible: a: ",
a.shape().DebugString(), ", b: ",
b.shape().DebugString()));
Tensor* output = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({m, n}), &output));

Expand Down
10 changes: 1 addition & 9 deletions tensorflow/core/lib/bfloat16/bfloat16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,4 @@ limitations under the License.

#include "third_party/eigen3/Eigen/Core"

namespace tensorflow {

const uint16_t bfloat16::NAN_VALUE;
const uint16_t bfloat16::ZERO_VALUE;

B16_DEVICE_FUNC bfloat16::operator Eigen::half() const {
return static_cast<Eigen::half>(float(*this));
}
} // end namespace tensorflow
namespace tensorflow {} // end namespace tensorflow
Loading