Skip to content

[SYCL] Implement bf16 conversions on host device #5954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Aug 24, 2022
Merged
25 changes: 19 additions & 6 deletions sycl/include/sycl/ext/oneapi/experimental/bfloat16.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
#include <CL/__spirv/spirv_ops.hpp>
#include <sycl/half_type.hpp>

#if !defined(__SYCL_DEVICE_ONLY__)
#include <cmath>
#endif

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
namespace ext {
Expand All @@ -35,9 +39,17 @@ class bfloat16 {
return __spirv_ConvertFToBF16INTEL(a);
#endif
#else
(void)a;
throw exception{errc::feature_not_supported,
"Bfloat16 conversion is not supported on host device"};
// In case of float value is nan - propagate bfloat16's qnan
if (std::isnan(a))
return 0xffc1;
union {
uint32_t intStorage;
float floatValue;
};
floatValue = a;
// Do RNE and truncate
uint32_t roundingBias = ((intStorage >> 16) & 0x1) + 0x00007FFF;
return static_cast<uint16_t>((intStorage + roundingBias) >> 16);
#endif
}
static float to_float(const storage_t &a) {
Expand All @@ -51,9 +63,10 @@ class bfloat16 {
return __spirv_ConvertBF16ToFINTEL(a);
#endif
#else
(void)a;
throw exception{errc::feature_not_supported,
"Bfloat16 conversion is not supported on host device"};
// Shift temporary variable to silence the warning
uint32_t bits = a;
bits <<= 16;
return static_cast<float>(bits);
#endif
}

Expand Down
88 changes: 88 additions & 0 deletions sycl/test/extensions/bfloat16_host.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
//==------------ bfloat16_host.cpp - SYCL vectors test ---------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

// RUN: %clangxx -fsycl %s -o %t.out
// RUN: %RUN_ON_HOST %t.out
#include <sycl/ext/oneapi/experimental/bfloat16.hpp>
#include <sycl/sycl.hpp>

#include <cmath>
#include <cstdint>
#include <iostream>
#include <limits>
#include <string>

// Helper to convert the expected bits to float value to compare with the result
typedef union {
float Value;
struct {
uint32_t Mantissa : 23;
uint32_t Exponent : 8;
uint32_t Sign : 1;
} RawData;
} floatConvHelper;

float bitsToFloatConv(std::string Bits) {
floatConvHelper Helper;
Helper.RawData.Sign = static_cast<uint32_t>(Bits[0] - '0');
uint32_t Exponent = 0;
for (size_t I = 1; I != 9; ++I)
Exponent = Exponent + static_cast<uint32_t>(Bits[I] - '0') * pow(2, 8 - I);
Helper.RawData.Exponent = Exponent;
uint32_t Mantissa = 0;
for (size_t I = 9; I != 32; ++I)
Mantissa = Mantissa + static_cast<uint32_t>(Bits[I] - '0') * pow(2, 31 - I);
Helper.RawData.Mantissa = Mantissa;
return Helper.Value;
}

bool check_bf16_from_float(float Val, uint16_t Expected) {
uint16_t Result = sycl::ext::oneapi::experimental::bfloat16::from_float(Val);
if (Result != Expected) {
std::cout << "from_float check for Val = " << Val << " failed!\n"
<< "Expected " << Expected << " Got " << Result << "\n";
return false;
}
return true;
}

bool check_bf16_to_float(uint16_t Val, float Expected) {
float Result = sycl::ext::oneapi::experimental::bfloat16::to_float(Val);
if (Result != Expected) {
std::cout << "to_float check for Val = " << Val << " failed!\n"
<< "Expected " << Expected << " Got " << Result << "\n";
return false;
}
return true;
}

int main() {
bool Success =
check_bf16_from_float(0.0f, std::stoi("0000000000000000", nullptr, 2));
Success &=
check_bf16_from_float(42.0f, std::stoi("100001000101000", nullptr, 2));
Success &= check_bf16_from_float(std::numeric_limits<float>::min(),
std::stoi("0000000010000000", nullptr, 2));
Success &= check_bf16_from_float(std::numeric_limits<float>::max(),
std::stoi("0111111110000000", nullptr, 2));
Success &= check_bf16_from_float(std::numeric_limits<float>::quiet_NaN(),
std::stoi("1111111111000001", nullptr, 2));

Success &= check_bf16_to_float(
0, bitsToFloatConv(std::string("00000000000000000000000000000000")));
Success &= check_bf16_to_float(
1, bitsToFloatConv(std::string("01000111100000000000000000000000")));
Success &= check_bf16_to_float(
42, bitsToFloatConv(std::string("01001010001010000000000000000000")));
Success &= check_bf16_to_float(
std::numeric_limits<uint16_t>::max(),
bitsToFloatConv(std::string("01001111011111111111111100000000")));
if (!Success)
return -1;
return 0;
}