Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce bfloat16 type #9067

Merged
merged 26 commits into from
Sep 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
7eacafa
introduce_bfloat16_type
clackhan Sep 7, 2022
8e111fa
storage
clackhan Sep 7, 2022
70faa94
fix compile error
clackhan Sep 8, 2022
0ea1fc5
support bfloat16 ep operator
clackhan Sep 8, 2022
c55d4f8
support create cpu bfloat tensor
clackhan Sep 8, 2022
b93f726
refine code
clackhan Sep 8, 2022
91e1ac2
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Sep 8, 2022
f3862c6
Merge branch 'master' into introduce_bfloat16_type
clackhan Sep 8, 2022
48a56a6
Merge branch 'introduce_bfloat16_type' of https://github.com/Oneflow-…
clackhan Sep 8, 2022
c2fee22
minor fix
clackhan Sep 8, 2022
47d4c8b
fix static check error
clackhan Sep 8, 2022
4e3349b
Merge branch 'master' into introduce_bfloat16_type
clackhan Sep 9, 2022
50704d7
Merge branch 'master' into introduce_bfloat16_type
clackhan Sep 13, 2022
17ef5ce
reslove comment
clackhan Sep 13, 2022
dea14c4
add more test case
clackhan Sep 13, 2022
bc1942f
fix bfloat16 numeric_limits
clackhan Sep 13, 2022
d4d30ec
Merge branch 'master' into introduce_bfloat16_type
clackhan Sep 14, 2022
5eccff8
Merge branch 'master' into introduce_bfloat16_type
clackhan Sep 14, 2022
96481d6
Merge branch 'master' into introduce_bfloat16_type
clackhan Sep 15, 2022
b2d8534
fix error
clackhan Sep 15, 2022
a807b45
Merge branch 'introduce_bfloat16_type' of https://github.com/Oneflow-…
clackhan Sep 15, 2022
319273c
Merge branch 'master' into introduce_bfloat16_type
clackhan Sep 15, 2022
3c5af4b
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Sep 15, 2022
48c8d2b
Merge branch 'introduce_bfloat16_type' of https://github.com/Oneflow-…
clackhan Sep 15, 2022
87e66fc
Merge branch 'master' into introduce_bfloat16_type
clackhan Sep 19, 2022
2af9f97
Merge branch 'master' into introduce_bfloat16_type
mergify[bot] Sep 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions oneflow/api/python/utils/tensor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,11 @@ Maybe<Tensor> MakeLocalTensorFromData(PyObject* data, const Optional<Symbol<DTyp
const bool requires_grad, const bool pin_memory) {
bool is_bfloat16_dtype = dtype ? JUST(dtype)->data_type() == DataType::kBFloat16 : false;
bool is_cuda_device = device ? JUST(device)->enum_type() == DeviceType::kCUDA : false;
if (is_bfloat16_dtype && !is_cuda_device) {
return Error::RuntimeError() << "Can build a bfloat16 tensor on cpu.";
if (is_bfloat16_dtype && is_cuda_device) {
#if CUDA_VERSION < 11000
return Error::RuntimeError()
<< "Cannot create a bfloat16 tensor on gpu under cuda version: 11000";
#endif // CUDA_VERSION >= 11000
}
PyObject* array = NULL;
PyArray_Descr* np_dtype =
Expand Down
314 changes: 314 additions & 0 deletions oneflow/core/common/bfloat16.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,314 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_BFLOAT16_H_
#define ONEFLOW_CORE_COMMON_BFLOAT16_H_

#include <stdint.h>
#include <limits>
#include <cmath>
#include <cstring>

namespace oneflow {

#if defined(__CUDACC__)
#define OF_DEVICE_FUNCTION __device__ __host__ __forceinline__
#else
#define OF_DEVICE_FUNCTION inline
#endif

struct alignas(2) bfloat16 {
uint16_t x;

bfloat16() = default;
bfloat16(const bfloat16& o) = default;
bfloat16& operator=(const bfloat16& o) = default;
bfloat16(bfloat16&& o) = default;
bfloat16& operator=(bfloat16&& o) = default;
~bfloat16() = default;

struct from_bits_t {};
static constexpr inline from_bits_t from_bits() { return from_bits_t(); }

constexpr inline bfloat16(unsigned short bits, from_bits_t) : x(bits){};

// reference: pytorch/c10/util/BFloat16.h
// https://github.com/pytorch/pytorch/blob/release/1.12/c10/util/BFloat16.h
bfloat16(float value) {
if (std::isnan(value)) {
x = 0x7FC0;
} else {
union {
uint32_t U32;
float F32;
};

F32 = value;
uint32_t rounding_bias = ((U32 >> 16) & 1) + 0x7FFFU;
x = static_cast<uint16_t>((U32 + rounding_bias) >> 16);
}
}

inline operator float() const {
float res = 0;
uint32_t tmp = x;
tmp <<= 16;
std::memcpy(&res, &tmp, sizeof(tmp));
return res;
}

inline bool operator==(const bfloat16& other) const { return x == other.x; }

inline explicit operator bool() const { return (x & 0x7fff) != 0; }

inline explicit operator int8_t() const { return static_cast<int8_t>(static_cast<float>(*this)); }

inline explicit operator uint8_t() const {
return static_cast<uint8_t>(static_cast<float>(*this));
}

inline explicit operator int16_t() const {
return static_cast<int16_t>(static_cast<float>(*this));
}

inline explicit operator uint16_t() const {
return static_cast<uint16_t>(static_cast<float>(*this));
}

inline explicit operator int32_t() const {
return static_cast<int32_t>(static_cast<float>(*this));
}

inline explicit operator uint32_t() const {
return static_cast<uint32_t>(static_cast<float>(*this));
}

inline explicit operator int64_t() const {
return static_cast<int64_t>(static_cast<float>(*this));
}

inline explicit operator uint64_t() const {
return static_cast<uint64_t>(static_cast<float>(*this));
}

inline explicit operator double() const { return static_cast<double>(static_cast<float>(*this)); }
};

// Arithmetic

inline bfloat16 operator+(const bfloat16& a, const bfloat16& b) {
return static_cast<float>(a) + static_cast<float>(b);
}

inline bfloat16 operator-(const bfloat16& a, const bfloat16& b) {
return static_cast<float>(a) - static_cast<float>(b);
}

inline bfloat16 operator*(const bfloat16& a, const bfloat16& b) {
return static_cast<float>(a) * static_cast<float>(b);
}

inline bfloat16 operator/(const bfloat16& a, const bfloat16& b) {
return static_cast<float>(a) / static_cast<float>(b);
}

inline bfloat16 operator-(const bfloat16& a) {
bfloat16 output;
output.x = a.x ^ 0x8000U;
return output;
}

inline bfloat16& operator+=(bfloat16& a, const bfloat16& b) {
a = a + b;
return a;
}

inline bfloat16& operator-=(bfloat16& a, const bfloat16& b) {
a = a - b;
return a;
}

inline bfloat16& operator*=(bfloat16& a, const bfloat16& b) {
a = a * b;
return a;
}

inline bfloat16& operator/=(bfloat16& a, const bfloat16& b) {
a = a / b;
return a;
}

inline bfloat16& operator|(bfloat16& a, const bfloat16& b) {
a.x = a.x | b.x;
return a;
}

inline bfloat16& operator^(bfloat16& a, const bfloat16& b) {
a.x = a.x ^ b.x;
return a;
}

inline bfloat16& operator&(bfloat16& a, const bfloat16& b) {
a.x = a.x & b.x;
return a;
}

// Arithmetic with floats

inline float operator+(bfloat16 a, float b) { return static_cast<float>(a) + b; }
inline float operator-(bfloat16 a, float b) { return static_cast<float>(a) - b; }
inline float operator*(bfloat16 a, float b) { return static_cast<float>(a) * b; }
inline float operator/(bfloat16 a, float b) { return static_cast<float>(a) / b; }

inline float operator+(float a, bfloat16 b) { return a + static_cast<float>(b); }
inline float operator-(float a, bfloat16 b) { return a - static_cast<float>(b); }
inline float operator*(float a, bfloat16 b) { return a * static_cast<float>(b); }
inline float operator/(float a, bfloat16 b) { return a / static_cast<float>(b); }

inline float& operator+=(float& a, const bfloat16& b) { return a += static_cast<float>(b); }
inline float& operator-=(float& a, const bfloat16& b) { return a -= static_cast<float>(b); }
inline float& operator*=(float& a, const bfloat16& b) { return a *= static_cast<float>(b); }
inline float& operator/=(float& a, const bfloat16& b) { return a /= static_cast<float>(b); }

// Arithmetic with doubles

inline double operator+(bfloat16 a, double b) { return static_cast<double>(a) + b; }
inline double operator-(bfloat16 a, double b) { return static_cast<double>(a) - b; }
inline double operator*(bfloat16 a, double b) { return static_cast<double>(a) * b; }
inline double operator/(bfloat16 a, double b) { return static_cast<double>(a) / b; }

inline double operator+(double a, bfloat16 b) { return a + static_cast<double>(b); }
inline double operator-(double a, bfloat16 b) { return a - static_cast<double>(b); }
inline double operator*(double a, bfloat16 b) { return a * static_cast<double>(b); }
inline double operator/(double a, bfloat16 b) { return a / static_cast<double>(b); }

// Arithmetic with int32_t

inline bfloat16 operator+(bfloat16 a, int32_t b) { return a + static_cast<bfloat16>(b); }
inline bfloat16 operator-(bfloat16 a, int32_t b) { return a - static_cast<bfloat16>(b); }
inline bfloat16 operator*(bfloat16 a, int32_t b) { return a * static_cast<bfloat16>(b); }
inline bfloat16 operator/(bfloat16 a, int32_t b) { return a / static_cast<bfloat16>(b); }

inline bfloat16 operator+(int32_t a, bfloat16 b) { return static_cast<bfloat16>(a) + b; }
inline bfloat16 operator-(int32_t a, bfloat16 b) { return static_cast<bfloat16>(a) - b; }
inline bfloat16 operator*(int32_t a, bfloat16 b) { return static_cast<bfloat16>(a) * b; }
inline bfloat16 operator/(int32_t a, bfloat16 b) { return static_cast<bfloat16>(a) / b; }

// Arithmetic with int64_t

inline bfloat16 operator+(bfloat16 a, int64_t b) { return a + static_cast<bfloat16>(b); }
inline bfloat16 operator-(bfloat16 a, int64_t b) { return a - static_cast<bfloat16>(b); }
inline bfloat16 operator*(bfloat16 a, int64_t b) { return a * static_cast<bfloat16>(b); }
inline bfloat16 operator/(bfloat16 a, int64_t b) { return a / static_cast<bfloat16>(b); }

inline bfloat16 operator+(int64_t a, bfloat16 b) { return static_cast<bfloat16>(a) + b; }
inline bfloat16 operator-(int64_t a, bfloat16 b) { return static_cast<bfloat16>(a) - b; }
inline bfloat16 operator*(int64_t a, bfloat16 b) { return static_cast<bfloat16>(a) * b; }
inline bfloat16 operator/(int64_t a, bfloat16 b) { return static_cast<bfloat16>(a) / b; }

// Comparison operators

inline bool operator>(bfloat16& lhs, bfloat16& rhs) {
return static_cast<float>(lhs) > static_cast<float>(rhs);
}

inline bool operator>=(bfloat16& lhs, bfloat16& rhs) {
return static_cast<float>(lhs) >= static_cast<float>(rhs);
}

inline bool operator<(bfloat16& lhs, bfloat16& rhs) {
return static_cast<float>(lhs) < static_cast<float>(rhs);
}

inline bool operator<=(bfloat16& lhs, bfloat16& rhs) {
return static_cast<float>(lhs) <= static_cast<float>(rhs);
}

inline bool operator==(bfloat16& lhs, bfloat16& rhs) {
return static_cast<float>(lhs) == static_cast<float>(rhs);
}

inline bool operator!=(bfloat16& lhs, bfloat16& rhs) {
return static_cast<float>(lhs) != static_cast<float>(rhs);
}

} // namespace oneflow

namespace std {

inline bool isnan(const oneflow::bfloat16& value) { return (value.x & 0x7FFFU) > 0x07F80U; }

inline bool isinf(const oneflow::bfloat16& value) { return value.x == 0x07F80U; }

inline bool isfinite(const oneflow::bfloat16& value) { return !isinf(value) && !isnan(value); }

template<>
class numeric_limits<oneflow::bfloat16> {
public:
static constexpr bool is_signed = true;
static constexpr bool is_specialized = true;
static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = true;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = true;
static constexpr auto has_denorm = numeric_limits<float>::has_denorm;
static constexpr auto has_denorm_loss = numeric_limits<float>::has_denorm_loss;
static constexpr auto round_style = numeric_limits<float>::round_style;
static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;
static constexpr int digits = 8;
static constexpr int digits10 = 2;
static constexpr int max_digits10 = 4;
static constexpr int radix = 2;
static constexpr int min_exponent = -125;
static constexpr int min_exponent10 = -37;
static constexpr int max_exponent = 128;
static constexpr int max_exponent10 = 38;
static constexpr auto traps = numeric_limits<float>::traps;
static constexpr auto tinyness_before = numeric_limits<float>::tinyness_before;
static constexpr oneflow::bfloat16 min() {
return oneflow::bfloat16(0x0080U, oneflow::bfloat16::from_bits());
}
static constexpr oneflow::bfloat16 lowest() {
return oneflow::bfloat16(0xFF7FU, oneflow::bfloat16::from_bits());
}
static constexpr oneflow::bfloat16 max() {
return oneflow::bfloat16(0x7F7FU, oneflow::bfloat16::from_bits());
}
static constexpr oneflow::bfloat16 epsilon() {
return oneflow::bfloat16(0x3C00U, oneflow::bfloat16::from_bits());
}
static constexpr oneflow::bfloat16 round_error() {
return oneflow::bfloat16(0x3F00U, oneflow::bfloat16::from_bits());
}
static constexpr oneflow::bfloat16 infinity() {
return oneflow::bfloat16(0x7F80U, oneflow::bfloat16::from_bits());
}
static constexpr oneflow::bfloat16 quiet_NaN() {
return oneflow::bfloat16(0x7FC0U, oneflow::bfloat16::from_bits());
}
static constexpr oneflow::bfloat16 signaling_NaN() {
return oneflow::bfloat16(0x7F80U, oneflow::bfloat16::from_bits());
}
static constexpr oneflow::bfloat16 denorm_min() {
return oneflow::bfloat16(0x0001U, oneflow::bfloat16::from_bits());
}
};

} // namespace std

#endif // ONEFLOW_CORE_COMMON_BFLOAT16_H_
66 changes: 66 additions & 0 deletions oneflow/core/common/bfloat16_math.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_BFLOAT16_MATH_H_
#define ONEFLOW_CORE_COMMON_BFLOAT16_MATH_H_

#include "oneflow/core/common/bfloat16.h"

namespace std {

// reference: pytorch/c10/util/BFloat16-math.h
// https://github.com/pytorch/pytorch/blob/release/1.12/c10/util/BFloat16-math.h
inline oneflow::bfloat16 acos(oneflow::bfloat16 a) { return std::acos(static_cast<float>(a)); }
inline oneflow::bfloat16 asin(oneflow::bfloat16 a) { return std::asin(static_cast<float>(a)); }
inline oneflow::bfloat16 atan(oneflow::bfloat16 a) { return std::atan(static_cast<float>(a)); }
inline oneflow::bfloat16 erf(oneflow::bfloat16 a) { return std::erf(static_cast<float>(a)); }
inline oneflow::bfloat16 erfc(oneflow::bfloat16 a) { return std::erfc(static_cast<float>(a)); }
inline oneflow::bfloat16 exp(oneflow::bfloat16 a) { return std::exp(static_cast<float>(a)); }
inline oneflow::bfloat16 expm1(oneflow::bfloat16 a) { return std::expm1(static_cast<float>(a)); }
inline oneflow::bfloat16 log(oneflow::bfloat16 a) { return std::log(static_cast<float>(a)); }
inline oneflow::bfloat16 log10(oneflow::bfloat16 a) { return std::log10(static_cast<float>(a)); }
inline oneflow::bfloat16 log1p(oneflow::bfloat16 a) { return std::log1p(static_cast<float>(a)); }
inline oneflow::bfloat16 log2(oneflow::bfloat16 a) { return std::log2(static_cast<float>(a)); }
inline oneflow::bfloat16 ceil(oneflow::bfloat16 a) { return std::ceil(static_cast<float>(a)); }
inline oneflow::bfloat16 cos(oneflow::bfloat16 a) { return std::cos(static_cast<float>(a)); }
inline oneflow::bfloat16 floor(oneflow::bfloat16 a) { return std::floor(static_cast<float>(a)); }
inline oneflow::bfloat16 nearbyint(oneflow::bfloat16 a) {
return std::nearbyint(static_cast<float>(a));
}
inline oneflow::bfloat16 sin(oneflow::bfloat16 a) { return std::sin(static_cast<float>(a)); }
inline oneflow::bfloat16 tan(oneflow::bfloat16 a) { return std::tan(static_cast<float>(a)); }
inline oneflow::bfloat16 sinh(oneflow::bfloat16 a) { return std::sinh(static_cast<float>(a)); }
inline oneflow::bfloat16 cosh(oneflow::bfloat16 a) { return std::cosh(static_cast<float>(a)); }
inline oneflow::bfloat16 tanh(oneflow::bfloat16 a) { return std::tanh(static_cast<float>(a)); }
inline oneflow::bfloat16 trunc(oneflow::bfloat16 a) { return std::trunc(static_cast<float>(a)); }
inline oneflow::bfloat16 lgamma(oneflow::bfloat16 a) { return std::lgamma(static_cast<float>(a)); }
inline oneflow::bfloat16 sqrt(oneflow::bfloat16 a) { return std::sqrt(static_cast<float>(a)); }
inline oneflow::bfloat16 rsqrt(oneflow::bfloat16 a) {
return 1.0 / std::sqrt(static_cast<float>(a));
}
inline oneflow::bfloat16 abs(oneflow::bfloat16 a) { return std::abs(static_cast<float>(a)); }
inline oneflow::bfloat16 pow(oneflow::bfloat16 a, double b) {
return std::pow(static_cast<float>(a), b);
}
inline oneflow::bfloat16 pow(oneflow::bfloat16 a, oneflow::bfloat16 b) {
return std::pow(static_cast<float>(a), static_cast<float>(b));
}
inline oneflow::bfloat16 fmod(oneflow::bfloat16 a, oneflow::bfloat16 b) {
return std::fmod(static_cast<float>(a), static_cast<float>(b));
}

} // namespace std

#endif // ONEFLOW_CORE_COMMON_BFLOAT16_MATH_H_
Loading