@@ -283,6 +283,7 @@ namespace fn_ns = dpctl::tensor::kernels::add;
283
283
284
284
using fn_ns::add_contig_impl_fn_ptr_t ;
285
285
using fn_ns::add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t ;
286
+ using fn_ns::add_contig_row_contig_matrix_broadcast_impl_fn_ptr_t ;
286
287
using fn_ns::add_strided_impl_fn_ptr_t ;
287
288
288
289
static add_contig_impl_fn_ptr_t add_contig_dispatch_table[td_ns::num_types]
@@ -292,35 +293,51 @@ static int add_output_id_table[td_ns::num_types][td_ns::num_types];
292
293
static add_strided_impl_fn_ptr_t add_strided_dispatch_table[td_ns::num_types]
293
294
[td_ns::num_types];
294
295
296
+ // add(matrix, row)
295
297
static add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t
296
298
add_contig_matrix_contig_row_broadcast_dispatch_table[td_ns::num_types]
297
299
[td_ns::num_types];
298
300
301
+ // add(row, matrix)
302
+ static add_contig_row_contig_matrix_broadcast_impl_fn_ptr_t
303
+ add_contig_row_contig_matrix_broadcast_dispatch_table[td_ns::num_types]
304
+ [td_ns::num_types];
305
+
299
306
void populate_add_dispatch_tables (void )
300
307
{
301
308
using namespace td_ns ;
302
309
303
- using fn_ns::AddContigFactory;
304
- DispatchTableBuilder< add_contig_impl_fn_ptr_t , AddContigFactory, num_types>
305
- dtb1;
306
- dtb1.populate_dispatch_table (add_contig_dispatch_table );
310
+ // which input types are supported, and what is the type of the result
311
+ using fn_ns::AddTypeMapFactory;
312
+ DispatchTableBuilder< int , AddTypeMapFactory, num_types> dtb1;
313
+ dtb1.populate_dispatch_table (add_output_id_table );
307
314
315
+ // function pointers for operation on general strided arrays
308
316
using fn_ns::AddStridedFactory;
309
317
DispatchTableBuilder<add_strided_impl_fn_ptr_t , AddStridedFactory,
310
318
num_types>
311
319
dtb2;
312
320
dtb2.populate_dispatch_table (add_strided_dispatch_table);
313
321
314
- using fn_ns::AddTypeMapFactory;
315
- DispatchTableBuilder<int , AddTypeMapFactory, num_types> dtb3;
316
- dtb3.populate_dispatch_table (add_output_id_table);
322
+ // function pointers for operation on contiguous inputs and outputs
323
+ using fn_ns::AddContigFactory;
324
+ DispatchTableBuilder<add_contig_impl_fn_ptr_t , AddContigFactory, num_types>
325
+ dtb3;
326
+ dtb3.populate_dispatch_table (add_contig_dispatch_table);
317
327
318
328
using fn_ns::AddContigMatrixContigRowBroadcastFactory;
319
329
DispatchTableBuilder<add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t ,
320
330
AddContigMatrixContigRowBroadcastFactory, num_types>
321
331
dtb4;
322
332
dtb4.populate_dispatch_table (
323
333
add_contig_matrix_contig_row_broadcast_dispatch_table);
334
+
335
+ using fn_ns::AddContigRowContigMatrixBroadcastFactory;
336
+ DispatchTableBuilder<add_contig_row_contig_matrix_broadcast_impl_fn_ptr_t ,
337
+ AddContigRowContigMatrixBroadcastFactory, num_types>
338
+ dtb5;
339
+ dtb5.populate_dispatch_table (
340
+ add_contig_row_contig_matrix_broadcast_dispatch_table);
324
341
};
325
342
326
343
} // namespace impl
@@ -365,6 +382,7 @@ void init_elementwise_functions(py::module_ m)
365
382
impl::populate_add_dispatch_tables ();
366
383
using impl::add_contig_dispatch_table;
367
384
using impl::add_contig_matrix_contig_row_broadcast_dispatch_table;
385
+ using impl::add_contig_row_contig_matrix_broadcast_dispatch_table;
368
386
using impl::add_output_id_table;
369
387
using impl::add_strided_dispatch_table;
370
388
@@ -382,7 +400,10 @@ void init_elementwise_functions(py::module_ m)
382
400
add_strided_dispatch_table,
383
401
// function pointers to handle operation of c-contig matrix and
384
402
// c-contig row with broadcasting (may be nullptr)
385
- add_contig_matrix_contig_row_broadcast_dispatch_table);
403
+ add_contig_matrix_contig_row_broadcast_dispatch_table,
404
+ // function pointers to handle operation of c-contig matrix and
405
+ // c-contig row with broadcasting (may be nullptr)
406
+ add_contig_row_contig_matrix_broadcast_dispatch_table);
386
407
};
387
408
auto add_result_type_pyapi = [&](py::dtype dtype1, py::dtype dtype2) {
388
409
return py_binary_ufunc_result_type (dtype1, dtype2,
0 commit comments