Skip to content

Commit 2a383f1

Browse files
authored
[SYCL] Implement bf16 conversions on host device (#5954)
They are implemented in a way of RNE conversion. Signed-off-by: Dmitry Sidorov <dmitry.sidorov@intel.com>
1 parent 728e5b4 commit 2a383f1

File tree

2 files changed

+107
-6
lines changed

2 files changed

+107
-6
lines changed

sycl/include/sycl/ext/oneapi/experimental/bfloat16.hpp

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
#include <CL/__spirv/spirv_ops.hpp>
1212
#include <sycl/half_type.hpp>
1313

14+
#if !defined(__SYCL_DEVICE_ONLY__)
15+
#include <cmath>
16+
#endif
17+
1418
namespace sycl {
1519
__SYCL_INLINE_VER_NAMESPACE(_V1) {
1620
namespace ext {
@@ -35,9 +39,17 @@ class bfloat16 {
3539
return __spirv_ConvertFToBF16INTEL(a);
3640
#endif
3741
#else
38-
(void)a;
39-
throw exception{errc::feature_not_supported,
40-
"Bfloat16 conversion is not supported on host device"};
42+
// In case of float value is nan - propagate bfloat16's qnan
43+
if (std::isnan(a))
44+
return 0xffc1;
45+
union {
46+
uint32_t intStorage;
47+
float floatValue;
48+
};
49+
floatValue = a;
50+
// Do RNE and truncate
51+
uint32_t roundingBias = ((intStorage >> 16) & 0x1) + 0x00007FFF;
52+
return static_cast<uint16_t>((intStorage + roundingBias) >> 16);
4153
#endif
4254
}
4355
static float to_float(const storage_t &a) {
@@ -51,9 +63,10 @@ class bfloat16 {
5163
return __spirv_ConvertBF16ToFINTEL(a);
5264
#endif
5365
#else
54-
(void)a;
55-
throw exception{errc::feature_not_supported,
56-
"Bfloat16 conversion is not supported on host device"};
66+
// Shift temporary variable to silence the warning
67+
uint32_t bits = a;
68+
bits <<= 16;
69+
return static_cast<float>(bits);
5770
#endif
5871
}
5972

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
//==------------ bfloat16_host.cpp - SYCL vectors test ---------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
// RUN: %clangxx -fsycl %s -o %t.out
10+
// RUN: %RUN_ON_HOST %t.out
11+
#include <sycl/ext/oneapi/experimental/bfloat16.hpp>
12+
#include <sycl/sycl.hpp>
13+
14+
#include <cmath>
15+
#include <cstdint>
16+
#include <iostream>
17+
#include <limits>
18+
#include <string>
19+
20+
// Helper to convert the expected bits to float value to compare with the result
21+
typedef union {
22+
float Value;
23+
struct {
24+
uint32_t Mantissa : 23;
25+
uint32_t Exponent : 8;
26+
uint32_t Sign : 1;
27+
} RawData;
28+
} floatConvHelper;
29+
30+
float bitsToFloatConv(std::string Bits) {
31+
floatConvHelper Helper;
32+
Helper.RawData.Sign = static_cast<uint32_t>(Bits[0] - '0');
33+
uint32_t Exponent = 0;
34+
for (size_t I = 1; I != 9; ++I)
35+
Exponent = Exponent + static_cast<uint32_t>(Bits[I] - '0') * pow(2, 8 - I);
36+
Helper.RawData.Exponent = Exponent;
37+
uint32_t Mantissa = 0;
38+
for (size_t I = 9; I != 32; ++I)
39+
Mantissa = Mantissa + static_cast<uint32_t>(Bits[I] - '0') * pow(2, 31 - I);
40+
Helper.RawData.Mantissa = Mantissa;
41+
return Helper.Value;
42+
}
43+
44+
bool check_bf16_from_float(float Val, uint16_t Expected) {
45+
uint16_t Result = sycl::ext::oneapi::experimental::bfloat16::from_float(Val);
46+
if (Result != Expected) {
47+
std::cout << "from_float check for Val = " << Val << " failed!\n"
48+
<< "Expected " << Expected << " Got " << Result << "\n";
49+
return false;
50+
}
51+
return true;
52+
}
53+
54+
bool check_bf16_to_float(uint16_t Val, float Expected) {
55+
float Result = sycl::ext::oneapi::experimental::bfloat16::to_float(Val);
56+
if (Result != Expected) {
57+
std::cout << "to_float check for Val = " << Val << " failed!\n"
58+
<< "Expected " << Expected << " Got " << Result << "\n";
59+
return false;
60+
}
61+
return true;
62+
}
63+
64+
int main() {
65+
bool Success =
66+
check_bf16_from_float(0.0f, std::stoi("0000000000000000", nullptr, 2));
67+
Success &=
68+
check_bf16_from_float(42.0f, std::stoi("100001000101000", nullptr, 2));
69+
Success &= check_bf16_from_float(std::numeric_limits<float>::min(),
70+
std::stoi("0000000010000000", nullptr, 2));
71+
Success &= check_bf16_from_float(std::numeric_limits<float>::max(),
72+
std::stoi("0111111110000000", nullptr, 2));
73+
Success &= check_bf16_from_float(std::numeric_limits<float>::quiet_NaN(),
74+
std::stoi("1111111111000001", nullptr, 2));
75+
76+
Success &= check_bf16_to_float(
77+
0, bitsToFloatConv(std::string("00000000000000000000000000000000")));
78+
Success &= check_bf16_to_float(
79+
1, bitsToFloatConv(std::string("01000111100000000000000000000000")));
80+
Success &= check_bf16_to_float(
81+
42, bitsToFloatConv(std::string("01001010001010000000000000000000")));
82+
Success &= check_bf16_to_float(
83+
std::numeric_limits<uint16_t>::max(),
84+
bitsToFloatConv(std::string("01001111011111111111111100000000")));
85+
if (!Success)
86+
return -1;
87+
return 0;
88+
}

0 commit comments

Comments
 (0)