Skip to content

Commit 8077aac

Browse files
No all binary functions are symmetric, modified py_binary_func accordingly
Introduced contig_row_contig_matrix_broadcasting_impl_fn_ptr_t and corresponding table. Implemented that for Add to fall back on contig_matrix_contrig_row_broadcasting_fn. It would be good to have a specialization for symmetric variants.
1 parent 300cc47 commit 8077aac

File tree

3 files changed

+103
-17
lines changed

3 files changed

+103
-17
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,67 @@ struct AddContigMatrixContigRowBroadcastFactory
436436
}
437437
};
438438

439+
typedef sycl::event (*add_contig_row_contig_matrix_broadcast_impl_fn_ptr_t)(
440+
sycl::queue,
441+
std::vector<sycl::event> &,
442+
size_t,
443+
size_t,
444+
const char *,
445+
py::ssize_t,
446+
const char *,
447+
py::ssize_t,
448+
char *,
449+
py::ssize_t,
450+
const std::vector<sycl::event> &);
451+
452+
template <typename argT1, typename argT2, typename resT>
453+
sycl::event add_contig_row_contig_matrix_broadcast_impl(
454+
sycl::queue exec_q,
455+
std::vector<sycl::event> &host_tasks,
456+
size_t n0,
457+
size_t n1,
458+
const char *vec_p, // typeless pointer to (n1,) contiguous row
459+
py::ssize_t vec_offset,
460+
const char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix
461+
py::ssize_t mat_offset,
462+
char *res_p, // typeless pointer to (n0, n1) result C-contig. matrix,
463+
// res[i,j] = mat[i,j] + vec[j]
464+
py::ssize_t res_offset,
465+
const std::vector<sycl::event> &depends = {})
466+
{
467+
return add_contig_matrix_contig_row_broadcast_impl<argT2, argT1, resT>(
468+
exec_q, host_tasks, n0, n1, mat_p, mat_offset, vec_p, vec_offset, res_p,
469+
res_offset, depends);
470+
};
471+
472+
template <typename fnT, typename T1, typename T2>
473+
struct AddContigRowContigMatrixBroadcastFactory
474+
{
475+
fnT get()
476+
{
477+
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
478+
void>) {
479+
fnT fn = nullptr;
480+
return fn;
481+
}
482+
else {
483+
using resT = typename AddOutputType<T1, T2>::value_type;
484+
if constexpr (dpctl::tensor::type_utils::is_complex<T1>::value ||
485+
dpctl::tensor::type_utils::is_complex<T2>::value ||
486+
dpctl::tensor::type_utils::is_complex<resT>::value)
487+
{
488+
fnT fn = nullptr;
489+
return fn;
490+
}
491+
else {
492+
fnT fn =
493+
add_contig_row_contig_matrix_broadcast_impl<T1, T2, resT>;
494+
return fn;
495+
}
496+
}
497+
}
498+
};
499+
439500
} // namespace add
440501
} // namespace kernels
441502
} // namespace tensor

dpctl/tensor/libtensor/source/elementwise_functions.cpp

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ namespace fn_ns = dpctl::tensor::kernels::add;
283283

284284
using fn_ns::add_contig_impl_fn_ptr_t;
285285
using fn_ns::add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t;
286+
using fn_ns::add_contig_row_contig_matrix_broadcast_impl_fn_ptr_t;
286287
using fn_ns::add_strided_impl_fn_ptr_t;
287288

288289
static add_contig_impl_fn_ptr_t add_contig_dispatch_table[td_ns::num_types]
@@ -292,35 +293,51 @@ static int add_output_id_table[td_ns::num_types][td_ns::num_types];
292293
static add_strided_impl_fn_ptr_t add_strided_dispatch_table[td_ns::num_types]
293294
[td_ns::num_types];
294295

296+
// add(matrix, row)
295297
static add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t
296298
add_contig_matrix_contig_row_broadcast_dispatch_table[td_ns::num_types]
297299
[td_ns::num_types];
298300

301+
// add(row, matrix)
302+
static add_contig_row_contig_matrix_broadcast_impl_fn_ptr_t
303+
add_contig_row_contig_matrix_broadcast_dispatch_table[td_ns::num_types]
304+
[td_ns::num_types];
305+
299306
void populate_add_dispatch_tables(void)
300307
{
301308
using namespace td_ns;
302309

303-
using fn_ns::AddContigFactory;
304-
DispatchTableBuilder<add_contig_impl_fn_ptr_t, AddContigFactory, num_types>
305-
dtb1;
306-
dtb1.populate_dispatch_table(add_contig_dispatch_table);
310+
// which input types are supported, and what is the type of the result
311+
using fn_ns::AddTypeMapFactory;
312+
DispatchTableBuilder<int, AddTypeMapFactory, num_types> dtb1;
313+
dtb1.populate_dispatch_table(add_output_id_table);
307314

315+
// function pointers for operation on general strided arrays
308316
using fn_ns::AddStridedFactory;
309317
DispatchTableBuilder<add_strided_impl_fn_ptr_t, AddStridedFactory,
310318
num_types>
311319
dtb2;
312320
dtb2.populate_dispatch_table(add_strided_dispatch_table);
313321

314-
using fn_ns::AddTypeMapFactory;
315-
DispatchTableBuilder<int, AddTypeMapFactory, num_types> dtb3;
316-
dtb3.populate_dispatch_table(add_output_id_table);
322+
// function pointers for operation on contiguous inputs and outputs
323+
using fn_ns::AddContigFactory;
324+
DispatchTableBuilder<add_contig_impl_fn_ptr_t, AddContigFactory, num_types>
325+
dtb3;
326+
dtb3.populate_dispatch_table(add_contig_dispatch_table);
317327

318328
using fn_ns::AddContigMatrixContigRowBroadcastFactory;
319329
DispatchTableBuilder<add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t,
320330
AddContigMatrixContigRowBroadcastFactory, num_types>
321331
dtb4;
322332
dtb4.populate_dispatch_table(
323333
add_contig_matrix_contig_row_broadcast_dispatch_table);
334+
335+
using fn_ns::AddContigRowContigMatrixBroadcastFactory;
336+
DispatchTableBuilder<add_contig_row_contig_matrix_broadcast_impl_fn_ptr_t,
337+
AddContigRowContigMatrixBroadcastFactory, num_types>
338+
dtb5;
339+
dtb5.populate_dispatch_table(
340+
add_contig_row_contig_matrix_broadcast_dispatch_table);
324341
};
325342

326343
} // namespace impl
@@ -365,6 +382,7 @@ void init_elementwise_functions(py::module_ m)
365382
impl::populate_add_dispatch_tables();
366383
using impl::add_contig_dispatch_table;
367384
using impl::add_contig_matrix_contig_row_broadcast_dispatch_table;
385+
using impl::add_contig_row_contig_matrix_broadcast_dispatch_table;
368386
using impl::add_output_id_table;
369387
using impl::add_strided_dispatch_table;
370388

@@ -382,7 +400,10 @@ void init_elementwise_functions(py::module_ m)
382400
add_strided_dispatch_table,
383401
// function pointers to handle operation of c-contig matrix and
384402
// c-contig row with broadcasting (may be nullptr)
385-
add_contig_matrix_contig_row_broadcast_dispatch_table);
403+
add_contig_matrix_contig_row_broadcast_dispatch_table,
404+
// function pointers to handle operation of c-contig matrix and
405+
// c-contig row with broadcasting (may be nullptr)
406+
add_contig_row_contig_matrix_broadcast_dispatch_table);
386407
};
387408
auto add_result_type_pyapi = [&](py::dtype dtype1, py::dtype dtype2) {
388409
return py_binary_ufunc_result_type(dtype1, dtype2,

dpctl/tensor/libtensor/source/elementwise_functions.hpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,8 @@ bool isEqual(Container const &c, std::initializer_list<T> const &l)
297297
template <typename output_typesT,
298298
typename contig_dispatchT,
299299
typename strided_dispatchT,
300-
typename matrix_row_dispatchT>
300+
typename contig_matrix_row_dispatchT,
301+
typename contig_row_matrix_dispatchT>
301302
std::pair<sycl::event, sycl::event> py_binary_ufunc(
302303
dpctl::tensor::usm_ndarray src1,
303304
dpctl::tensor::usm_ndarray src2,
@@ -308,7 +309,10 @@ std::pair<sycl::event, sycl::event> py_binary_ufunc(
308309
const output_typesT &output_type_table,
309310
const contig_dispatchT &contig_dispatch_table,
310311
const strided_dispatchT &strided_dispatch_table,
311-
const matrix_row_dispatchT &contig_matrix_row_broadcast_dispatch_table)
312+
const contig_matrix_row_dispatchT
313+
&contig_matrix_row_broadcast_dispatch_table,
314+
const contig_row_matrix_dispatchT
315+
&contig_row_matrix_broadcast_dispatch_table)
312316
{
313317
// check type_nums
314318
int src1_typenum = src1.get_typenum();
@@ -507,15 +511,15 @@ std::pair<sycl::event, sycl::event> py_binary_ufunc(
507511
isEqual(simplified_src2_strides, {one, simplified_shape[0]}) &&
508512
isEqual(simplified_dst_strides, {one, simplified_shape[0]}))
509513
{
510-
auto matrix_row_broadcast_fn =
511-
contig_matrix_row_broadcast_dispatch_table[src2_typeid]
512-
[src1_typeid];
513-
if (matrix_row_broadcast_fn != nullptr) {
514+
auto row_matrix_broadcast_fn =
515+
contig_row_matrix_broadcast_dispatch_table[src1_typeid]
516+
[src2_typeid];
517+
if (row_matrix_broadcast_fn != nullptr) {
514518
size_t n0 = simplified_shape[1];
515519
size_t n1 = simplified_shape[0];
516-
sycl::event comp_ev = matrix_row_broadcast_fn(
517-
exec_q, host_tasks, n0, n1, src2_data, src2_offset,
518-
src1_data, src1_offset, dst_data, dst_offset, depends);
520+
sycl::event comp_ev = row_matrix_broadcast_fn(
521+
exec_q, host_tasks, n0, n1, src1_data, src1_offset,
522+
src2_data, src2_offset, dst_data, dst_offset, depends);
519523

520524
return std::make_pair(
521525
dpctl::utils::keep_args_alive(exec_q, {src1, src2, dst},

0 commit comments

Comments
 (0)