Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit 61a138b

Browse files
committed
Reduce work-group size and apply minor style fixes.
1 parent 8a9bead commit 61a138b

File tree

1 file changed

+38
-43
lines changed

1 file changed

+38
-43
lines changed

SYCL/GroupAlgorithm/SYCL2020/sort.cpp

Lines changed: 38 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,16 @@
22
// RUN: %CPU_RUN_PLACEHOLDER %t.out
33
// RUN: %GPU_RUN_PLACEHOLDER %t.out
44
// RUN: %ACC_RUN_PLACEHOLDER %t.out
5-
//
6-
// RUNx: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -I . -o %t13.out
75

86
#include "support.h"
9-
#include <CL/sycl.hpp>
7+
#include <sycl/sycl.hpp>
108

119
#include <algorithm>
1210
#include <iostream>
1311
#include <random>
1412
#include <vector>
1513

16-
namespace my_sycl = sycl::ext::oneapi::experimental;
14+
using namespace sycl::ext::oneapi::experimental;
1715

1816
auto async_handler_ = [](sycl::exception_list ex_list) {
1917
for (auto &ex : ex_list) {
@@ -38,14 +36,11 @@ struct CustomFunctor {
3836
}
3937
};
4038

41-
// we need it since using std::abs leads to compilation error
42-
template <typename T> T my_abs(T x) { return x >= 0 ? x : -x; }
43-
4439
template <typename T> bool check(T lhs, T rhs, float epsilon) {
45-
return my_abs(lhs - rhs) > epsilon;
40+
return sycl::abs(lhs - rhs) > epsilon;
4641
}
4742
bool check(CustomType lhs, CustomType rhs, float epsilon) {
48-
return my_abs(lhs.x - rhs.x) > epsilon;
43+
return sycl::abs(lhs.x - rhs.x) > epsilon;
4944
}
5045

5146
template <typename T>
@@ -105,7 +100,7 @@ int test_sort_over_group(sycl::queue &q, std::size_t local,
105100

106101
sycl::range<dim> local_range = get_range<dim>(local);
107102

108-
std::size_t local_memory_size = my_sycl::default_sorter<>::memory_required<T>(
103+
std::size_t local_memory_size = default_sorter<>::memory_required<T>(
109104
sycl::memory_scope::work_group, local_range);
110105

111106
if (local_memory_size >
@@ -129,24 +124,24 @@ int test_sort_over_group(sycl::queue &q, std::size_t local,
129124
case 0:
130125
if constexpr (std::is_same_v<Compare, std::less<T>> &&
131126
!std::is_same_v<T, CustomType>)
132-
aI1[local_id] = my_sycl::sort_over_group(
133-
my_sycl::group_with_scratchpad(
127+
aI1[local_id] = sort_over_group(
128+
group_with_scratchpad(
134129
id.get_group(),
135130
sycl::span{&scratch[0], local_memory_size}),
136131
aI1[local_id]);
137132
break;
138133
case 1:
139-
aI1[local_id] = my_sycl::sort_over_group(
140-
my_sycl::group_with_scratchpad(
134+
aI1[local_id] = sort_over_group(
135+
group_with_scratchpad(
141136
id.get_group(),
142137
sycl::span{&scratch[0], local_memory_size}),
143138
aI1[local_id], comp);
144139
break;
145140
case 2:
146-
aI1[local_id] = my_sycl::sort_over_group(
147-
id.get_group(), aI1[local_id],
148-
my_sycl::default_sorter<Compare>(
149-
sycl::span{&scratch[0], local_memory_size}));
141+
aI1[local_id] =
142+
sort_over_group(id.get_group(), aI1[local_id],
143+
default_sorter<Compare>(sycl::span{
144+
&scratch[0], local_memory_size}));
150145
break;
151146
}
152147
});
@@ -160,8 +155,8 @@ int test_joint_sort(sycl::queue &q, std::size_t n_items, std::size_t local,
160155
auto n = bufI1.size();
161156
auto n_groups = (n - 1) / n_items + 1;
162157

163-
std::size_t local_memory_size = my_sycl::default_sorter<>::memory_required<T>(
164-
sycl::memory_scope::work_group, n);
158+
std::size_t local_memory_size =
159+
default_sorter<>::memory_required<T>(sycl::memory_scope::work_group, n);
165160
if (local_memory_size >
166161
q.get_device().template get_info<sycl::info::device::local_mem_size>())
167162
std::cout << "local_memory_size = " << local_memory_size << ", available = "
@@ -187,27 +182,26 @@ int test_joint_sort(sycl::queue &q, std::size_t n_items, std::size_t local,
187182
case 0:
188183
if constexpr (std::is_same_v<Compare, std::less<T>> &&
189184
!std::is_same_v<T, CustomType>)
190-
my_sycl::joint_sort(
191-
my_sycl::group_with_scratchpad(
192-
id.get_group(),
193-
sycl::span{&scratch[0], local_memory_size}),
194-
ptr_keys,
195-
ptr_keys + sycl::min(n_items, n - group_id * n_items));
185+
joint_sort(group_with_scratchpad(
186+
id.get_group(),
187+
sycl::span{&scratch[0], local_memory_size}),
188+
ptr_keys,
189+
ptr_keys +
190+
sycl::min(n_items, n - group_id * n_items));
196191
break;
197192
case 1:
198-
my_sycl::joint_sort(
199-
my_sycl::group_with_scratchpad(
200-
id.get_group(),
201-
sycl::span{&scratch[0], local_memory_size}),
202-
ptr_keys,
203-
ptr_keys + sycl::min(n_items, n - group_id * n_items), comp);
193+
joint_sort(group_with_scratchpad(
194+
id.get_group(),
195+
sycl::span{&scratch[0], local_memory_size}),
196+
ptr_keys,
197+
ptr_keys + sycl::min(n_items, n - group_id * n_items),
198+
comp);
204199
break;
205200
case 2:
206-
my_sycl::joint_sort(
207-
id.get_group(), ptr_keys,
208-
ptr_keys + sycl::min(n_items, n - group_id * n_items),
209-
my_sycl::default_sorter<Compare>(
210-
sycl::span{&scratch[0], local_memory_size}));
201+
joint_sort(id.get_group(), ptr_keys,
202+
ptr_keys + sycl::min(n_items, n - group_id * n_items),
203+
default_sorter<Compare>(
204+
sycl::span{&scratch[0], local_memory_size}));
211205
break;
212206
}
213207
});
@@ -217,7 +211,7 @@ int test_joint_sort(sycl::queue &q, std::size_t n_items, std::size_t local,
217211

218212
template <typename T, typename Compare>
219213
int test_custom_sorter(sycl::queue &q, sycl::buffer<T> &bufI1, Compare comp) {
220-
std::size_t local = 256;
214+
std::size_t local = 4;
221215
auto n = bufI1.size();
222216
if (n > local)
223217
return -1;
@@ -230,9 +224,8 @@ int test_custom_sorter(sycl::queue &q, sycl::buffer<T> &bufI1, Compare comp) {
230224
sycl::nd_range<2>({local, 1}, {local, 1}), [=](sycl::nd_item<2> id) {
231225
auto ptr = aI1.get_pointer();
232226

233-
my_sycl::joint_sort(
234-
id.get_group(), ptr, ptr + n,
235-
bubble_sorter<Compare>{comp, id.get_local_linear_id()});
227+
joint_sort(id.get_group(), ptr, ptr + n,
228+
bubble_sorter<Compare>{comp, id.get_local_linear_id()});
236229
});
237230
}).wait_and_throw();
238231
return 1;
@@ -243,9 +236,11 @@ void run_sort(sycl::queue &q, std::vector<T> &in, std::size_t size,
243236
Compare comp, int test_case, int sort_case) {
244237
std::vector<T> in2(in.begin(), in.begin() + size);
245238
std::vector<T> expected(in.begin(), in.begin() + size);
246-
std::size_t local =
239+
constexpr size_t work_size_limit = 4;
240+
std::size_t local = std::min(
241+
work_size_limit,
247242
q.get_device()
248-
.template get_info<sycl::info::device::max_work_group_size>();
243+
.template get_info<sycl::info::device::max_work_group_size>());
249244
local = std::min(local, size);
250245
auto n_items = items_per_work_item * local;
251246

0 commit comments

Comments
 (0)