Skip to content

Commit 2e7cbe0

Browse files
committed
Simpler dispatching for in-place broadcast kernels and changes requested by @vtavana
1 parent 7b82b2b commit 2e7cbe0

File tree

12 files changed

+41
-46
lines changed

12 files changed

+41
-46
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ template <typename fnT, typename T1, typename T2> struct AddTypeMapFactory
218218
};
219219

220220
template <typename T1, typename T2, typename resT, typename IndexerT>
221-
class add_strided_strided_kernel;
221+
class add_strided_kernel;
222222

223223
template <typename argTy1, typename argTy2>
224224
sycl::event add_strided_impl(sycl::queue exec_q,
@@ -235,8 +235,7 @@ sycl::event add_strided_impl(sycl::queue exec_q,
235235
const std::vector<sycl::event> &additional_depends)
236236
{
237237
return elementwise_common::binary_strided_impl<
238-
argTy1, argTy2, AddOutputType, AddStridedFunctor,
239-
add_strided_strided_kernel>(
238+
argTy1, argTy2, AddOutputType, AddStridedFunctor, add_strided_kernel>(
240239
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
241240
arg2_offset, res_p, res_offset, depends, additional_depends);
242241
}
@@ -515,14 +514,13 @@ struct AddInplaceRowMatrixBroadcastFactory
515514
fnT get()
516515
{
517516
using resT = typename AddOutputType<T1, T2>::value_type;
518-
if constexpr (std::is_same_v<resT, void>) {
517+
if constexpr (!std::is_same_v<resT, T2>) {
519518
fnT fn = nullptr;
520519
return fn;
521520
}
522521
else {
523522
if constexpr (dpctl::tensor::type_utils::is_complex<T1>::value ||
524-
dpctl::tensor::type_utils::is_complex<T2>::value ||
525-
dpctl::tensor::type_utils::is_complex<resT>::value)
523+
dpctl::tensor::type_utils::is_complex<T2>::value)
526524
{
527525
fnT fn = nullptr;
528526
return fn;

dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ template <typename fnT, typename T1, typename T2> struct EqualTypeMapFactory
201201
};
202202

203203
template <typename T1, typename T2, typename resT, typename IndexerT>
204-
class equal_strided_strided_kernel;
204+
class equal_strided_kernel;
205205

206206
template <typename argTy1, typename argTy2>
207207
sycl::event
@@ -220,9 +220,9 @@ equal_strided_impl(sycl::queue exec_q,
220220
{
221221
return elementwise_common::binary_strided_impl<
222222
argTy1, argTy2, EqualOutputType, EqualStridedFunctor,
223-
equal_strided_strided_kernel>(
224-
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
225-
arg2_offset, res_p, res_offset, depends, additional_depends);
223+
equal_strided_kernel>(exec_q, nelems, nd, shape_and_strides, arg1_p,
224+
arg1_offset, arg2_p, arg2_offset, res_p,
225+
res_offset, depends, additional_depends);
226226
}
227227

228228
template <typename fnT, typename T1, typename T2> struct EqualStridedFactory

dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ struct FloorDivideTypeMapFactory
235235
};
236236

237237
template <typename T1, typename T2, typename resT, typename IndexerT>
238-
class floor_divide_strided_strided_kernel;
238+
class floor_divide_strided_kernel;
239239

240240
template <typename argTy1, typename argTy2>
241241
sycl::event
@@ -254,7 +254,7 @@ floor_divide_strided_impl(sycl::queue exec_q,
254254
{
255255
return elementwise_common::binary_strided_impl<
256256
argTy1, argTy2, FloorDivideOutputType, FloorDivideStridedFunctor,
257-
floor_divide_strided_strided_kernel>(
257+
floor_divide_strided_kernel>(
258258
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
259259
arg2_offset, res_p, res_offset, depends, additional_depends);
260260
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ template <typename fnT, typename T1, typename T2> struct GreaterTypeMapFactory
255255
};
256256

257257
template <typename T1, typename T2, typename resT, typename IndexerT>
258-
class greater_strided_strided_kernel;
258+
class greater_strided_kernel;
259259

260260
template <typename argTy1, typename argTy2>
261261
sycl::event
@@ -289,7 +289,7 @@ greater_strided_impl(sycl::queue exec_q,
289289
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
290290

291291
cgh.parallel_for<
292-
greater_strided_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
292+
greater_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
293293
{nelems}, GreaterStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
294294
arg1_tp, arg2_tp, res_tp, indexer));
295295
});

dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ struct GreaterEqualTypeMapFactory
261261
};
262262

263263
template <typename T1, typename T2, typename resT, typename IndexerT>
264-
class greater_equal_strided_strided_kernel;
264+
class greater_equal_strided_kernel;
265265

266266
template <typename argTy1, typename argTy2>
267267
sycl::event
@@ -295,8 +295,8 @@ greater_equal_strided_impl(sycl::queue exec_q,
295295
const argTy2 *arg2_tp = reinterpret_cast<const argTy2 *>(arg2_p);
296296
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
297297

298-
cgh.parallel_for<greater_equal_strided_strided_kernel<argTy1, argTy2,
299-
resTy, IndexerT>>(
298+
cgh.parallel_for<
299+
greater_equal_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
300300
{nelems},
301301
GreaterEqualStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
302302
arg1_tp, arg2_tp, res_tp, indexer));

dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ template <typename fnT, typename T1, typename T2> struct LessTypeMapFactory
253253
};
254254

255255
template <typename T1, typename T2, typename resT, typename IndexerT>
256-
class less_strided_strided_kernel;
256+
class less_strided_kernel;
257257

258258
template <typename argTy1, typename argTy2>
259259
sycl::event
@@ -286,8 +286,7 @@ less_strided_impl(sycl::queue exec_q,
286286
const argTy2 *arg2_tp = reinterpret_cast<const argTy2 *>(arg2_p);
287287
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
288288

289-
cgh.parallel_for<
290-
less_strided_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
289+
cgh.parallel_for<less_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
291290
{nelems}, LessStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
292291
arg1_tp, arg2_tp, res_tp, indexer));
293292
});

dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ template <typename fnT, typename T1, typename T2> struct LessEqualTypeMapFactory
256256
};
257257

258258
template <typename T1, typename T2, typename resT, typename IndexerT>
259-
class less_equal_strided_strided_kernel;
259+
class less_equal_strided_kernel;
260260

261261
template <typename argTy1, typename argTy2>
262262
sycl::event
@@ -290,7 +290,7 @@ less_equal_strided_impl(sycl::queue exec_q,
290290
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
291291

292292
cgh.parallel_for<
293-
less_equal_strided_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
293+
less_equal_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
294294
{nelems}, LessEqualStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
295295
arg1_tp, arg2_tp, res_tp, indexer));
296296
});

dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ template <typename fnT, typename T1, typename T2> struct MultiplyTypeMapFactory
221221
};
222222

223223
template <typename T1, typename T2, typename resT, typename IndexerT>
224-
class multiply_strided_strided_kernel;
224+
class multiply_strided_kernel;
225225

226226
template <typename argTy1, typename argTy2>
227227
sycl::event
@@ -240,9 +240,9 @@ multiply_strided_impl(sycl::queue exec_q,
240240
{
241241
return elementwise_common::binary_strided_impl<
242242
argTy1, argTy2, MultiplyOutputType, MultiplyStridedFunctor,
243-
multiply_strided_strided_kernel>(
244-
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
245-
arg2_offset, res_p, res_offset, depends, additional_depends);
243+
multiply_strided_kernel>(exec_q, nelems, nd, shape_and_strides, arg1_p,
244+
arg1_offset, arg2_p, arg2_offset, res_p,
245+
res_offset, depends, additional_depends);
246246
}
247247

248248
template <typename fnT, typename T1, typename T2> struct MultiplyStridedFactory
@@ -531,14 +531,13 @@ struct MultiplyInplaceRowMatrixBroadcastFactory
531531
fnT get()
532532
{
533533
using resT = typename MultiplyOutputType<T1, T2>::value_type;
534-
if constexpr (std::is_same_v<resT, void>) {
534+
if constexpr (!std::is_same_v<resT, T2>) {
535535
fnT fn = nullptr;
536536
return fn;
537537
}
538538
else {
539539
if constexpr (dpctl::tensor::type_utils::is_complex<T1>::value ||
540-
dpctl::tensor::type_utils::is_complex<T2>::value ||
541-
dpctl::tensor::type_utils::is_complex<resT>::value)
540+
dpctl::tensor::type_utils::is_complex<T2>::value)
542541
{
543542
fnT fn = nullptr;
544543
return fn;

dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ template <typename fnT, typename T1, typename T2> struct NotEqualTypeMapFactory
218218
};
219219

220220
template <typename T1, typename T2, typename resT, typename IndexerT>
221-
class not_equal_strided_strided_kernel;
221+
class not_equal_strided_kernel;
222222

223223
template <typename argTy1, typename argTy2>
224224
sycl::event
@@ -237,9 +237,9 @@ not_equal_strided_impl(sycl::queue exec_q,
237237
{
238238
return elementwise_common::binary_strided_impl<
239239
argTy1, argTy2, NotEqualOutputType, NotEqualStridedFunctor,
240-
not_equal_strided_strided_kernel>(
241-
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
242-
arg2_offset, res_p, res_offset, depends, additional_depends);
240+
not_equal_strided_kernel>(exec_q, nelems, nd, shape_and_strides, arg1_p,
241+
arg1_offset, arg2_p, arg2_offset, res_p,
242+
res_offset, depends, additional_depends);
243243
}
244244

245245
template <typename fnT, typename T1, typename T2> struct NotEqualStridedFactory

dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ template <typename fnT, typename T1, typename T2> struct SubtractTypeMapFactory
218218
};
219219

220220
template <typename T1, typename T2, typename resT, typename IndexerT>
221-
class subtract_strided_strided_kernel;
221+
class subtract_strided_kernel;
222222

223223
template <typename argTy1, typename argTy2>
224224
sycl::event
@@ -237,9 +237,9 @@ subtract_strided_impl(sycl::queue exec_q,
237237
{
238238
return elementwise_common::binary_strided_impl<
239239
argTy1, argTy2, SubtractOutputType, SubtractStridedFunctor,
240-
subtract_strided_strided_kernel>(
241-
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
242-
arg2_offset, res_p, res_offset, depends, additional_depends);
240+
subtract_strided_kernel>(exec_q, nelems, nd, shape_and_strides, arg1_p,
241+
arg1_offset, arg2_p, arg2_offset, res_p,
242+
res_offset, depends, additional_depends);
243243
}
244244

245245
template <typename fnT, typename T1, typename T2> struct SubtractStridedFactory
@@ -544,14 +544,13 @@ struct SubtractInplaceRowMatrixBroadcastFactory
544544
fnT get()
545545
{
546546
using resT = typename SubtractOutputType<T1, T2>::value_type;
547-
if constexpr (std::is_same_v<resT, void>) {
547+
if constexpr (!std::is_same_v<resT, T2>) {
548548
fnT fn = nullptr;
549549
return fn;
550550
}
551551
else {
552552
if constexpr (dpctl::tensor::type_utils::is_complex<T1>::value ||
553-
dpctl::tensor::type_utils::is_complex<T2>::value ||
554-
dpctl::tensor::type_utils::is_complex<resT>::value)
553+
dpctl::tensor::type_utils::is_complex<T2>::value)
555554
{
556555
fnT fn = nullptr;
557556
return fn;

dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ struct TrueDivideTypeMapFactory
201201
};
202202

203203
template <typename T1, typename T2, typename resT, typename IndexerT>
204-
class true_divide_strided_strided_kernel;
204+
class true_divide_strided_kernel;
205205

206206
template <typename argTy1, typename argTy2>
207207
sycl::event
@@ -220,7 +220,7 @@ true_divide_strided_impl(sycl::queue exec_q,
220220
{
221221
return elementwise_common::binary_strided_impl<
222222
argTy1, argTy2, TrueDivideOutputType, TrueDivideStridedFunctor,
223-
true_divide_strided_strided_kernel>(
223+
true_divide_strided_kernel>(
224224
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
225225
arg2_offset, res_p, res_offset, depends, additional_depends);
226226
}

dpctl/tensor/libtensor/source/elementwise_functions.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,7 @@ std::pair<sycl::event, sycl::event> py_binary_ufunc(
517517

518518
if (strided_fn == nullptr) {
519519
throw std::runtime_error(
520-
"Contiguous implementation is missing for src1_typeid=" +
520+
"Strided implementation is missing for src1_typeid=" +
521521
std::to_string(src1_typeid) +
522522
" and src2_typeid=" + std::to_string(src2_typeid));
523523
}
@@ -627,7 +627,7 @@ py_binary_inplace_ufunc(dpctl::tensor::usm_ndarray lhs,
627627

628628
if (output_typeid != lhs_typeid) {
629629
throw py::value_error(
630-
"Destination array has unexpected elemental data type.");
630+
"Left-hand side array has unexpected elemental data type.");
631631
}
632632

633633
// check that queues are compatible
@@ -696,7 +696,7 @@ py_binary_inplace_ufunc(dpctl::tensor::usm_ndarray lhs,
696696

697697
// dispatch for contiguous inputs
698698
if (both_c_contig || both_f_contig) {
699-
auto contig_fn = contig_dispatch_table[lhs_typeid][rhs_typeid];
699+
auto contig_fn = contig_dispatch_table[rhs_typeid][lhs_typeid];
700700

701701
if (contig_fn != nullptr) {
702702
auto comp_ev = contig_fn(exec_q, rhs_nelems, rhs_data, 0, lhs_data,
@@ -781,7 +781,7 @@ py_binary_inplace_ufunc(dpctl::tensor::usm_ndarray lhs,
781781

782782
if (strided_fn == nullptr) {
783783
throw std::runtime_error(
784-
"Contiguous implementation is missing for rhs_typeid=" +
784+
"Strided implementation is missing for rhs_typeid=" +
785785
std::to_string(rhs_typeid) +
786786
" and lhs_typeid=" + std::to_string(lhs_typeid));
787787
}

0 commit comments

Comments
 (0)