Skip to content

Commit 98e088f

Browse files
[SYCL] Optimize sub-group group_load via BlockRead in simple cases
1 parent f56d7d7 commit 98e088f

File tree

4 files changed

+386
-179
lines changed

4 files changed

+386
-179
lines changed

sycl/include/sycl/detail/generic_type_traits.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,13 @@ template <typename T>
262262
using make_unsinged_integer_t =
263263
make_type_t<T, gtl::scalar_unsigned_integer_list>;
264264

265+
template <int Size>
266+
using cl_unsigned = std::conditional_t<
267+
Size == 1, opencl::cl_uchar,
268+
std::conditional_t<
269+
Size == 2, opencl::cl_ushort,
270+
std::conditional_t<Size == 4, opencl::cl_uint, opencl::cl_ulong>>>;
271+
265272
// select_apply_cl_scalar_t selects from T8/T16/T32/T64 basing on
266273
// sizeof(IN). expected to handle scalar types.
267274
template <typename T, typename T8, typename T16, typename T32, typename T64>

sycl/include/sycl/detail/helpers.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ void loop_impl(std::integer_sequence<size_t, Inds...>, F &&f) {
253253
template <size_t count, class F> void loop(F &&f) {
254254
loop_impl(std::make_index_sequence<count>{}, std::forward<F>(f));
255255
}
256+
inline constexpr bool is_power_of_two(int x) { return (x & (x - 1)) == 0; }
256257
} // namespace detail
257258

258259
} // namespace _V1

sycl/include/sycl/ext/oneapi/experimental/group_load_store.hpp

Lines changed: 165 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include <sycl/ext/oneapi/properties/properties.hpp>
1414
#include <sycl/sycl_span.hpp>
1515

16+
#include <cstring>
17+
1618
namespace sycl {
1719
inline namespace _V1 {
1820
namespace ext::oneapi::experimental {
@@ -106,6 +108,116 @@ int get_mem_idx(GroupTy g, int vec_or_array_idx) {
106108
return g.get_local_linear_id() +
107109
g.get_local_linear_range() * vec_or_array_idx;
108110
}
111+
112+
// SPIR-V extension:
113+
// https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_subgroups.asciidoc,
114+
// however it doesn't describe limitations/requirements. Those seem to be
115+
// listed in the Intel OpenCL extensions for sub-groups:
116+
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html
117+
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups_char.html
118+
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups_long.html
119+
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups_short.html
120+
// Reads require 4-byte alignment, writes 16-byte alignment. Supported
121+
// sizes:
122+
//
123+
// +--------+-------------+
124+
// | uchar | 1,2,4,8,16 |
125+
// | ushort | 1,2,4,8 |
126+
// | uint | 1,2,4,8 |
127+
// | ulong | 1,2,4,8 |
128+
// +--------+-------------+
129+
//
130+
// Utility type traits below are used to map user type to one of the block
131+
// read/write types above.
132+
133+
template <typename IteratorT, std::size_t ElementsPerWorkItem, bool blocked>
134+
struct BlockInfo {
135+
using value_type =
136+
remove_decoration_t<typename std::iterator_traits<IteratorT>::value_type>;
137+
138+
static constexpr int block_size =
139+
sizeof(value_type) * (blocked ? ElementsPerWorkItem : 1);
140+
static constexpr int num_blocks = blocked ? 1 : ElementsPerWorkItem;
141+
static constexpr bool has_builtin =
142+
detail::is_power_of_two(block_size) &&
143+
detail::is_power_of_two(num_blocks) && block_size <= 8 &&
144+
(num_blocks <= 8 || (num_blocks == 16 && block_size == 1));
145+
};
146+
147+
template <typename BlockInfoTy> struct BlockTypeInfo;
148+
149+
template <typename IteratorT, std::size_t ElementsPerWorkItem, bool blocked>
150+
struct BlockTypeInfo<BlockInfo<IteratorT, ElementsPerWorkItem, blocked>> {
151+
using BlockInfoTy = BlockInfo<IteratorT, ElementsPerWorkItem, blocked>;
152+
static_assert(BlockInfoTy::has_builtin);
153+
154+
using block_type = detail::cl_unsigned<BlockInfoTy::block_size>;
155+
156+
using block_pointer_elem_type = std::conditional_t<
157+
std::is_const_v<std::remove_reference_t<
158+
typename std::iterator_traits<IteratorT>::reference>>,
159+
std::add_const_t<block_type>, block_type>;
160+
161+
using block_pointer_type = typename detail::DecoratedType<
162+
block_pointer_elem_type, access::address_space::global_space>::type *;
163+
using block_load_type = std::conditional_t<
164+
BlockInfoTy::num_blocks == 1, block_type,
165+
detail::ConvertToOpenCLType_t<vec<block_type, BlockInfoTy::num_blocks>>>;
166+
};
167+
168+
// Returns either a pointer suitable to use in a block read/write builtin or
169+
// nullptr if some legality conditions aren't satisfied.
170+
template <int RequiredAlign, std::size_t ElementsPerWorkItem,
171+
typename IteratorT, typename Properties>
172+
auto get_block_op_ptr(IteratorT iter, [[maybe_unused]] Properties props) {
173+
using value_type =
174+
remove_decoration_t<typename std::iterator_traits<IteratorT>::value_type>;
175+
using iter_no_cv = std::remove_cv_t<IteratorT>;
176+
177+
constexpr bool blocked = detail::isBlocked(props);
178+
using BlkInfo = BlockInfo<IteratorT, ElementsPerWorkItem, blocked>;
179+
180+
#if defined(__SPIR__)
181+
// TODO: What about non-Intel SPIR-V devices?
182+
constexpr bool is_spir = true;
183+
#else
184+
constexpr bool is_spir = false;
185+
#endif
186+
187+
if constexpr (!is_spir || !BlkInfo::has_builtin) {
188+
return nullptr;
189+
} else if constexpr (!props.template has_property<full_group_key>()) {
190+
return nullptr;
191+
} else if constexpr (detail::is_multi_ptr_v<IteratorT>) {
192+
return get_block_op_ptr<RequiredAlign, ElementsPerWorkItem>(
193+
iter.get_decorated(), props);
194+
} else if constexpr (!std::is_pointer_v<iter_no_cv>) {
195+
if constexpr (props.template has_property<contiguous_memory_key>())
196+
return get_block_op_ptr<RequiredAlign, ElementsPerWorkItem>(&*iter,
197+
props);
198+
else
199+
return nullptr;
200+
} else {
201+
__builtin_assume(iter != nullptr);
202+
static_assert(BlkInfo::has_builtin);
203+
bool aligned = alignof(value_type) >= RequiredAlign ||
204+
reinterpret_cast<uintptr_t>(iter) % RequiredAlign == 0;
205+
206+
constexpr auto AS = detail::deduce_AS<iter_no_cv>::value;
207+
using block_pointer_type =
208+
typename BlockTypeInfo<BlkInfo>::block_pointer_type;
209+
if constexpr (AS == access::address_space::global_space) {
210+
return aligned ? reinterpret_cast<block_pointer_type>(iter) : nullptr;
211+
} else if constexpr (AS == access::address_space::generic_space) {
212+
return aligned ? reinterpret_cast<block_pointer_type>(
213+
__SYCL_GenericCastToPtrExplicit_ToGlobal<value_type>(
214+
iter))
215+
: nullptr;
216+
} else {
217+
return nullptr;
218+
}
219+
}
220+
}
109221
} // namespace detail
110222

111223
// Load API span overload.
@@ -117,17 +229,66 @@ std::enable_if_t<detail::verify_load_types<InputIteratorT, OutputT> &&
117229
group_load(Group g, InputIteratorT in_ptr,
118230
span<OutputT, ElementsPerWorkItem> out, Properties props = {}) {
119231
constexpr bool blocked = detail::isBlocked(props);
232+
using use_naive =
233+
detail::merged_properties_t<Properties,
234+
decltype(properties(detail::naive))>;
120235

121236
if constexpr (props.template has_property<detail::naive_key>()) {
122237
group_barrier(g);
123238
for (int i = 0; i < out.size(); ++i)
124239
out[i] = in_ptr[detail::get_mem_idx<blocked, ElementsPerWorkItem>(g, i)];
125240
group_barrier(g);
126-
} else {
127-
using use_naive =
128-
detail::merged_properties_t<Properties,
129-
decltype(properties(detail::naive))>;
241+
return;
242+
} else if constexpr (!std::is_same_v<Group, sycl::sub_group>) {
130243
return group_load(g, in_ptr, out, use_naive{});
244+
} else {
245+
auto ptr =
246+
detail::get_block_op_ptr<4 /* load align */, ElementsPerWorkItem>(
247+
in_ptr, props);
248+
if (!ptr)
249+
return group_load(g, in_ptr, out, use_naive{});
250+
251+
if constexpr (!std::is_same_v<std::nullptr_t, decltype(ptr)>) {
252+
// Do optimized load.
253+
using value_type = remove_decoration_t<
254+
typename std::iterator_traits<InputIteratorT>::value_type>;
255+
256+
auto load = __spirv_SubgroupBlockReadINTEL<
257+
typename detail::BlockTypeInfo<detail::BlockInfo<
258+
InputIteratorT, ElementsPerWorkItem, blocked>>::block_load_type>(
259+
ptr);
260+
261+
// TODO: accessor_iterator's value_type is weird, so we need
262+
// `std::remove_const_t` below:
263+
//
264+
// static_assert(
265+
// std::is_same_v<
266+
// typename std::iterator_traits<
267+
// sycl::detail::accessor_iterator<const int, 1>>::value_type,
268+
// const int>);
269+
//
270+
// yet
271+
//
272+
// static_assert(
273+
// std::is_same_v<
274+
// typename std::iterator_traits<const int *>::value_type, int>);
275+
276+
if constexpr (std::is_same_v<std::remove_const_t<value_type>, OutputT>) {
277+
static_assert(sizeof(load) == out.size_bytes());
278+
std::memcpy(out.begin(), &load, out.size_bytes());
279+
} else {
280+
std::remove_const_t<value_type> values[ElementsPerWorkItem];
281+
static_assert(sizeof(load) == sizeof(values));
282+
std::memcpy(values, &load, sizeof(values));
283+
284+
// Note: can't `memcpy` directly into `out` because that might bypass
285+
// an implicit conversion required by the specification.
286+
for (int i = 0; i < ElementsPerWorkItem; ++i)
287+
out[i] = values[i];
288+
}
289+
290+
return;
291+
}
131292
}
132293
}
133294

0 commit comments

Comments
 (0)