Skip to content

Commit 12af248

Browse files
[SYCL] Reintroduce experimental bfloat16 math functions (#7567)
#6524 accidentally removed the experimental bfloat16 math functions while moving bfloat16 out of the experimental namespace. This commit reintroduces these in the bfloat16_math.hpp header file. Changes to sub_group.hpp are to resolve detail namespace ambiguities are are NFC. Signed-off-by: Larsen, Steffen <steffen.larsen@intel.com>
1 parent 334d0e9 commit 12af248

File tree

4 files changed

+281
-52
lines changed

4 files changed

+281
-52
lines changed

sycl/include/sycl/ext/oneapi/bfloat16.hpp

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,21 @@ __SYCL_INLINE_VER_NAMESPACE(_V1) {
2525
namespace ext {
2626
namespace oneapi {
2727

28+
class bfloat16;
29+
30+
namespace detail {
31+
using Bfloat16StorageT = uint16_t;
32+
Bfloat16StorageT bfloat16ToBits(const bfloat16 &Value);
33+
bfloat16 bitsToBfloat16(const Bfloat16StorageT Value);
34+
} // namespace detail
35+
2836
class bfloat16 {
29-
using storage_t = uint16_t;
30-
storage_t value;
37+
detail::Bfloat16StorageT value;
38+
39+
friend inline detail::Bfloat16StorageT
40+
detail::bfloat16ToBits(const bfloat16 &Value);
41+
friend inline bfloat16
42+
detail::bitsToBfloat16(const detail::Bfloat16StorageT Value);
3143

3244
public:
3345
bfloat16() = default;
@@ -36,7 +48,7 @@ class bfloat16 {
3648

3749
private:
3850
// Explicit conversion functions
39-
static storage_t from_float(const float &a) {
51+
static detail::Bfloat16StorageT from_float(const float &a) {
4052
#if defined(__SYCL_DEVICE_ONLY__)
4153
#if defined(__NVPTX__)
4254
#if (__CUDA_ARCH__ >= 800)
@@ -72,7 +84,7 @@ class bfloat16 {
7284
#endif
7385
}
7486

75-
static float to_float(const storage_t &a) {
87+
static float to_float(const detail::Bfloat16StorageT &a) {
7688
#if defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__)
7789
return __devicelib_ConvertBF16ToFINTEL(a);
7890
#else
@@ -85,12 +97,6 @@ class bfloat16 {
8597
#endif
8698
}
8799

88-
static bfloat16 from_bits(const storage_t &a) {
89-
bfloat16 res;
90-
res.value = a;
91-
return res;
92-
}
93-
94100
public:
95101
// Implicit conversion from float to bfloat16
96102
bfloat16(const float &a) { value = from_float(a); }
@@ -122,7 +128,7 @@ class bfloat16 {
122128
#if defined(__SYCL_DEVICE_ONLY__)
123129
#if defined(__NVPTX__)
124130
#if (__CUDA_ARCH__ >= 800)
125-
return from_bits(__nvvm_neg_bf16(lhs.value));
131+
return detail::bitsToBfloat16(__nvvm_neg_bf16(lhs.value));
126132
#else
127133
return -to_float(lhs.value);
128134
#endif
@@ -203,6 +209,23 @@ class bfloat16 {
203209
// for floating-point types.
204210
};
205211

212+
namespace detail {
213+
214+
// Helper function for getting the internal representation of a bfloat16.
215+
inline Bfloat16StorageT bfloat16ToBits(const bfloat16 &Value) {
216+
return Value.value;
217+
}
218+
219+
// Helper function for creating a float16 from a value with the same type as the
220+
// internal representation.
221+
inline bfloat16 bitsToBfloat16(const Bfloat16StorageT Value) {
222+
bfloat16 res;
223+
res.value = Value;
224+
return res;
225+
}
226+
227+
} // namespace detail
228+
206229
} // namespace oneapi
207230
} // namespace ext
208231

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
//==-------- bfloat16_math.hpp - SYCL bloat16 math functions ---------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#pragma once
10+
11+
#include <sycl/detail/defines_elementary.hpp>
12+
#include <sycl/exception.hpp>
13+
#include <sycl/ext/oneapi/bfloat16.hpp>
14+
#include <sycl/marray.hpp>
15+
16+
#include <cstring>
17+
#include <tuple>
18+
#include <type_traits>
19+
20+
namespace sycl {
21+
__SYCL_INLINE_VER_NAMESPACE(_V1) {
22+
namespace ext {
23+
namespace oneapi {
24+
namespace experimental {
25+
26+
namespace detail {
27+
template <size_t N>
28+
uint32_t to_uint32_t(sycl::marray<bfloat16, N> x, size_t start) {
29+
uint32_t res;
30+
std::memcpy(&res, &x[start], sizeof(uint32_t));
31+
return res;
32+
}
33+
} // namespace detail
34+
35+
template <typename T>
36+
std::enable_if_t<std::is_same<T, bfloat16>::value, T> fabs(T x) {
37+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
38+
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
39+
return oneapi::detail::bitsToBfloat16(__clc_fabs(XBits));
40+
#else
41+
std::ignore = x;
42+
throw runtime_error("bfloat16 is not currently supported on the host device.",
43+
PI_ERROR_INVALID_DEVICE);
44+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
45+
}
46+
47+
template <size_t N>
48+
sycl::marray<bfloat16, N> fabs(sycl::marray<bfloat16, N> x) {
49+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
50+
sycl::marray<bfloat16, N> res;
51+
52+
for (size_t i = 0; i < N / 2; i++) {
53+
auto partial_res = __clc_fabs(detail::to_uint32_t(x, i * 2));
54+
std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
55+
}
56+
57+
if (N % 2) {
58+
oneapi::detail::Bfloat16StorageT XBits =
59+
oneapi::detail::bfloat16ToBits(x[N - 1]);
60+
res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fabs(XBits));
61+
}
62+
return res;
63+
#else
64+
std::ignore = x;
65+
throw runtime_error("bfloat16 is not currently supported on the host device.",
66+
PI_ERROR_INVALID_DEVICE);
67+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
68+
}
69+
70+
template <typename T>
71+
std::enable_if_t<std::is_same<T, bfloat16>::value, T> fmin(T x, T y) {
72+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
73+
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
74+
oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y);
75+
return oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits));
76+
#else
77+
std::ignore = x;
78+
std::ignore = y;
79+
throw runtime_error("bfloat16 is not currently supported on the host device.",
80+
PI_ERROR_INVALID_DEVICE);
81+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
82+
}
83+
84+
template <size_t N>
85+
sycl::marray<bfloat16, N> fmin(sycl::marray<bfloat16, N> x,
86+
sycl::marray<bfloat16, N> y) {
87+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
88+
sycl::marray<bfloat16, N> res;
89+
90+
for (size_t i = 0; i < N / 2; i++) {
91+
auto partial_res = __clc_fmin(detail::to_uint32_t(x, i * 2),
92+
detail::to_uint32_t(y, i * 2));
93+
std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
94+
}
95+
96+
if (N % 2) {
97+
oneapi::detail::Bfloat16StorageT XBits =
98+
oneapi::detail::bfloat16ToBits(x[N - 1]);
99+
oneapi::detail::Bfloat16StorageT YBits =
100+
oneapi::detail::bfloat16ToBits(y[N - 1]);
101+
res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits));
102+
}
103+
104+
return res;
105+
#else
106+
std::ignore = x;
107+
std::ignore = y;
108+
throw runtime_error("bfloat16 is not currently supported on the host device.",
109+
PI_ERROR_INVALID_DEVICE);
110+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
111+
}
112+
113+
template <typename T>
114+
std::enable_if_t<std::is_same<T, bfloat16>::value, T> fmax(T x, T y) {
115+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
116+
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
117+
oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y);
118+
return oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits));
119+
#else
120+
std::ignore = x;
121+
std::ignore = y;
122+
throw runtime_error("bfloat16 is not currently supported on the host device.",
123+
PI_ERROR_INVALID_DEVICE);
124+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
125+
}
126+
127+
template <size_t N>
128+
sycl::marray<bfloat16, N> fmax(sycl::marray<bfloat16, N> x,
129+
sycl::marray<bfloat16, N> y) {
130+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
131+
sycl::marray<bfloat16, N> res;
132+
133+
for (size_t i = 0; i < N / 2; i++) {
134+
auto partial_res = __clc_fmax(detail::to_uint32_t(x, i * 2),
135+
detail::to_uint32_t(y, i * 2));
136+
std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
137+
}
138+
139+
if (N % 2) {
140+
oneapi::detail::Bfloat16StorageT XBits =
141+
oneapi::detail::bfloat16ToBits(x[N - 1]);
142+
oneapi::detail::Bfloat16StorageT YBits =
143+
oneapi::detail::bfloat16ToBits(y[N - 1]);
144+
res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits));
145+
}
146+
return res;
147+
#else
148+
std::ignore = x;
149+
std::ignore = y;
150+
throw runtime_error("bfloat16 is not currently supported on the host device.",
151+
PI_ERROR_INVALID_DEVICE);
152+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
153+
}
154+
155+
template <typename T>
156+
std::enable_if_t<std::is_same<T, bfloat16>::value, T> fma(T x, T y, T z) {
157+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
158+
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
159+
oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y);
160+
oneapi::detail::Bfloat16StorageT ZBits = oneapi::detail::bfloat16ToBits(z);
161+
return oneapi::detail::bitsToBfloat16(__clc_fma(XBits, YBits, ZBits));
162+
#else
163+
std::ignore = x;
164+
std::ignore = y;
165+
std::ignore = z;
166+
throw runtime_error("bfloat16 is not currently supported on the host device.",
167+
PI_ERROR_INVALID_DEVICE);
168+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
169+
}
170+
171+
template <size_t N>
172+
sycl::marray<bfloat16, N> fma(sycl::marray<bfloat16, N> x,
173+
sycl::marray<bfloat16, N> y,
174+
sycl::marray<bfloat16, N> z) {
175+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
176+
sycl::marray<bfloat16, N> res;
177+
178+
for (size_t i = 0; i < N / 2; i++) {
179+
auto partial_res =
180+
__clc_fma(detail::to_uint32_t(x, i * 2), detail::to_uint32_t(y, i * 2),
181+
detail::to_uint32_t(z, i * 2));
182+
std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
183+
}
184+
185+
if (N % 2) {
186+
oneapi::detail::Bfloat16StorageT XBits =
187+
oneapi::detail::bfloat16ToBits(x[N - 1]);
188+
oneapi::detail::Bfloat16StorageT YBits =
189+
oneapi::detail::bfloat16ToBits(y[N - 1]);
190+
oneapi::detail::Bfloat16StorageT ZBits =
191+
oneapi::detail::bfloat16ToBits(z[N - 1]);
192+
res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fma(XBits, YBits, ZBits));
193+
}
194+
return res;
195+
#else
196+
std::ignore = x;
197+
std::ignore = y;
198+
std::ignore = z;
199+
throw runtime_error("bfloat16 is not currently supported on the host device.",
200+
PI_ERROR_INVALID_DEVICE);
201+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
202+
}
203+
204+
} // namespace experimental
205+
} // namespace oneapi
206+
} // namespace ext
207+
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
208+
} // namespace sycl

sycl/include/sycl/ext/oneapi/experimental/builtins.hpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,6 @@ namespace ext {
3232
namespace oneapi {
3333
namespace experimental {
3434

35-
namespace detail {
36-
template <size_t N>
37-
uint32_t to_uint32_t(sycl::marray<bfloat16, N> x, size_t start) {
38-
uint32_t res;
39-
std::memcpy(&res, &x[start], sizeof(uint32_t));
40-
return res;
41-
}
42-
} // namespace detail
43-
4435
// Provides functionality to print data from kernels in a C way:
4536
// - On non-host devices this function is directly mapped to printf from
4637
// OpenCL C

0 commit comments

Comments
 (0)