Skip to content

Commit 78a3e77

Browse files
authored
[SYCL] Implement sub-group mask extension (#4481)
The specification is available under https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/SubGroupMask/SubGroupMask.asciidoc Complementary test changes are available under intel/llvm-test-suite#441, intel/llvm-test-suite#462
1 parent fc2f897 commit 78a3e77

File tree

10 files changed

+447
-54
lines changed

10 files changed

+447
-54
lines changed

sycl/doc/extensions/GroupMask/GroupMask.asciidoc renamed to sycl/doc/extensions/SubGroupMask/SubGroupMask.asciidoc

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
= SYCL_EXT_ONEAPI_GROUP_MASK
1+
= SYCL_EXT_ONEAPI_SUB_GROUP_MASK
22
:source-highlighter: coderay
33
:coderay-linenums-mode: table
44

@@ -21,7 +21,7 @@ IMPORTANT: This specification is a draft.
2121

2222
NOTE: Khronos(R) is a registered trademark and SYCL(TM) and SPIR(TM) are trademarks of The Khronos Group Inc. OpenCL(TM) is a trademark of Apple Inc. used by permission by Khronos.
2323

24-
This document describes an extension which adds a `group_mask` type. Such a mask can be used to efficiently represent subsets of work-items in a group for which a given Boolean condition holds. Group mask functionality is currently limited to groups that are instances of the `sub_group` class.
24+
This document describes an extension which adds a `sub_group_mask` type. Such a mask can be used to efficiently represent subsets of work-items in a sub-group for which a given Boolean condition holds.
2525

2626
== Notice
2727

@@ -51,9 +51,9 @@ This extension is written against the SYCL 2020 specification, Revision 3.
5151
This extension provides a feature-test macro as described in the core SYCL
5252
specification section 6.3.3 "Feature test macros". Therefore, an
5353
implementation supporting this extension must predefine the macro
54-
`SYCL_EXT_ONEAPI_GROUP_MASK` to one of the values defined in the table below.
55-
Applications can test for the existence of this macro to determine if the
56-
implementation supports this feature, or applications can test the macro's
54+
`SYCL_EXT_ONEAPI_SUB_GROUP_MASK` to one of the values defined in the table
55+
below. Applications can test for the existence of this macro to determine if
56+
the implementation supports this feature, or applications can test the macro's
5757
value to determine which of the extension's APIs the implementation supports.
5858

5959
[%header,cols="1,5"]
@@ -81,18 +81,18 @@ must be encountered by all work-items in the group in converged control flow.
8181
|===
8282
|Function|Description
8383

84-
|`template <typename Group> Group::mask_type group_ballot(Group g, bool predicate = true) const`
85-
|Return a `group_mask` representing the set of work-items in group _g_ for which _predicate_ is `true`.
84+
|`template <typename Group> Group::mask_type group_ballot(Group g, bool predicate = true)`
85+
|Return a `sub_group_mask` with one bit for each work-item in group _g_. A bit is set in this mask if and only if the corresponding work-item's _predicate_ is `true`.
8686
|===
8787

8888
=== Group Masks
8989

9090
The group mask type is an opaque type, permitting implementations to use any
9191
mask representation that has the same size and alignment across host and
92-
device. The maximum number of bits that can be stored in a `group_mask` is
93-
exposed as a static member variable, `group_mask::max_bits`.
92+
device. The maximum number of bits that can be stored in a `sub_group_mask` is
93+
exposed as a static member variable, `sub_group_mask::max_bits`.
9494

95-
Functions declared in the `group_mask` class can be called independently by
95+
Functions declared in the `sub_group_mask` class can be called independently by
9696
different work-items in the same group. An instance of a group class (e.g.
9797
`group` or `sub_group`) is not required to manipulate a group mask.
9898

@@ -107,7 +107,7 @@ work-item with the id `max_local_range()-1`.
107107
|Return `true` if the bit corresponding to the specified _id_ is set in the
108108
mask.
109109

110-
|`group_mask::reference operator[](id<1> id)`
110+
|`sub_group_mask::reference operator[](id<1> id)`
111111
|Return a reference to the bit corresponding to the specified _id_ in the mask.
112112

113113
|`bool test(id<1> id) const`
@@ -137,17 +137,15 @@ work-item with the id `max_local_range()-1`.
137137
|Return the highest `id` with a corresponding bit set in the mask. If no bits
138138
are set, the return value is equal to `size()`.
139139

140-
|`template <typename T = marray<uint32_t, max_bits/sizeof(uint32_t)>> void insert_bits(T bits, id<1> pos = 0)`
140+
|`template <typename T> void insert_bits(const T &bits, id<1> pos = 0)`
141141
|Insert `CHAR_BIT * sizeof(T)` bits into the mask, starting from _pos_. `T`
142-
must be an integral type or a SYCL `marray` of integral types. _pos_ must be a
143-
multiple of `CHAR_BIT * sizeof(T)` in the range [0, `size()`). If _pos_ pass:[+]
142+
must be an integral type or a SYCL `marray` of integral types. If _pos_ pass:[+]
144143
`CHAR_BIT * sizeof(T)` is greater than `size()`, the final `size()` - (_pos_ pass:[+]
145144
`CHAR_BIT * sizeof(T)`) bits are ignored.
146145

147-
|`template <typename T = marray<uint32_t, max_bits/sizeof(uint32_t)>> T extract_bits(id<1> pos = 0) const`
146+
|`template <typename T> void extract_bits(T &out, id<1> pos = 0) const`
148147
|Return `CHAR_BIT * sizeof(T)` bits from the mask, starting from _pos_. `T`
149-
must be an integral type or a SYCL `marray` of integral types. _pos_ must be a
150-
multiple of `CHAR_BIT * sizeof(T)` in the range [0, `size()`). If _pos_ pass:[+]
148+
must be an integral type or a SYCL `marray` of integral types. If _pos_ pass:[+]
151149
`CHAR_BIT * sizeof(T)` is greater than `size()`, the final `size()` - (_pos_ pass:[+]
152150
`CHAR_BIT * sizeof(T)`) bits of the return value are zero.
153151

@@ -178,62 +176,63 @@ work-item with the id `max_local_range()-1`.
178176
|`void flip(id<1> id)`
179177
|Toggle the value of the bit corresponding to the specified _id_.
180178

181-
|`bool operator==(group_mask rhs) const`
179+
|`bool operator==(const sub_group_mask &rhs) const`
182180
|Return true if each bit in this mask is equal to the corresponding bit in
183181
`rhs`.
184182

185-
|`bool operator!=(group_mask rhs) const`
183+
|`bool operator!=(const sub_group_mask &rhs) const`
186184
|Return true if any bit in this mask is not equal to the corresponding bit in
187185
`rhs`.
188186

189-
|`group_mask operator &=(group_mask rhs)`
187+
|`sub_group_mask &operator &=(const sub_group_mask &rhs)`
190188
|Set the bits of this mask to the result of performing a bitwise AND with this
191189
mask and `rhs`.
192190

193-
|`group_mask operator \|=(group_mask rhs)`
191+
|`sub_group_mask &operator \|=(const sub_group_mask &rhs)`
194192
|Set the bits of this mask to the result of performing a bitwise OR with this
195193
mask and `rhs`.
196194

197-
|`group_mask operator ^=(group_mask rhs)`
195+
|`sub_group_mask &operator ^=(const sub_group_mask &rhs)`
198196
|Set the bits of this mask to the result of performing a bitwise XOR with this
199197
mask and `rhs`.
200198

201-
|`group_mask operator pass:[<<=](size_t shift)`
199+
|`sub_group_mask &operator pass:[<<=](size_t shift)`
202200
|Set the bits of this mask to the result of shifting its bits _shift_ positions
203201
to the left using a logical shift. Bits that are shifted out to the left are
204202
discarded, and zeroes are shifted in from the right.
205203

206-
|`group_mask operator >>=(size_t shift)`
204+
|`sub_group_mask &operator >>=(size_t shift)`
207205
|Set the bits of this mask to the result of shifting its bits _shift_ positions
208206
to the right using a logical shift. Bits that are shifted out to the right are
209207
discarded, and zeroes are shifted in from the left.
210208

211-
|`group_mask operator ~() const`
209+
|`sub_group_mask operator ~() const`
212210
|Return a mask representing the result of flipping all the bits in this mask.
213211

214-
|`group_mask operator <<(size_t shift)`
212+
|`sub_group_mask operator <<(size_t shift) const`
215213
|Return a mask representing the result of shifting its bits _shift_ positions
216214
to the left using a logical shift. Bits that are shifted out to the left are
217215
discarded, and zeroes are shifted in from the right.
218216

219-
|`group_mask operator >>(size_t shift)`
217+
|`sub_group_mask operator >>(size_t shift) const`
220218
|Return a mask representing the result of shifting its bits _shift_ positions
221219
to the right using a logical shift. Bits that are shifted out to the right are
222220
discarded, and zeroes are shifted in from the left.
221+
223222
|===
224223

225224
|===
226225
|Function|Description
227226

228-
|`group_mask operator &(const group_mask& lhs, const group_mask& rhs)`
227+
|`sub_group_mask operator &(const sub_group_mask& lhs, const sub_group_mask& rhs)`
229228
|Return a mask representing the result of performing a bitwise AND of `lhs` and
230229
`rhs`.
231230

232-
|`group_mask operator \|(const group_mask& lhs, const group_mask& rhs)`
231+
|`sub_group_mask operator \|(const sub_group_mask& lhs, const sub_group_mask& rhs)`
233232
|Return a mask representing the result of performing a bitwise OR of `lhs` and
234233
`rhs`.
235234

236-
|`group_mask operator ^(const group_mask& lhs, const group_mask& rhs)`
235+
|`sub_group_mask operator ^(const sub_group_mask& lhs, const sub_group_mask& rhs)`
237236
|Return a mask representing the result of performing a bitwise XOR of `lhs` and
238237
`rhs`.
239238

@@ -247,7 +246,7 @@ namespace sycl {
247246
namespace ext {
248247
namespace oneapi {
249248
250-
struct group_mask {
249+
struct sub_group_mask {
251250
252251
// enable reference to individual bit
253252
struct reference {
@@ -271,11 +270,11 @@ struct group_mask {
271270
id<1> find_low() const;
272271
id<1> find_high() const;
273272
274-
template <typename T = marray<uint32_t, max_bits/sizeof(uint32_t)>>
275-
void insert_bits(T bits, id<1> pos = 0);
273+
template <typename T>
274+
void insert_bits(const T &bits, id<1> pos = 0);
276275
277-
template <typename T = marray<uint32_t, max_bits/sizeof(uint32_t)>>
278-
T extract_bits(id<1> pos = 0);
276+
template <typename T>
277+
void extract_bits(T &out, id<1> pos = 0);
279278
280279
void set();
281280
void set(id<1> id, bool value = true);
@@ -286,24 +285,24 @@ struct group_mask {
286285
void flip();
287286
void flip(id<1> id);
288287
289-
bool operator==(group_mask rhs) const;
290-
bool operator!=(group_mask rhs) const;
288+
bool operator==(const sub_group_mask &rhs) const;
289+
bool operator!=(const sub_group_mask &rhs) const;
291290
292-
group_mask operator &=(group_mask rhs);
293-
group_mask operator |=(group_mask rhs);
294-
group_mask operator ^=(group_mask rhs);
295-
group_mask operator <<=(size_t);
296-
group_mask operator >>=(size_t rhs);
291+
sub_group_mask &operator &=(const sub_group_mask &rhs);
292+
sub_group_mask &operator |=(const sub_group_mask &rhs);
293+
sub_group_mask &operator ^=(const sub_group_mask &rhs);
294+
sub_group_mask &operator <<=(size_t n);
295+
sub_group_mask &operator >>=(size_t n);
297296
298-
group_mask operator ~() const;
299-
group_mask operator <<(size_t) const;
300-
group_mask operator >>(size_t) const;
297+
sub_group_mask operator ~() const;
298+
sub_group_mask operator <<(size_t n) const;
299+
sub_group_mask operator >>(size_t n) const;
301300
302301
};
303302
304-
group_mask operator &(const group_mask& lhs, const group_mask& rhs);
305-
group_mask operator |(const group_mask& lhs, const group_mask& rhs);
306-
group_mask operator ^(const group_mask& lhs, const group_mask& rhs);
303+
sub_group_mask operator &(const sub_group_mask& lhs, const sub_group_mask& rhs);
304+
sub_group_mask operator |(const sub_group_mask& lhs, const sub_group_mask& rhs);
305+
sub_group_mask operator ^(const sub_group_mask& lhs, const sub_group_mask& rhs);
307306
308307
} // namespace oneapi
309308
} // namespace ext
@@ -328,6 +327,7 @@ None.
328327
|========================================
329328
|Rev|Date|Author|Changes
330329
|1|2021-08-11|John Pennycook|*Initial public working draft*
330+
|2|2021-09-13|Vladimir Lazarev|*Update during implementation*
331331
|========================================
332332
333333
//************************************************************************

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,9 @@ __spirv_ocl_prefetch(const __attribute__((opencl_global)) char *Ptr,
633633
extern SYCL_EXTERNAL uint16_t __spirv_ConvertFToBF16INTEL(float) noexcept;
634634
extern SYCL_EXTERNAL float __spirv_ConvertBF16ToFINTEL(uint16_t) noexcept;
635635

636+
__SYCL_CONVERGENT__ extern SYCL_EXTERNAL __SYCL_EXPORT __ocl_vec_t<uint32_t, 4>
637+
__spirv_GroupNonUniformBallot(uint32_t Execution, bool Predicate) noexcept;
638+
636639
#else // if !__SYCL_DEVICE_ONLY__
637640

638641
template <typename dataT>

sycl/include/CL/sycl.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,4 @@
5959
#include <sycl/ext/oneapi/matrix/matrix.hpp>
6060
#include <sycl/ext/oneapi/reduction.hpp>
6161
#include <sycl/ext/oneapi/sub_group.hpp>
62+
#include <sycl/ext/oneapi/sub_group_mask.hpp>

sycl/include/CL/sycl/detail/helpers.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ template <int Dims> class range;
3131
template <int Dims> class id;
3232
template <int Dims> class nd_item;
3333
template <int Dims> class h_item;
34+
template <typename Type, std::size_t NumElements> class marray;
3435
enum class memory_order;
3536

3637
namespace detail {
@@ -82,6 +83,11 @@ class Builder {
8283
return group<Dims>(Global, Local, Global / Local, Index);
8384
}
8485

86+
template <class ResType>
87+
static ResType createSubGroupMask(uint32_t Bits, size_t BitsNum) {
88+
return ResType(Bits, BitsNum);
89+
}
90+
8591
template <int Dims, bool WithOffset>
8692
static detail::enable_if_t<WithOffset, item<Dims, WithOffset>>
8793
createItem(const range<Dims> &Extent, const id<Dims> &Index,

sycl/include/CL/sycl/feature_test.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ namespace sycl {
1414

1515
// TODO: Move these feature-test macros to compiler driver.
1616
#define SYCL_EXT_INTEL_DEVICE_INFO 2
17+
#define SYCL_EXT_ONEAPI_SUB_GROUP_MASK 1
1718
#define SYCL_EXT_ONEAPI_LOCAL_MEMORY 1
1819
// As for SYCL_EXT_ONEAPI_MATRIX:
1920
// 1- provides AOT initial implementation for AMX for the experimental matrix

sycl/include/CL/sycl/marray.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,9 @@ template <typename Type, std::size_t NumElements> class marray {
149149
}
150150

151151
#define __SYCL_BINOP_INTEGRAL(BINOP, OPASSIGN) \
152-
template <typename T = DataT> \
153-
friend typename std::enable_if<std::is_integral<T>::value, marray> \
154-
operator BINOP(const marray &Lhs, const marray &Rhs) { \
152+
template <typename T = DataT, \
153+
typename = std::enable_if<std::is_integral<T>::value, marray>> \
154+
friend marray operator BINOP(const marray &Lhs, const marray &Rhs) { \
155155
marray Ret; \
156156
for (size_t I = 0; I < NumElements; ++I) { \
157157
Ret[I] = Lhs[I] BINOP Rhs[I]; \
@@ -166,9 +166,9 @@ template <typename Type, std::size_t NumElements> class marray {
166166
operator BINOP(const marray &Lhs, const T &Rhs) { \
167167
return Lhs BINOP marray(static_cast<DataT>(Rhs)); \
168168
} \
169-
template <typename T = DataT> \
170-
friend typename std::enable_if<std::is_integral<T>::value, marray> \
171-
&operator OPASSIGN(marray &Lhs, const marray &Rhs) { \
169+
template <typename T = DataT, \
170+
typename = std::enable_if<std::is_integral<T>::value, marray>> \
171+
friend marray &operator OPASSIGN(marray &Lhs, const marray &Rhs) { \
172172
Lhs = Lhs BINOP Rhs; \
173173
return Lhs; \
174174
} \

0 commit comments

Comments
 (0)