@@ -52,15 +52,15 @@ template <typename OrthogIndexerT,
5252 typename LocalAccessorT>
5353struct MaskedExtractStridedFunctor
5454{
55- MaskedExtractStridedFunctor (const char *src_data_p,
56- const char *cumsum_data_p,
57- char *dst_data_p,
55+ MaskedExtractStridedFunctor (const dataT *src_data_p,
56+ const indT *cumsum_data_p,
57+ dataT *dst_data_p,
5858 size_t masked_iter_size,
5959 const OrthogIndexerT &orthog_src_dst_indexer_,
6060 const MaskedSrcIndexerT &masked_src_indexer_,
6161 const MaskedDstIndexerT &masked_dst_indexer_,
6262 const LocalAccessorT &lacc_)
63- : src_cp (src_data_p), cumsum_cp (cumsum_data_p), dst_cp (dst_data_p),
63+ : src (src_data_p), cumsum (cumsum_data_p), dst (dst_data_p),
6464 masked_nelems (masked_iter_size),
6565 orthog_src_dst_indexer(orthog_src_dst_indexer_),
6666 masked_src_indexer(masked_src_indexer_),
@@ -72,24 +72,20 @@ struct MaskedExtractStridedFunctor
7272
7373 void operator ()(sycl::nd_item<2 > ndit) const
7474 {
75- const dataT *src_data = reinterpret_cast <const dataT *>(src_cp);
76- dataT *dst_data = reinterpret_cast <dataT *>(dst_cp);
77- const indT *cumsum_data = reinterpret_cast <const indT *>(cumsum_cp);
78-
79- const size_t orthog_i = ndit.get_global_id (0 );
80- const size_t group_i = ndit.get_group (1 );
75+ const std::size_t orthog_i = ndit.get_global_id (0 );
8176 const std::uint32_t l_i = ndit.get_local_id (1 );
8277 const std::uint32_t lws = ndit.get_local_range (1 );
8378
84- const size_t masked_block_start = group_i * lws ;
85- const size_t masked_i = masked_block_start + l_i;
79+ const std:: size_t masked_i = ndit. get_global_id ( 1 ) ;
80+ const std:: size_t masked_block_start = masked_i - l_i;
8681
82+ const std::size_t max_offset = masked_nelems + 1 ;
8783 for (std::uint32_t i = l_i; i < lacc.size (); i += lws) {
8884 const size_t offset = masked_block_start + i;
8985 lacc[i] = (offset == 0 ) ? indT (0 )
90- : (offset - 1 < masked_nelems )
91- ? cumsum_data [offset - 1 ]
92- : cumsum_data [masked_nelems - 1 ] + 1 ;
86+ : (offset < max_offset )
87+ ? cumsum [offset - 1 ]
88+ : cumsum [masked_nelems - 1 ] + 1 ;
9389 }
9490
9591 sycl::group_barrier (ndit.get_group ());
@@ -110,14 +106,14 @@ struct MaskedExtractStridedFunctor
110106 masked_dst_indexer (current_running_count - 1 ) +
111107 orthog_offsets.get_second_offset ();
112108
113- dst_data [total_dst_offset] = src_data [total_src_offset];
109+ dst [total_dst_offset] = src [total_src_offset];
114110 }
115111 }
116112
117113private:
118- const char *src_cp = nullptr ;
119- const char *cumsum_cp = nullptr ;
120- char *dst_cp = nullptr ;
114+ const dataT *src = nullptr ;
115+ const indT *cumsum = nullptr ;
116+ dataT *dst = nullptr ;
121117 const size_t masked_nelems = 0 ;
122118 // has nd, shape, src_strides, dst_strides for
123119 // dimensions that ARE NOT masked
@@ -138,15 +134,15 @@ template <typename OrthogIndexerT,
138134 typename LocalAccessorT>
139135struct MaskedPlaceStridedFunctor
140136{
141- MaskedPlaceStridedFunctor (char *dst_data_p,
142- const char *cumsum_data_p,
143- const char *rhs_data_p,
137+ MaskedPlaceStridedFunctor (dataT *dst_data_p,
138+ const indT *cumsum_data_p,
139+ const dataT *rhs_data_p,
144140 size_t masked_iter_size,
145141 const OrthogIndexerT &orthog_dst_rhs_indexer_,
146142 const MaskedDstIndexerT &masked_dst_indexer_,
147143 const MaskedRhsIndexerT &masked_rhs_indexer_,
148144 const LocalAccessorT &lacc_)
149- : dst_cp (dst_data_p), cumsum_cp (cumsum_data_p), rhs_cp (rhs_data_p),
145+ : dst (dst_data_p), cumsum (cumsum_data_p), rhs (rhs_data_p),
150146 masked_nelems (masked_iter_size),
151147 orthog_dst_rhs_indexer(orthog_dst_rhs_indexer_),
152148 masked_dst_indexer(masked_dst_indexer_),
@@ -158,24 +154,20 @@ struct MaskedPlaceStridedFunctor
158154
159155 void operator ()(sycl::nd_item<2 > ndit) const
160156 {
161- dataT *dst_data = reinterpret_cast <dataT *>(dst_cp);
162- const indT *cumsum_data = reinterpret_cast <const indT *>(cumsum_cp);
163- const dataT *rhs_data = reinterpret_cast <const dataT *>(rhs_cp);
164-
165157 const std::size_t orthog_i = ndit.get_global_id (0 );
166- const std::size_t group_i = ndit.get_group (1 );
167158 const std::uint32_t l_i = ndit.get_local_id (1 );
168159 const std::uint32_t lws = ndit.get_local_range (1 );
169160
170- const size_t masked_block_start = group_i * lws ;
171- const size_t masked_i = masked_block_start + l_i;
161+ const size_t masked_i = ndit. get_global_id ( 1 ) ;
162+ const size_t masked_block_start = masked_i - l_i;
172163
164+ const std::size_t max_offset = masked_nelems + 1 ;
173165 for (std::uint32_t i = l_i; i < lacc.size (); i += lws) {
174166 const size_t offset = masked_block_start + i;
175167 lacc[i] = (offset == 0 ) ? indT (0 )
176- : (offset - 1 < masked_nelems )
177- ? cumsum_data [offset - 1 ]
178- : cumsum_data [masked_nelems - 1 ] + 1 ;
168+ : (offset < max_offset )
169+ ? cumsum [offset - 1 ]
170+ : cumsum [masked_nelems - 1 ] + 1 ;
179171 }
180172
181173 sycl::group_barrier (ndit.get_group ());
@@ -196,14 +188,14 @@ struct MaskedPlaceStridedFunctor
196188 masked_rhs_indexer (current_running_count - 1 ) +
197189 orthog_offsets.get_second_offset ();
198190
199- dst_data [total_dst_offset] = rhs_data [total_rhs_offset];
191+ dst [total_dst_offset] = rhs [total_rhs_offset];
200192 }
201193 }
202194
203195private:
204- char *dst_cp = nullptr ;
205- const char *cumsum_cp = nullptr ;
206- const char *rhs_cp = nullptr ;
196+ dataT *dst = nullptr ;
197+ const indT *cumsum = nullptr ;
198+ const dataT *rhs = nullptr ;
207199 const size_t masked_nelems = 0 ;
208200 // has nd, shape, dst_strides, rhs_strides for
209201 // dimensions that ARE NOT masked
@@ -218,6 +210,26 @@ struct MaskedPlaceStridedFunctor
218210
219211// ======= Masked extraction ================================
220212
213+ namespace {
214+
215+ template <std::size_t I, std::size_t ... IR>
216+ std::size_t _get_lws_impl (std::size_t n) {
217+ if constexpr (sizeof ...(IR) == 0 ) {
218+ return I;
219+ } else {
220+ return (n < I) ? _get_lws_impl<IR...>(n) : I;
221+ }
222+ }
223+
224+ std::size_t get_lws (std::size_t n) {
225+ constexpr std::size_t lws0 = 256u ;
226+ constexpr std::size_t lws1 = 128u ;
227+ constexpr std::size_t lws2 = 64u ;
228+ return _get_lws_impl<lws0, lws1, lws2>(n);
229+ }
230+
231+ } // end of anonymous namespace
232+
221233template <typename MaskedDstIndexerT, typename dataT, typename indT>
222234class masked_extract_all_slices_contig_impl_krn ;
223235
@@ -258,16 +270,21 @@ sycl::event masked_extract_all_slices_contig_impl(
258270 Strided1DIndexer, dataT, indT,
259271 LocalAccessorT>;
260272
261- constexpr std::size_t nominal_lws = 256 ;
262273 const std::size_t masked_extent = iteration_size;
263- const std::size_t lws = std::min (masked_extent, nominal_lws);
274+
275+ const std::size_t lws = get_lws (masked_extent);
276+
264277 const std::size_t n_groups = (iteration_size + lws - 1 ) / lws;
265278
266279 sycl::range<2 > gRange {1 , n_groups * lws};
267280 sycl::range<2 > lRange{1 , lws};
268281
269282 sycl::nd_range<2 > ndRange (gRange , lRange);
270283
284+ const dataT *src_tp = reinterpret_cast <const dataT *>(src_p);
285+ const indT *cumsum_tp = reinterpret_cast <const indT *>(cumsum_p);
286+ dataT *dst_tp = reinterpret_cast <dataT *>(dst_p);
287+
271288 sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
272289 cgh.depends_on (depends);
273290
@@ -276,7 +293,7 @@ sycl::event masked_extract_all_slices_contig_impl(
276293
277294 cgh.parallel_for <KernelName>(
278295 ndRange,
279- Impl (src_p, cumsum_p, dst_p , masked_extent, orthog_src_dst_indexer,
296+ Impl (src_tp, cumsum_tp, dst_tp , masked_extent, orthog_src_dst_indexer,
280297 masked_src_indexer, masked_dst_indexer, lacc));
281298 });
282299
@@ -332,16 +349,21 @@ sycl::event masked_extract_all_slices_strided_impl(
332349 StridedIndexer, Strided1DIndexer,
333350 dataT, indT, LocalAccessorT>;
334351
335- constexpr std::size_t nominal_lws = 256 ;
336352 const std::size_t masked_nelems = iteration_size;
337- const std::size_t lws = std::min (masked_nelems, nominal_lws);
353+
354+ const std::size_t lws = get_lws (masked_nelems);
355+
338356 const std::size_t n_groups = (masked_nelems + lws - 1 ) / lws;
339357
340358 sycl::range<2 > gRange {1 , n_groups * lws};
341359 sycl::range<2 > lRange{1 , lws};
342360
343361 sycl::nd_range<2 > ndRange (gRange , lRange);
344362
363+ const dataT *src_tp = reinterpret_cast <const dataT *>(src_p);
364+ const indT *cumsum_tp = reinterpret_cast <const indT *>(cumsum_p);
365+ dataT *dst_tp = reinterpret_cast <dataT *>(dst_p);
366+
345367 sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
346368 cgh.depends_on (depends);
347369
@@ -350,7 +372,7 @@ sycl::event masked_extract_all_slices_strided_impl(
350372
351373 cgh.parallel_for <KernelName>(
352374 ndRange,
353- Impl (src_p, cumsum_p, dst_p , iteration_size, orthog_src_dst_indexer,
375+ Impl (src_tp, cumsum_tp, dst_tp , iteration_size, orthog_src_dst_indexer,
354376 masked_src_indexer, masked_dst_indexer, lacc));
355377 });
356378
@@ -422,9 +444,10 @@ sycl::event masked_extract_some_slices_strided_impl(
422444 StridedIndexer, Strided1DIndexer,
423445 dataT, indT, LocalAccessorT>;
424446
425- const size_t nominal_lws = 256 ;
426447 const std::size_t masked_extent = masked_nelems;
427- const size_t lws = std::min (masked_extent, nominal_lws);
448+
449+ const std::size_t lws = get_lws (masked_extent);
450+
428451 const size_t n_groups = ((masked_extent + lws - 1 ) / lws);
429452 const size_t orthog_extent = static_cast <size_t >(orthog_nelems);
430453
@@ -433,6 +456,10 @@ sycl::event masked_extract_some_slices_strided_impl(
433456
434457 sycl::nd_range<2 > ndRange (gRange , lRange);
435458
459+ const dataT *src_tp = reinterpret_cast <const dataT *>(src_p);
460+ const indT *cumsum_tp = reinterpret_cast <const indT *>(cumsum_p);
461+ dataT *dst_tp = reinterpret_cast <dataT *>(dst_p);
462+
436463 sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
437464 cgh.depends_on (depends);
438465
@@ -442,7 +469,7 @@ sycl::event masked_extract_some_slices_strided_impl(
442469
443470 cgh.parallel_for <KernelName>(
444471 ndRange,
445- Impl (src_p, cumsum_p, dst_p , masked_nelems, orthog_src_dst_indexer,
472+ Impl (src_tp, cumsum_tp, dst_tp , masked_nelems, orthog_src_dst_indexer,
446473 masked_src_indexer, masked_dst_indexer, lacc));
447474 });
448475
@@ -567,6 +594,10 @@ sycl::event masked_place_all_slices_strided_impl(
567594
568595 using LocalAccessorT = sycl::local_accessor<indT, 1 >;
569596
597+ dataT *dst_tp = reinterpret_cast <dataT *>(dst_p);
598+ const dataT *rhs_tp = reinterpret_cast <const dataT *>(rhs_p);
599+ const indT *cumsum_tp = reinterpret_cast <const indT *>(cumsum_p);
600+
570601 sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
571602 cgh.depends_on (depends);
572603
@@ -578,7 +609,7 @@ sycl::event masked_place_all_slices_strided_impl(
578609 MaskedPlaceStridedFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
579610 Strided1DCyclicIndexer, dataT, indT,
580611 LocalAccessorT>(
581- dst_p, cumsum_p, rhs_p , iteration_size, orthog_dst_rhs_indexer,
612+ dst_tp, cumsum_tp, rhs_tp , iteration_size, orthog_dst_rhs_indexer,
582613 masked_dst_indexer, masked_rhs_indexer, lacc));
583614 });
584615
@@ -659,6 +690,10 @@ sycl::event masked_place_some_slices_strided_impl(
659690
660691 using LocalAccessorT = sycl::local_accessor<indT, 1 >;
661692
693+ dataT *dst_tp = reinterpret_cast <dataT *>(dst_p);
694+ const dataT *rhs_tp = reinterpret_cast <const dataT *>(rhs_p);
695+ const indT* cumsum_tp = reinterpret_cast <const indT *>(cumsum_p);
696+
662697 sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
663698 cgh.depends_on (depends);
664699
@@ -670,7 +705,7 @@ sycl::event masked_place_some_slices_strided_impl(
670705 MaskedPlaceStridedFunctor<TwoOffsets_StridedIndexer, StridedIndexer,
671706 Strided1DCyclicIndexer, dataT, indT,
672707 LocalAccessorT>(
673- dst_p, cumsum_p, rhs_p , masked_nelems, orthog_dst_rhs_indexer,
708+ dst_tp, cumsum_tp, rhs_tp , masked_nelems, orthog_dst_rhs_indexer,
674709 masked_dst_indexer, masked_rhs_indexer, lacc));
675710 });
676711
0 commit comments