Skip to content

Commit ec46177

Browse files
authored
Conversion from raw to multi_ptr should be done with address_space_cast (#1538)
1 parent 7cb1fd7 commit ec46177

File tree

4 files changed

+49
-43
lines changed

4 files changed

+49
-43
lines changed

dpnp/backend/kernels/dpnp_krnl_bitwise.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,14 @@ DPCTLSyclEventRef dpnp_invert_c(DPCTLSyclQueueRef q_ref,
6868
sg.get_group_id()[0] * max_sg_size);
6969

7070
if (start + static_cast<size_t>(vec_sz) * max_sg_size < size) {
71-
using multi_ptrT =
72-
sycl::multi_ptr<_DataType,
73-
sycl::access::address_space::global_space>;
71+
auto input_multi_ptr = sycl::address_space_cast<
72+
sycl::access::address_space::global_space,
73+
sycl::access::decorated::yes>(&input_data[start]);
74+
auto result_multi_ptr = sycl::address_space_cast<
75+
sycl::access::address_space::global_space,
76+
sycl::access::decorated::yes>(&result[start]);
7477

75-
sycl::vec<_DataType, vec_sz> x =
76-
sg.load<vec_sz>(multi_ptrT(&input_data[start]));
78+
sycl::vec<_DataType, vec_sz> x = sg.load<vec_sz>(input_multi_ptr);
7779
sycl::vec<_DataType, vec_sz> res_vec;
7880

7981
if constexpr (std::is_same_v<_DataType, bool>) {
@@ -86,7 +88,7 @@ DPCTLSyclEventRef dpnp_invert_c(DPCTLSyclQueueRef q_ref,
8688
res_vec = ~x;
8789
}
8890

89-
sg.store<vec_sz>(multi_ptrT(&result[start]), res_vec);
91+
sg.store<vec_sz>(result_multi_ptr, res_vec);
9092
}
9193
else {
9294
for (size_t k = start + sg.get_local_id()[0]; k < size;

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,12 +1344,17 @@ static void func_map_init_elemwise_1arg_1type(func_map_t &fmap)
13441344
\
13451345
if (start + static_cast<size_t>(vec_sz) * max_sg_size < \
13461346
result_size) { \
1347-
using input1_ptrT = \
1348-
sycl::multi_ptr<_DataType_input1, global_space>; \
1349-
using input2_ptrT = \
1350-
sycl::multi_ptr<_DataType_input2, global_space>; \
1351-
using result_ptrT = \
1352-
sycl::multi_ptr<_DataType_output, global_space>; \
1347+
auto input1_multi_ptr = sycl::address_space_cast< \
1348+
sycl::access::address_space::global_space, \
1349+
sycl::access::decorated::yes>( \
1350+
&input1_data[start]); \
1351+
auto input2_multi_ptr = sycl::address_space_cast< \
1352+
sycl::access::address_space::global_space, \
1353+
sycl::access::decorated::yes>( \
1354+
&input2_data[start]); \
1355+
auto result_multi_ptr = sycl::address_space_cast< \
1356+
sycl::access::address_space::global_space, \
1357+
sycl::access::decorated::yes>(&result[start]); \
13531358
\
13541359
sycl::vec<_DataType_output, vec_sz> res_vec; \
13551360
\
@@ -1363,11 +1368,9 @@ static void func_map_init_elemwise_1arg_1type(func_map_t &fmap)
13631368
_DataType_output>) \
13641369
{ \
13651370
sycl::vec<_DataType_input1, vec_sz> x1 = \
1366-
sg.load<vec_sz>( \
1367-
input1_ptrT(&input1_data[start])); \
1371+
sg.load<vec_sz>(input1_multi_ptr); \
13681372
sycl::vec<_DataType_input2, vec_sz> x2 = \
1369-
sg.load<vec_sz>( \
1370-
input2_ptrT(&input2_data[start])); \
1373+
sg.load<vec_sz>(input2_multi_ptr); \
13711374
\
13721375
res_vec = __vec_operation__; \
13731376
} \
@@ -1377,33 +1380,28 @@ static void func_map_init_elemwise_1arg_1type(func_map_t &fmap)
13771380
sycl::vec<_DataType_output, vec_sz> x1 = \
13781381
dpnp_vec_cast<_DataType_output, \
13791382
_DataType_input1, vec_sz>( \
1380-
sg.load<vec_sz>(input1_ptrT( \
1381-
&input1_data[start]))); \
1383+
sg.load<vec_sz>(input1_multi_ptr)); \
13821384
sycl::vec<_DataType_output, vec_sz> x2 = \
13831385
dpnp_vec_cast<_DataType_output, \
13841386
_DataType_input2, vec_sz>( \
1385-
sg.load<vec_sz>(input2_ptrT( \
1386-
&input2_data[start]))); \
1387+
sg.load<vec_sz>(input2_multi_ptr)); \
13871388
\
13881389
res_vec = __vec_operation__; \
13891390
} \
13901391
} \
13911392
else { \
13921393
sycl::vec<_DataType_input1, vec_sz> x1 = \
1393-
sg.load<vec_sz>( \
1394-
input1_ptrT(&input1_data[start])); \
1394+
sg.load<vec_sz>(input1_multi_ptr); \
13951395
sycl::vec<_DataType_input2, vec_sz> x2 = \
1396-
sg.load<vec_sz>( \
1397-
input2_ptrT(&input2_data[start])); \
1396+
sg.load<vec_sz>(input2_multi_ptr); \
13981397
\
13991398
for (size_t k = 0; k < vec_sz; ++k) { \
14001399
const _DataType_output input1_elem = x1[k]; \
14011400
const _DataType_output input2_elem = x2[k]; \
14021401
res_vec[k] = __operation__; \
14031402
} \
14041403
} \
1405-
sg.store<vec_sz>(result_ptrT(&result[start]), \
1406-
res_vec); \
1404+
sg.store<vec_sz>(result_multi_ptr, res_vec); \
14071405
} \
14081406
else { \
14091407
for (size_t k = start + sg.get_local_id()[0]; \

dpnp/backend/kernels/dpnp_krnl_logic.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -537,22 +537,28 @@ DPCTLSyclEventRef (*dpnp_any_ext_c)(DPCTLSyclQueueRef,
537537
\
538538
if (start + static_cast<size_t>(vec_sz) * max_sg_size < \
539539
result_size) { \
540-
sycl::vec<_DataType_input1, vec_sz> x1 = sg.load<vec_sz>( \
541-
sycl::multi_ptr<_DataType_input1, global_space>( \
542-
&input1_data[start])); \
543-
sycl::vec<_DataType_input2, vec_sz> x2 = sg.load<vec_sz>( \
544-
sycl::multi_ptr<_DataType_input2, global_space>( \
545-
&input2_data[start])); \
540+
auto input1_multi_ptr = sycl::address_space_cast< \
541+
sycl::access::address_space::global_space, \
542+
sycl::access::decorated::yes>(&input1_data[start]); \
543+
auto input2_multi_ptr = sycl::address_space_cast< \
544+
sycl::access::address_space::global_space, \
545+
sycl::access::decorated::yes>(&input2_data[start]); \
546+
auto result_multi_ptr = sycl::address_space_cast< \
547+
sycl::access::address_space::global_space, \
548+
sycl::access::decorated::yes>(&result[start]); \
549+
\
550+
sycl::vec<_DataType_input1, vec_sz> x1 = \
551+
sg.load<vec_sz>(input1_multi_ptr); \
552+
sycl::vec<_DataType_input2, vec_sz> x2 = \
553+
sg.load<vec_sz>(input2_multi_ptr); \
546554
sycl::vec<bool, vec_sz> res_vec; \
547555
\
548556
for (size_t k = 0; k < vec_sz; ++k) { \
549557
const _DataType_input1 input1_elem = x1[k]; \
550558
const _DataType_input2 input2_elem = x2[k]; \
551559
res_vec[k] = __operation__; \
552560
} \
553-
sg.store<vec_sz>( \
554-
sycl::multi_ptr<bool, global_space>(&result[start]), \
555-
res_vec); \
561+
sg.store<vec_sz>(result_multi_ptr, res_vec); \
556562
} \
557563
else { \
558564
for (size_t k = start; k < result_size; ++k) { \

dpnp/backend/kernels/dpnp_krnl_mathematical.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,6 @@ DPCTLSyclEventRef
151151

152152
constexpr size_t lws = 64;
153153
constexpr unsigned int vec_sz = 8;
154-
constexpr sycl::access::address_space global_space =
155-
sycl::access::address_space::global_space;
156154

157155
auto gws_range =
158156
sycl::range<1>(((size + lws * vec_sz - 1) / (lws * vec_sz)) * lws);
@@ -166,18 +164,20 @@ DPCTLSyclEventRef
166164
sg.get_group_id()[0] * max_sg_size);
167165

168166
if (start + static_cast<size_t>(vec_sz) * max_sg_size < size) {
169-
using input_ptrT =
170-
sycl::multi_ptr<_DataType_input, global_space>;
171-
using result_ptrT =
172-
sycl::multi_ptr<_DataType_output, global_space>;
167+
auto array_multi_ptr = sycl::address_space_cast<
168+
sycl::access::address_space::global_space,
169+
sycl::access::decorated::yes>(&array1[start]);
170+
auto result_multi_ptr = sycl::address_space_cast<
171+
sycl::access::address_space::global_space,
172+
sycl::access::decorated::yes>(&result[start]);
173173

174174
sycl::vec<_DataType_input, vec_sz> data_vec =
175-
sg.load<vec_sz>(input_ptrT(&array1[start]));
175+
sg.load<vec_sz>(array_multi_ptr);
176176

177177
sycl::vec<_DataType_output, vec_sz> res_vec =
178178
sycl::abs(data_vec);
179179

180-
sg.store<vec_sz>(result_ptrT(&result[start]), res_vec);
180+
sg.store<vec_sz>(result_multi_ptr, res_vec);
181181
}
182182
else {
183183
for (size_t k = start + sg.get_local_id()[0]; k < size;

0 commit comments

Comments
 (0)