@@ -848,6 +848,18 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
848
848
return ;
849
849
}
850
850
851
+ template <typename T>
852
+ constexpr auto dispatch_fmod_op (T elem1, T elem2)
853
+ {
854
+ if constexpr (is_any_v<T, std::int32_t , std::int64_t >)
855
+ {
856
+ return elem1 % elem2;
857
+ }
858
+ else
859
+ {
860
+ return sycl::fmod (elem1, elem2);
861
+ }
862
+ }
851
863
852
864
#define MACRO_2ARG_3TYPES_OP ( \
853
865
__name__, __operation__, __vec_operation__, __vec_types__, __mkl_operation__, __mkl_types__) \
@@ -995,8 +1007,8 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
995
1007
const size_t output_id = global_id[0 ]; /* for (size_t i = 0; i < result_size; ++i) */ \
996
1008
{ \
997
1009
const shape_elem_type* result_strides_data = &dev_strides_data[0 ]; \
998
- const shape_elem_type* input1_strides_data = &dev_strides_data[1 ]; \
999
- const shape_elem_type* input2_strides_data = &dev_strides_data[2 ]; \
1010
+ const shape_elem_type* input1_strides_data = &dev_strides_data[result_ndim]; \
1011
+ const shape_elem_type* input2_strides_data = &dev_strides_data[2 * result_ndim]; \
1000
1012
\
1001
1013
size_t input1_id = 0 ; \
1002
1014
size_t input2_id = 0 ; \
@@ -1261,6 +1273,19 @@ static constexpr DPNPFuncType get_divide_res_type()
1261
1273
return widest_type;
1262
1274
}
1263
1275
1276
+ template <DPNPFuncType FT1, DPNPFuncType FT2>
1277
+ static constexpr DPNPFuncType get_fmod_res_type ()
1278
+ {
1279
+ constexpr auto widest_type = populate_func_types<FT1, FT2>();
1280
+ constexpr auto shortes_type = (widest_type == FT1) ? FT2 : FT1;
1281
+
1282
+ if constexpr (shortes_type == DPNPFuncType::DPNP_FT_BOOL)
1283
+ {
1284
+ return DPNPFuncType::DPNP_FT_INT;
1285
+ }
1286
+ return widest_type;
1287
+ }
1288
+
1264
1289
template <DPNPFuncType FT1, DPNPFuncType... FTs>
1265
1290
static void func_map_elemwise_2arg_3type_core (func_map_t & fmap)
1266
1291
{
@@ -1300,12 +1325,29 @@ static void func_map_elemwise_2arg_3type_core(func_map_t& fmap)
1300
1325
...);
1301
1326
}
1302
1327
1328
+ template <DPNPFuncType FT1, DPNPFuncType... FTs>
1329
+ static void func_map_elemwise_2arg_3type_core_no_complex (func_map_t & fmap)
1330
+ {
1331
+ ((fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][FT1][FTs] =
1332
+ {get_fmod_res_type<FT1, FTs>(),
1333
+ (void *)dpnp_fmod_c_ext<func_type_map_t ::find_type<get_fmod_res_type<FT1, FTs>()>,
1334
+ func_type_map_t ::find_type<FT1>,
1335
+ func_type_map_t ::find_type<FTs>>}),
1336
+ ...);
1337
+ }
1338
+
1303
1339
template <DPNPFuncType... FTs>
1304
1340
static void func_map_elemwise_2arg_3type_helper (func_map_t & fmap)
1305
1341
{
1306
1342
((func_map_elemwise_2arg_3type_core<FTs, FTs...>(fmap)), ...);
1307
1343
}
1308
1344
1345
+ template <DPNPFuncType... FTs>
1346
+ static void func_map_elemwise_2arg_3type_helper_no_complex (func_map_t & fmap)
1347
+ {
1348
+ ((func_map_elemwise_2arg_3type_core_no_complex<FTs, FTs...>(fmap)), ...);
1349
+ }
1350
+
1309
1351
static void func_map_init_elemwise_2arg_3type (func_map_t & fmap)
1310
1352
{
1311
1353
fmap[DPNPFuncName::DPNP_FN_ADD][eft_INT][eft_INT] = {eft_INT,
@@ -1539,39 +1581,6 @@ static void func_map_init_elemwise_2arg_3type(func_map_t& fmap)
1539
1581
fmap[DPNPFuncName::DPNP_FN_FMOD][eft_DBL][eft_DBL] = {eft_DBL,
1540
1582
(void *)dpnp_fmod_c_default<double , double , double >};
1541
1583
1542
- fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_INT][eft_INT] = {eft_INT,
1543
- (void *)dpnp_fmod_c_ext<int32_t , int32_t , int32_t >};
1544
- fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_INT][eft_LNG] = {eft_LNG,
1545
- (void *)dpnp_fmod_c_ext<int64_t , int32_t , int64_t >};
1546
- fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_INT][eft_FLT] = {eft_DBL,
1547
- (void *)dpnp_fmod_c_ext<double , int32_t , float >};
1548
- fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_INT][eft_DBL] = {eft_DBL,
1549
- (void *)dpnp_fmod_c_ext<double , int32_t , double >};
1550
- fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_LNG][eft_INT] = {eft_LNG,
1551
- (void *)dpnp_fmod_c_ext<int64_t , int64_t , int32_t >};
1552
- fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_LNG][eft_LNG] = {eft_LNG,
1553
- (void *)dpnp_fmod_c_ext<int64_t , int64_t , int64_t >};
1554
- fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_LNG][eft_FLT] = {eft_DBL,
1555
- (void *)dpnp_fmod_c_ext<double , int64_t , float >};
1556
- fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_LNG][eft_DBL] = {eft_DBL,
1557
- (void *)dpnp_fmod_c_ext<double , int64_t , double >};
1558
- fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_FLT][eft_INT] = {eft_DBL,
1559
- (void *)dpnp_fmod_c_ext<double , float , int32_t >};
1560
- fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_FLT][eft_LNG] = {eft_DBL,
1561
- (void *)dpnp_fmod_c_ext<double , float , int64_t >};
1562
- fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_FLT][eft_FLT] = {eft_FLT,
1563
- (void *)dpnp_fmod_c_ext<float , float , float >};
1564
- fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_FLT][eft_DBL] = {eft_DBL,
1565
- (void *)dpnp_fmod_c_ext<double , float , double >};
1566
- fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_DBL][eft_INT] = {eft_DBL,
1567
- (void *)dpnp_fmod_c_ext<double , double , int32_t >};
1568
- fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_DBL][eft_LNG] = {eft_DBL,
1569
- (void *)dpnp_fmod_c_ext<double , double , int64_t >};
1570
- fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_DBL][eft_FLT] = {eft_DBL,
1571
- (void *)dpnp_fmod_c_ext<double , double , float >};
1572
- fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_DBL][eft_DBL] = {eft_DBL,
1573
- (void *)dpnp_fmod_c_ext<double , double , double >};
1574
-
1575
1584
fmap[DPNPFuncName::DPNP_FN_HYPOT][eft_INT][eft_INT] = {eft_DBL,
1576
1585
(void *)dpnp_hypot_c_default<double , int32_t , int32_t >};
1577
1586
fmap[DPNPFuncName::DPNP_FN_HYPOT][eft_INT][eft_LNG] = {eft_DBL,
@@ -1918,6 +1927,7 @@ static void func_map_init_elemwise_2arg_3type(func_map_t& fmap)
1918
1927
eft_DBL, (void *)dpnp_subtract_c_default<double , double , double >};
1919
1928
1920
1929
func_map_elemwise_2arg_3type_helper<eft_BLN, eft_INT, eft_LNG, eft_FLT, eft_DBL, eft_C64, eft_C128>(fmap);
1930
+ func_map_elemwise_2arg_3type_helper_no_complex<eft_BLN, eft_INT, eft_LNG, eft_FLT, eft_DBL>(fmap);
1921
1931
1922
1932
return ;
1923
1933
}
0 commit comments