13
13
#include < sycl/ext/oneapi/properties/properties.hpp>
14
14
#include < sycl/sycl_span.hpp>
15
15
16
+ #include < cstring>
17
+
16
18
namespace sycl {
17
19
inline namespace _V1 {
18
20
namespace ext ::oneapi::experimental {
@@ -106,6 +108,116 @@ int get_mem_idx(GroupTy g, int vec_or_array_idx) {
106
108
return g.get_local_linear_id () +
107
109
g.get_local_linear_range () * vec_or_array_idx;
108
110
}
111
+
112
+ // SPIR-V extension:
113
+ // https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_subgroups.asciidoc,
114
+ // however it doesn't describe limitations/requirements. Those seem to be
115
+ // listed in the Intel OpenCL extensions for sub-groups:
116
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html
117
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups_char.html
118
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups_long.html
119
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups_short.html
120
+ // Reads require 4-byte alignment, writes 16-byte alignment. Supported
121
+ // sizes:
122
+ //
123
+ // +--------+-------------+
124
+ // | uchar | 1,2,4,8,16 |
125
+ // | ushort | 1,2,4,8 |
126
+ // | uint | 1,2,4,8 |
127
+ // | ulong | 1,2,4,8 |
128
+ // +--------+-------------+
129
+ //
130
+ // Utility type traits below are used to map user type to one of the block
131
+ // read/write types above.
132
+
133
+ template <typename IteratorT, std::size_t ElementsPerWorkItem, bool blocked>
134
+ struct BlockInfo {
135
+ using value_type =
136
+ remove_decoration_t <typename std::iterator_traits<IteratorT>::value_type>;
137
+
138
+ static constexpr int block_size =
139
+ sizeof (value_type) * (blocked ? ElementsPerWorkItem : 1 );
140
+ static constexpr int num_blocks = blocked ? 1 : ElementsPerWorkItem;
141
+ static constexpr bool has_builtin =
142
+ detail::is_power_of_two (block_size) &&
143
+ detail::is_power_of_two (num_blocks) && block_size <= 8 &&
144
+ (num_blocks <= 8 || (num_blocks == 16 && block_size == 1 ));
145
+ };
146
+
147
+ template <typename BlockInfoTy> struct BlockTypeInfo ;
148
+
149
+ template <typename IteratorT, std::size_t ElementsPerWorkItem, bool blocked>
150
+ struct BlockTypeInfo <BlockInfo<IteratorT, ElementsPerWorkItem, blocked>> {
151
+ using BlockInfoTy = BlockInfo<IteratorT, ElementsPerWorkItem, blocked>;
152
+ static_assert (BlockInfoTy::has_builtin);
153
+
154
+ using block_type = detail::cl_unsigned<BlockInfoTy::block_size>;
155
+
156
+ using block_pointer_elem_type = std::conditional_t <
157
+ std::is_const_v<std::remove_reference_t <
158
+ typename std::iterator_traits<IteratorT>::reference>>,
159
+ std::add_const_t <block_type>, block_type>;
160
+
161
+ using block_pointer_type = typename detail::DecoratedType<
162
+ block_pointer_elem_type, access::address_space::global_space>::type *;
163
+ using block_load_type = std::conditional_t <
164
+ BlockInfoTy::num_blocks == 1 , block_type,
165
+ detail::ConvertToOpenCLType_t<vec<block_type, BlockInfoTy::num_blocks>>>;
166
+ };
167
+
168
+ // Returns either a pointer suitable to use in a block read/write builtin or
169
+ // nullptr if some legality conditions aren't satisfied.
170
+ template <int RequiredAlign, std::size_t ElementsPerWorkItem,
171
+ typename IteratorT, typename Properties>
172
+ auto get_block_op_ptr (IteratorT iter, [[maybe_unused]] Properties props) {
173
+ using value_type =
174
+ remove_decoration_t <typename std::iterator_traits<IteratorT>::value_type>;
175
+ using iter_no_cv = std::remove_cv_t <IteratorT>;
176
+
177
+ constexpr bool blocked = detail::isBlocked (props);
178
+ using BlkInfo = BlockInfo<IteratorT, ElementsPerWorkItem, blocked>;
179
+
180
+ #if defined(__SPIR__)
181
+ // TODO: What about non-Intel SPIR-V devices?
182
+ constexpr bool is_spir = true ;
183
+ #else
184
+ constexpr bool is_spir = false ;
185
+ #endif
186
+
187
+ if constexpr (!is_spir || !BlkInfo::has_builtin) {
188
+ return nullptr ;
189
+ } else if constexpr (!props.template has_property <full_group_key>()) {
190
+ return nullptr ;
191
+ } else if constexpr (detail::is_multi_ptr_v<IteratorT>) {
192
+ return get_block_op_ptr<RequiredAlign, ElementsPerWorkItem>(
193
+ iter.get_decorated (), props);
194
+ } else if constexpr (!std::is_pointer_v<iter_no_cv>) {
195
+ if constexpr (props.template has_property <contiguous_memory_key>())
196
+ return get_block_op_ptr<RequiredAlign, ElementsPerWorkItem>(&*iter,
197
+ props);
198
+ else
199
+ return nullptr ;
200
+ } else {
201
+ __builtin_assume (iter != nullptr );
202
+ static_assert (BlkInfo::has_builtin);
203
+ bool aligned = alignof (value_type) >= RequiredAlign ||
204
+ reinterpret_cast <uintptr_t >(iter) % RequiredAlign == 0 ;
205
+
206
+ constexpr auto AS = detail::deduce_AS<iter_no_cv>::value;
207
+ using block_pointer_type =
208
+ typename BlockTypeInfo<BlkInfo>::block_pointer_type;
209
+ if constexpr (AS == access::address_space::global_space) {
210
+ return aligned ? reinterpret_cast <block_pointer_type>(iter) : nullptr ;
211
+ } else if constexpr (AS == access::address_space::generic_space) {
212
+ return aligned ? reinterpret_cast <block_pointer_type>(
213
+ __SYCL_GenericCastToPtrExplicit_ToGlobal<value_type>(
214
+ iter))
215
+ : nullptr ;
216
+ } else {
217
+ return nullptr ;
218
+ }
219
+ }
220
+ }
109
221
} // namespace detail
110
222
111
223
// Load API span overload.
@@ -117,17 +229,66 @@ std::enable_if_t<detail::verify_load_types<InputIteratorT, OutputT> &&
117
229
group_load (Group g, InputIteratorT in_ptr,
118
230
span<OutputT, ElementsPerWorkItem> out, Properties props = {}) {
119
231
constexpr bool blocked = detail::isBlocked (props);
232
+ using use_naive =
233
+ detail::merged_properties_t <Properties,
234
+ decltype (properties (detail::naive))>;
120
235
121
236
if constexpr (props.template has_property <detail::naive_key>()) {
122
237
group_barrier (g);
123
238
for (int i = 0 ; i < out.size (); ++i)
124
239
out[i] = in_ptr[detail::get_mem_idx<blocked, ElementsPerWorkItem>(g, i)];
125
240
group_barrier (g);
126
- } else {
127
- using use_naive =
128
- detail::merged_properties_t <Properties,
129
- decltype (properties (detail::naive))>;
241
+ return ;
242
+ } else if constexpr (!std::is_same_v<Group, sycl::sub_group>) {
130
243
return group_load (g, in_ptr, out, use_naive{});
244
+ } else {
245
+ auto ptr =
246
+ detail::get_block_op_ptr<4 /* load align */ , ElementsPerWorkItem>(
247
+ in_ptr, props);
248
+ if (!ptr)
249
+ return group_load (g, in_ptr, out, use_naive{});
250
+
251
+ if constexpr (!std::is_same_v<std::nullptr_t , decltype (ptr)>) {
252
+ // Do optimized load.
253
+ using value_type = remove_decoration_t <
254
+ typename std::iterator_traits<InputIteratorT>::value_type>;
255
+
256
+ auto load = __spirv_SubgroupBlockReadINTEL<
257
+ typename detail::BlockTypeInfo<detail::BlockInfo<
258
+ InputIteratorT, ElementsPerWorkItem, blocked>>::block_load_type>(
259
+ ptr);
260
+
261
+ // TODO: accessor_iterator's value_type is weird, so we need
262
+ // `std::remove_const_t` below:
263
+ //
264
+ // static_assert(
265
+ // std::is_same_v<
266
+ // typename std::iterator_traits<
267
+ // sycl::detail::accessor_iterator<const int, 1>>::value_type,
268
+ // const int>);
269
+ //
270
+ // yet
271
+ //
272
+ // static_assert(
273
+ // std::is_same_v<
274
+ // typename std::iterator_traits<const int *>::value_type, int>);
275
+
276
+ if constexpr (std::is_same_v<std::remove_const_t <value_type>, OutputT>) {
277
+ static_assert (sizeof (load) == out.size_bytes ());
278
+ std::memcpy (out.begin (), &load, out.size_bytes ());
279
+ } else {
280
+ std::remove_const_t <value_type> values[ElementsPerWorkItem];
281
+ static_assert (sizeof (load) == sizeof (values));
282
+ std::memcpy (values, &load, sizeof (values));
283
+
284
+ // Note: can't `memcpy` directly into `out` because that might bypass
285
+ // an implicit conversion required by the specification.
286
+ for (int i = 0 ; i < ElementsPerWorkItem; ++i)
287
+ out[i] = values[i];
288
+ }
289
+
290
+ return ;
291
+ }
131
292
}
132
293
}
133
294
0 commit comments