Skip to content

Commit 7b82b2b

Browse files
committed
Optimization for in-place arithmetic operators between rows and matrices
- In-place operations from C-contiguous rows into C-contiguous matrices show significant performance improvements
1 parent 81553f8 commit 7b82b2b

File tree

6 files changed

+386
-10
lines changed

6 files changed

+386
-10
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,61 @@ struct AddInplaceStridedFactory
480480
}
481481
};
482482

483+
template <typename argT, typename resT>
484+
class add_inplace_row_matrix_broadcast_sg_krn;
485+
486+
template <typename argT, typename resT>
487+
using AddInplaceRowMatrixBroadcastingFunctor =
488+
elementwise_common::BinaryInplaceRowMatrixBroadcastingFunctor<
489+
argT,
490+
resT,
491+
AddInplaceFunctor<argT, resT>>;
492+
493+
template <typename argT, typename resT>
494+
sycl::event add_inplace_row_matrix_broadcast_impl(
495+
sycl::queue exec_q,
496+
std::vector<sycl::event> &host_tasks,
497+
size_t n0,
498+
size_t n1,
499+
const char *vec_p, // typeless pointer to (n1,) contiguous row
500+
py::ssize_t vec_offset,
501+
char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix
502+
py::ssize_t mat_offset,
503+
const std::vector<sycl::event> &depends = {})
504+
{
505+
return elementwise_common::binary_inplace_row_matrix_broadcast_impl<
506+
argT, resT, AddInplaceRowMatrixBroadcastingFunctor,
507+
add_inplace_row_matrix_broadcast_sg_krn>(exec_q, host_tasks, n0, n1,
508+
vec_p, vec_offset, mat_p,
509+
mat_offset, depends);
510+
}
511+
512+
template <typename fnT, typename T1, typename T2>
513+
struct AddInplaceRowMatrixBroadcastFactory
514+
{
515+
fnT get()
516+
{
517+
using resT = typename AddOutputType<T1, T2>::value_type;
518+
if constexpr (std::is_same_v<resT, void>) {
519+
fnT fn = nullptr;
520+
return fn;
521+
}
522+
else {
523+
if constexpr (dpctl::tensor::type_utils::is_complex<T1>::value ||
524+
dpctl::tensor::type_utils::is_complex<T2>::value ||
525+
dpctl::tensor::type_utils::is_complex<resT>::value)
526+
{
527+
fnT fn = nullptr;
528+
return fn;
529+
}
530+
else {
531+
fnT fn = add_inplace_row_matrix_broadcast_impl<T1, T2>;
532+
return fn;
533+
}
534+
}
535+
}
536+
};
537+
483538
} // namespace add
484539
} // namespace kernels
485540
} // namespace tensor

dpctl/tensor/libtensor/include/kernels/elementwise_functions/common_inplace.hpp

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,60 @@ struct BinaryInplaceStridedFunctor
191191
}
192192
};
193193

194+
template <typename argT, typename resT, typename BinaryOperatorT>
195+
struct BinaryInplaceRowMatrixBroadcastingFunctor
196+
{
197+
private:
198+
const argT *padded_vec;
199+
resT *mat;
200+
size_t n_elems;
201+
size_t n1;
202+
203+
public:
204+
BinaryInplaceRowMatrixBroadcastingFunctor(const argT *row_tp,
205+
resT *mat_tp,
206+
size_t n_elems_in_mat,
207+
size_t n_elems_in_row)
208+
: padded_vec(row_tp), mat(mat_tp), n_elems(n_elems_in_mat),
209+
n1(n_elems_in_row)
210+
{
211+
}
212+
213+
void operator()(sycl::nd_item<1> ndit) const
214+
{
215+
BinaryOperatorT op{};
216+
static_assert(BinaryOperatorT::supports_sg_loadstore::value);
217+
218+
auto sg = ndit.get_sub_group();
219+
size_t gid = ndit.get_global_linear_id();
220+
221+
std::uint8_t sgSize = sg.get_local_range()[0];
222+
size_t base = gid - sg.get_local_id()[0];
223+
224+
if (base + sgSize < n_elems) {
225+
using in_ptrT =
226+
sycl::multi_ptr<const argT,
227+
sycl::access::address_space::global_space>;
228+
using res_ptrT =
229+
sycl::multi_ptr<resT,
230+
sycl::access::address_space::global_space>;
231+
232+
const argT vec_el = sg.load(in_ptrT(&padded_vec[base % n1]));
233+
resT mat_el = sg.load(res_ptrT(&mat[base]));
234+
235+
op(mat_el, vec_el);
236+
237+
sg.store(res_ptrT(&mat[base]), mat_el);
238+
}
239+
else {
240+
for (size_t k = base + sg.get_local_id()[0]; k < n_elems;
241+
k += sgSize) {
242+
op(mat[k], padded_vec[k % n1]);
243+
}
244+
}
245+
}
246+
};
247+
194248
// Typedefs for function pointers
195249

196250
typedef sycl::event (*binary_inplace_contig_impl_fn_ptr_t)(
@@ -214,6 +268,17 @@ typedef sycl::event (*binary_inplace_strided_impl_fn_ptr_t)(
214268
const std::vector<sycl::event> &,
215269
const std::vector<sycl::event> &);
216270

271+
typedef sycl::event (*binary_inplace_row_matrix_broadcast_impl_fn_ptr_t)(
272+
sycl::queue,
273+
std::vector<sycl::event> &,
274+
size_t,
275+
size_t,
276+
const char *,
277+
py::ssize_t,
278+
char *,
279+
py::ssize_t,
280+
const std::vector<sycl::event> &);
281+
217282
template <typename argTy,
218283
typename resTy,
219284
template <typename T1, typename T2, unsigned int vs, unsigned int nv>
@@ -289,6 +354,79 @@ binary_inplace_strided_impl(sycl::queue exec_q,
289354
return comp_ev;
290355
}
291356

357+
template <typename argT,
358+
typename resT,
359+
template <typename T1, typename T3>
360+
class BinaryInplaceRowMatrixBroadcastFunctorT,
361+
template <typename T1, typename T3>
362+
class kernel_name>
363+
sycl::event binary_inplace_row_matrix_broadcast_impl(
364+
sycl::queue exec_q,
365+
std::vector<sycl::event> &host_tasks,
366+
size_t n0,
367+
size_t n1,
368+
const char *vec_p, // typeless pointer to (n1,) contiguous row
369+
py::ssize_t vec_offset,
370+
char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix
371+
py::ssize_t mat_offset,
372+
const std::vector<sycl::event> &depends = {})
373+
{
374+
const argT *vec = reinterpret_cast<const argT *>(vec_p) + vec_offset;
375+
resT *mat = reinterpret_cast<resT *>(mat_p) + mat_offset;
376+
377+
const auto &dev = exec_q.get_device();
378+
const auto &sg_sizes = dev.get_info<sycl::info::device::sub_group_sizes>();
379+
// Get device-specific kernel info max_sub_group_size
380+
size_t max_sgSize =
381+
*(std::max_element(std::begin(sg_sizes), std::end(sg_sizes)));
382+
383+
size_t n1_padded = n1 + max_sgSize;
384+
argT *padded_vec = sycl::malloc_device<argT>(n1_padded, exec_q);
385+
386+
if (padded_vec == nullptr) {
387+
throw std::runtime_error("Could not allocate memory on the device");
388+
}
389+
sycl::event make_padded_vec_ev = exec_q.submit([&](sycl::handler &cgh) {
390+
cgh.depends_on(depends); // ensure vec contains actual data
391+
cgh.parallel_for({n1_padded}, [=](sycl::id<1> id) {
392+
auto i = id[0];
393+
padded_vec[i] = vec[i % n1];
394+
});
395+
});
396+
397+
// sub-group spans work-items [I, I + sgSize)
398+
// base = ndit.get_global_linear_id() - sg.get_local_id()[0]
399+
// Generically, sg.load( &mat[base]) may load arrays from
400+
// different rows of mat. The start corresponds to row (base / n0)
401+
// We read sg.load(&padded_vec[(base / n0)]). The vector is padded to
402+
// ensure that reads are accessible
403+
404+
size_t lws = 64;
405+
406+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
407+
cgh.depends_on(make_padded_vec_ev);
408+
409+
auto lwsRange = sycl::range<1>(lws);
410+
size_t n_elems = n0 * n1;
411+
size_t n_groups = (n_elems + lws - 1) / lws;
412+
auto gwsRange = sycl::range<1>(n_groups * lws);
413+
414+
cgh.parallel_for<class kernel_name<argT, resT>>(
415+
sycl::nd_range<1>(gwsRange, lwsRange),
416+
BinaryInplaceRowMatrixBroadcastFunctorT<argT, resT>(padded_vec, mat,
417+
n_elems, n1));
418+
});
419+
420+
sycl::event tmp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
421+
cgh.depends_on(comp_ev);
422+
sycl::context ctx = exec_q.get_context();
423+
cgh.host_task([ctx, padded_vec]() { sycl::free(padded_vec, ctx); });
424+
});
425+
host_tasks.push_back(tmp_cleanup_ev);
426+
427+
return comp_ev;
428+
}
429+
292430
} // namespace elementwise_common
293431
} // namespace kernels
294432
} // namespace tensor

dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,61 @@ struct MultiplyInplaceStridedFactory
496496
}
497497
};
498498

499+
template <typename argT, typename resT>
500+
class multiply_inplace_row_matrix_broadcast_sg_krn;
501+
502+
template <typename argT, typename resT>
503+
using MultiplyInplaceRowMatrixBroadcastingFunctor =
504+
elementwise_common::BinaryInplaceRowMatrixBroadcastingFunctor<
505+
argT,
506+
resT,
507+
MultiplyInplaceFunctor<argT, resT>>;
508+
509+
template <typename argT, typename resT>
510+
sycl::event multiply_inplace_row_matrix_broadcast_impl(
511+
sycl::queue exec_q,
512+
std::vector<sycl::event> &host_tasks,
513+
size_t n0,
514+
size_t n1,
515+
const char *vec_p, // typeless pointer to (n1,) contiguous row
516+
py::ssize_t vec_offset,
517+
char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix
518+
py::ssize_t mat_offset,
519+
const std::vector<sycl::event> &depends = {})
520+
{
521+
return elementwise_common::binary_inplace_row_matrix_broadcast_impl<
522+
argT, resT, MultiplyInplaceRowMatrixBroadcastingFunctor,
523+
multiply_inplace_row_matrix_broadcast_sg_krn>(
524+
exec_q, host_tasks, n0, n1, vec_p, vec_offset, mat_p, mat_offset,
525+
depends);
526+
}
527+
528+
template <typename fnT, typename T1, typename T2>
529+
struct MultiplyInplaceRowMatrixBroadcastFactory
530+
{
531+
fnT get()
532+
{
533+
using resT = typename MultiplyOutputType<T1, T2>::value_type;
534+
if constexpr (std::is_same_v<resT, void>) {
535+
fnT fn = nullptr;
536+
return fn;
537+
}
538+
else {
539+
if constexpr (dpctl::tensor::type_utils::is_complex<T1>::value ||
540+
dpctl::tensor::type_utils::is_complex<T2>::value ||
541+
dpctl::tensor::type_utils::is_complex<resT>::value)
542+
{
543+
fnT fn = nullptr;
544+
return fn;
545+
}
546+
else {
547+
fnT fn = multiply_inplace_row_matrix_broadcast_impl<T1, T2>;
548+
return fn;
549+
}
550+
}
551+
}
552+
};
553+
499554
} // namespace multiply
500555
} // namespace kernels
501556
} // namespace tensor

dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,61 @@ struct SubtractInplaceStridedFactory
509509
}
510510
};
511511

512+
template <typename argT, typename resT>
513+
class subtract_inplace_row_matrix_broadcast_sg_krn;
514+
515+
template <typename argT, typename resT>
516+
using SubtractInplaceRowMatrixBroadcastingFunctor =
517+
elementwise_common::BinaryInplaceRowMatrixBroadcastingFunctor<
518+
argT,
519+
resT,
520+
SubtractInplaceFunctor<argT, resT>>;
521+
522+
template <typename argT, typename resT>
523+
sycl::event subtract_inplace_row_matrix_broadcast_impl(
524+
sycl::queue exec_q,
525+
std::vector<sycl::event> &host_tasks,
526+
size_t n0,
527+
size_t n1,
528+
const char *vec_p, // typeless pointer to (n1,) contiguous row
529+
py::ssize_t vec_offset,
530+
char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix
531+
py::ssize_t mat_offset,
532+
const std::vector<sycl::event> &depends = {})
533+
{
534+
return elementwise_common::binary_inplace_row_matrix_broadcast_impl<
535+
argT, resT, SubtractInplaceRowMatrixBroadcastingFunctor,
536+
subtract_inplace_row_matrix_broadcast_sg_krn>(
537+
exec_q, host_tasks, n0, n1, vec_p, vec_offset, mat_p, mat_offset,
538+
depends);
539+
}
540+
541+
template <typename fnT, typename T1, typename T2>
542+
struct SubtractInplaceRowMatrixBroadcastFactory
543+
{
544+
fnT get()
545+
{
546+
using resT = typename SubtractOutputType<T1, T2>::value_type;
547+
if constexpr (std::is_same_v<resT, void>) {
548+
fnT fn = nullptr;
549+
return fn;
550+
}
551+
else {
552+
if constexpr (dpctl::tensor::type_utils::is_complex<T1>::value ||
553+
dpctl::tensor::type_utils::is_complex<T2>::value ||
554+
dpctl::tensor::type_utils::is_complex<resT>::value)
555+
{
556+
fnT fn = nullptr;
557+
return fn;
558+
}
559+
else {
560+
fnT fn = subtract_inplace_row_matrix_broadcast_impl<T1, T2>;
561+
return fn;
562+
}
563+
}
564+
}
565+
};
566+
512567
} // namespace subtract
513568
} // namespace kernels
514569
} // namespace tensor

0 commit comments

Comments
 (0)