Skip to content

Commit

Permalink
[SYCL] Initial changes for the second version of sycl_ext_oneapi_grou…
Browse files Browse the repository at this point in the history
…p_sort extension

Current implementation is aligned with the first version of the
extension: https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_group_sort.asciidoc

This second version of the extension:
https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/proposed/sycl_ext_oneapi_group_sort.asciidoc
changes the API by introducing separate sorter objects for
sort_over_group and joint_sort.

Save both versions of API until the second version is not fully implemented. When PRs supporting the second version will
be fully merged and macro will be updated, then old APIs will be removed.

Currently PR doesn't include array input and key/value sorting support.
It's split from larger PR: intel#13713

Co-authored-by: "Andrei Fedorov [andrey.fedorov@intel.com](mailto:andrey.fedorov@intel.com)"
Co-authored-by: "Romanov Vlad [vlad.romanov@intel.com](mailto:vlad.romanov@intel.com)"
  • Loading branch information
againull committed May 30, 2024
1 parent f2a2de3 commit 71e6b36
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 7 deletions.
192 changes: 191 additions & 1 deletion sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ template <typename Group, size_t Extent> class group_with_scratchpad {
sycl::span<std::byte, Extent> get_memory() const { return scratch; }
};

// ---- sorters
// Default sorter provided by the first version of the extension specification.
template <typename Compare = std::less<>> class default_sorter {
Compare comp;
sycl::span<std::byte> scratch;
Expand Down Expand Up @@ -122,6 +122,7 @@ template <typename T> struct ConvertToComp<T, sorting_order::descending> {
};
} // namespace detail

// Radix sorter provided by the first version of the extension specification.
template <typename ValT, sorting_order OrderT = sorting_order::ascending,
unsigned int BitsPerPass = 4>
class radix_sorter {
Expand Down Expand Up @@ -199,6 +200,195 @@ class radix_sorter {
}
};

// Default sorters provided by the second version of the extension
// specification.
namespace default_sorters {

template <typename CompareT = std::less<>> class joint_sorter {
CompareT comp;
sycl::span<std::byte> scratch;

public:
template <size_t Extent>
joint_sorter(sycl::span<std::byte, Extent> scratch_,
CompareT comp_ = CompareT())
: comp(comp_), scratch(scratch_) {}

template <typename Group, typename Ptr>
void operator()([[maybe_unused]] Group g, [[maybe_unused]] Ptr first,
[[maybe_unused]] Ptr last) {
#ifdef __SYCL_DEVICE_ONLY__
// Per extension specification if scratch size is less than the value
// returned by memory_required then behavior is undefined, so we don't check
// that the scratch size statisfies the requirement.
sycl::detail::merge_sort(g, first, last - first, comp, scratch.data());
#else
throw sycl::exception(
std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
"default_sorter constructor is not supported on host device.");
#endif
}

template <typename T>
static constexpr size_t memory_required(sycl::memory_scope,
size_t range_size) {
return range_size * sizeof(T) + alignof(T);
}
};

template <typename T, typename CompareT = std::less<>,
std::size_t ElementsPerWorkItem = 1>
class group_sorter {
CompareT comp;
sycl::span<std::byte> scratch;

public:
template <std::size_t Extent>
group_sorter(sycl::span<std::byte, Extent> scratch_,
CompareT comp_ = CompareT{})
: comp(comp_), scratch(scratch_) {}

template <typename Group> T operator()([[maybe_unused]] Group g, T val) {
#ifdef __SYCL_DEVICE_ONLY__
// Per extension specification if scratch size is less than the value
// returned by memory_required then behavior is undefined, so we don't check
// that the scratch size statisfies the requirement.
auto range_size = g.get_local_range().size();
size_t local_id = g.get_local_linear_id();
T *temp = reinterpret_cast<T *>(scratch.data());
::new (temp + local_id) T(val);
sycl::detail::merge_sort(g, temp, range_size, comp,
scratch.data() + range_size * sizeof(T));
val = temp[local_id];
#else
throw sycl::exception(
std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
"default_sorter operator() is not supported on host device.");
#endif
return val;
}

static constexpr std::size_t memory_required(sycl::memory_scope scope,
size_t range_size) {
return 2 * joint_sorter<>::template memory_required<T>(
scope, range_size * ElementsPerWorkItem);
}
};

} // namespace default_sorters

// Radix sorters provided by the second version of the extension specification.
namespace radix_sorters {

template <typename ValT, sorting_order OrderT = sorting_order::ascending,
unsigned int BitsPerPass = 4>
class joint_sorter {

sycl::span<std::byte> scratch;
uint32_t first_bit = 0;
uint32_t last_bit = 0;

static constexpr uint32_t bits = BitsPerPass;
using bitset_t = std::bitset<sizeof(ValT) * CHAR_BIT>;

public:
template <std::size_t Extent>
joint_sorter(sycl::span<std::byte, Extent> scratch_,
const bitset_t mask = bitset_t{}.set())
: scratch(scratch_) {
static_assert((std::is_arithmetic<ValT>::value ||
std::is_same<ValT, sycl::half>::value ||
std::is_same<ValT, sycl::ext::oneapi::bfloat16>::value),
"radix sort is not supported for the given type");

for (first_bit = 0; first_bit < mask.size() && !mask[first_bit];
++first_bit)
;
for (last_bit = first_bit; last_bit < mask.size() && mask[last_bit];
++last_bit)
;
}

template <typename GroupT, typename PtrT>
void operator()([[maybe_unused]] GroupT g, [[maybe_unused]] PtrT first,
[[maybe_unused]] PtrT last) {
#ifdef __SYCL_DEVICE_ONLY__
sycl::detail::privateDynamicSort</*is_key_value=*/false,
OrderT == sorting_order::ascending,
/*empty*/ 1, BitsPerPass>(
g, first, /*empty*/ first, last - first, scratch.data(), first_bit,
last_bit);
#else
throw sycl::exception(
std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
"radix_sorter is not supported on host device.");
#endif
}

static constexpr std::size_t
memory_required([[maybe_unused]] sycl::memory_scope scope,
std::size_t range_size) {
return range_size * sizeof(ValT) +
(1 << bits) * range_size * sizeof(uint32_t) + alignof(uint32_t);
}
};

template <typename ValT, sorting_order OrderT = sorting_order::ascending,
size_t ElementsPerWorkItem = 1, unsigned int BitsPerPass = 4>
class group_sorter {

sycl::span<std::byte> scratch;
uint32_t first_bit = 0;
uint32_t last_bit = 0;

static constexpr uint32_t bits = BitsPerPass;
using bitset_t = std::bitset<sizeof(ValT) * CHAR_BIT>;

public:
template <std::size_t Extent>
group_sorter(sycl::span<std::byte, Extent> scratch_,
const bitset_t mask = bitset_t{}.set())
: scratch(scratch_) {
static_assert((std::is_arithmetic<ValT>::value ||
std::is_same<ValT, sycl::half>::value ||
std::is_same<ValT, sycl::ext::oneapi::bfloat16>::value),
"radix sort is not usable");

for (first_bit = 0; first_bit < mask.size() && !mask[first_bit];
++first_bit)
;
for (last_bit = first_bit; last_bit < mask.size() && mask[last_bit];
++last_bit)
;
}

template <typename GroupT>
ValT operator()([[maybe_unused]] GroupT g, [[maybe_unused]] ValT val) {
#ifdef __SYCL_DEVICE_ONLY__
ValT result[]{val};
sycl::detail::privateStaticSort</*is_key_value=*/false,
/*is_blocked=*/true,
OrderT == sorting_order::ascending,
/*items_per_work_item=*/1, bits>(
g, result, /*empty*/ result, scratch.data(), first_bit, last_bit);
return result[0];
#else
throw sycl::exception(
std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
"radix_sorter is not supported on host device.");
#endif
}

static constexpr size_t
memory_required([[maybe_unused]] sycl::memory_scope scope,
size_t range_size) {
return (std::max)(range_size * sizeof(ValT),
range_size * (1 << bits) * sizeof(uint32_t));
}
};

} // namespace radix_sorters

} // namespace ext::oneapi::experimental
} // namespace _V1
} // namespace sycl
Expand Down
8 changes: 4 additions & 4 deletions sycl/include/sycl/ext/oneapi/experimental/group_sort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,15 @@ sort_over_group(experimental::group_with_scratchpad<Group, Extent> exec,
T value, Compare comp) {
return sort_over_group(
exec.get_group(), value,
experimental::default_sorter<Compare>(exec.get_memory(), comp));
default_sorters::group_sorter<T, Compare, 1>(exec.get_memory(), comp));
}

template <typename Group, typename T, size_t Extent>
std::enable_if_t<sycl::is_group_v<std::decay_t<Group>>, T>
sort_over_group(experimental::group_with_scratchpad<Group, Extent> exec,
T value) {
return sort_over_group(exec.get_group(), value,
experimental::default_sorter<>(exec.get_memory()));
default_sorters::group_sorter<T>(exec.get_memory()));
}

// ---- joint_sort
Expand All @@ -120,15 +120,15 @@ std::enable_if_t<!detail::is_sorter<Compare, Group, Iter>::value, void>
joint_sort(experimental::group_with_scratchpad<Group, Extent> exec, Iter first,
Iter last, Compare comp) {
joint_sort(exec.get_group(), first, last,
experimental::default_sorter<Compare>(exec.get_memory(), comp));
default_sorters::joint_sorter<Compare>(exec.get_memory(), comp));
}

template <typename Group, typename Iter, size_t Extent>
std::enable_if_t<sycl::is_group_v<std::decay_t<Group>>, void>
joint_sort(experimental::group_with_scratchpad<Group, Extent> exec, Iter first,
Iter last) {
joint_sort(exec.get_group(), first, last,
experimental::default_sorter<>(exec.get_memory()));
default_sorters::joint_sorter<>(exec.get_memory()));
}

} // namespace ext::oneapi::experimental
Expand Down
Loading

0 comments on commit 71e6b36

Please sign in to comment.