@@ -1029,26 +1029,50 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
1029
1029
\
1030
1030
if (start + static_cast <size_t >(vec_sz) * max_sg_size < result_size) \
1031
1031
{ \
1032
- sycl::vec <_DataType_input1, vec_sz> x1 = \
1033
- sg. load <vec_sz>( sycl::multi_ptr<_DataType_input1 , global_space>(&input1_data[start])); \
1034
- sycl::vec<_DataType_input2, vec_sz> x2 = \
1035
- sg. load <vec_sz>(sycl::multi_ptr<_DataType_input2, global_space>(&input2_data[start])); \
1032
+ using input1_ptrT = sycl::multi_ptr <_DataType_input1, global_space>; \
1033
+ using input2_ptrT = sycl::multi_ptr<_DataType_input2 , global_space>; \
1034
+ using result_ptrT = sycl::multi_ptr<_DataType_output, global_space>; \
1035
+ \
1036
1036
sycl::vec<_DataType_output, vec_sz> res_vec; \
1037
1037
\
1038
- if constexpr (both_types_are_same <_DataType_input1, _DataType_input2, __vec_types__>) \
1038
+ if constexpr (both_types_are_any_of <_DataType_input1, _DataType_input2, __vec_types__>) \
1039
1039
{ \
1040
- res_vec = __vec_operation__; \
1040
+ if constexpr (both_types_are_same<_DataType_input1, _DataType_input2, _DataType_output>) \
1041
+ { \
1042
+ sycl::vec<_DataType_input1, vec_sz> x1 = \
1043
+ sg.load <vec_sz>(input1_ptrT (&input1_data[start])); \
1044
+ sycl::vec<_DataType_input2, vec_sz> x2 = \
1045
+ sg.load <vec_sz>(input2_ptrT (&input2_data[start])); \
1046
+ \
1047
+ res_vec = __vec_operation__; \
1048
+ } \
1049
+ else /* input types don't match result type, so explicit casting is required */ \
1050
+ { \
1051
+ sycl::vec<_DataType_output, vec_sz> x1 = \
1052
+ dpnp_vec_cast<_DataType_output, _DataType_input1, vec_sz>( \
1053
+ sg.load <vec_sz>(input1_ptrT (&input1_data[start]))); \
1054
+ sycl::vec<_DataType_output, vec_sz> x2 = \
1055
+ dpnp_vec_cast<_DataType_output, _DataType_input2, vec_sz>( \
1056
+ sg.load <vec_sz>(input2_ptrT (&input2_data[start]))); \
1057
+ \
1058
+ res_vec = __vec_operation__; \
1059
+ } \
1041
1060
} \
1042
1061
else \
1043
1062
{ \
1063
+ sycl::vec<_DataType_input1, vec_sz> x1 = \
1064
+ sg.load <vec_sz>(input1_ptrT (&input1_data[start])); \
1065
+ sycl::vec<_DataType_input2, vec_sz> x2 = \
1066
+ sg.load <vec_sz>(input2_ptrT (&input2_data[start])); \
1067
+ \
1044
1068
for (size_t k = 0 ; k < vec_sz; ++k) \
1045
1069
{ \
1046
1070
const _DataType_output input1_elem = x1[k]; \
1047
1071
const _DataType_output input2_elem = x2[k]; \
1048
1072
res_vec[k] = __operation__; \
1049
1073
} \
1050
1074
} \
1051
- sg.store <vec_sz>(sycl::multi_ptr<_DataType_output, global_space> (&result[start]), res_vec); \
1075
+ sg.store <vec_sz>(result_ptrT (&result[start]), res_vec); \
1052
1076
} \
1053
1077
else \
1054
1078
{ \
@@ -1173,6 +1197,47 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
1173
1197
1174
1198
#include < dpnp_gen_2arg_3type_tbl.hpp>
1175
1199
1200
+ template <DPNPFuncType FT1, DPNPFuncType FT2, typename has_fp64 = std::true_type>
1201
+ static constexpr DPNPFuncType get_divide_res_type ()
1202
+ {
1203
+ constexpr auto widest_type = populate_func_types<FT1, FT2>();
1204
+ constexpr auto shortes_type = (widest_type == FT1) ? FT2 : FT1;
1205
+
1206
+ if constexpr (widest_type == DPNPFuncType::DPNP_FT_CMPLX128 || widest_type == DPNPFuncType::DPNP_FT_DOUBLE)
1207
+ {
1208
+ return widest_type;
1209
+ }
1210
+ else if constexpr (widest_type == DPNPFuncType::DPNP_FT_CMPLX64)
1211
+ {
1212
+ if constexpr (shortes_type == DPNPFuncType::DPNP_FT_DOUBLE)
1213
+ {
1214
+ return DPNPFuncType::DPNP_FT_CMPLX128;
1215
+ }
1216
+ else if constexpr (has_fp64::value &&
1217
+ (shortes_type == DPNPFuncType::DPNP_FT_INT || shortes_type == DPNPFuncType::DPNP_FT_LONG))
1218
+ {
1219
+ return DPNPFuncType::DPNP_FT_CMPLX128;
1220
+ }
1221
+ }
1222
+ else if constexpr (widest_type == DPNPFuncType::DPNP_FT_FLOAT)
1223
+ {
1224
+ if constexpr (has_fp64::value &&
1225
+ (shortes_type == DPNPFuncType::DPNP_FT_INT || shortes_type == DPNPFuncType::DPNP_FT_LONG))
1226
+ {
1227
+ return DPNPFuncType::DPNP_FT_DOUBLE;
1228
+ }
1229
+ }
1230
+ else if constexpr (has_fp64::value)
1231
+ {
1232
+ return DPNPFuncType::DPNP_FT_DOUBLE;
1233
+ }
1234
+ else
1235
+ {
1236
+ return DPNPFuncType::DPNP_FT_FLOAT;
1237
+ }
1238
+ return widest_type;
1239
+ }
1240
+
1176
1241
template <DPNPFuncType FT1, DPNPFuncType... FTs>
1177
1242
static void func_map_elemwise_2arg_3type_core (func_map_t & fmap)
1178
1243
{
@@ -1194,6 +1259,16 @@ static void func_map_elemwise_2arg_3type_core(func_map_t& fmap)
1194
1259
func_type_map_t ::find_type<FT1>,
1195
1260
func_type_map_t ::find_type<FTs>>}),
1196
1261
...);
1262
+ ((fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][FT1][FTs] =
1263
+ {get_divide_res_type<FT1, FTs>(),
1264
+ (void *)dpnp_divide_c_ext<func_type_map_t ::find_type<get_divide_res_type<FT1, FTs>()>,
1265
+ func_type_map_t ::find_type<FT1>,
1266
+ func_type_map_t ::find_type<FTs>>,
1267
+ get_divide_res_type<FT1, FTs, std::false_type>(),
1268
+ (void *)dpnp_divide_c_ext<func_type_map_t ::find_type<get_divide_res_type<FT1, FTs, std::false_type>()>,
1269
+ func_type_map_t ::find_type<FT1>,
1270
+ func_type_map_t ::find_type<FTs>>}),
1271
+ ...);
1197
1272
}
1198
1273
1199
1274
template <DPNPFuncType... FTs>
@@ -1402,39 +1477,6 @@ static void func_map_init_elemwise_2arg_3type(func_map_t& fmap)
1402
1477
fmap[DPNPFuncName::DPNP_FN_DIVIDE][eft_DBL][eft_DBL] = {eft_DBL,
1403
1478
(void *)dpnp_divide_c_default<double , double , double >};
1404
1479
1405
- fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_INT][eft_INT] = {eft_DBL,
1406
- (void *)dpnp_divide_c_ext<double , int32_t , int32_t >};
1407
- fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_INT][eft_LNG] = {eft_DBL,
1408
- (void *)dpnp_divide_c_ext<double , int32_t , int64_t >};
1409
- fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_INT][eft_FLT] = {eft_DBL,
1410
- (void *)dpnp_divide_c_ext<double , int32_t , float >};
1411
- fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_INT][eft_DBL] = {eft_DBL,
1412
- (void *)dpnp_divide_c_ext<double , int32_t , double >};
1413
- fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_LNG][eft_INT] = {eft_DBL,
1414
- (void *)dpnp_divide_c_ext<double , int64_t , int32_t >};
1415
- fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_LNG][eft_LNG] = {eft_DBL,
1416
- (void *)dpnp_divide_c_ext<double , int64_t , int64_t >};
1417
- fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_LNG][eft_FLT] = {eft_DBL,
1418
- (void *)dpnp_divide_c_ext<double , int64_t , float >};
1419
- fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_LNG][eft_DBL] = {eft_DBL,
1420
- (void *)dpnp_divide_c_ext<double , int64_t , double >};
1421
- fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_FLT][eft_INT] = {eft_DBL,
1422
- (void *)dpnp_divide_c_ext<double , float , int32_t >};
1423
- fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_FLT][eft_LNG] = {eft_DBL,
1424
- (void *)dpnp_divide_c_ext<double , float , int64_t >};
1425
- fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_FLT][eft_FLT] = {eft_FLT,
1426
- (void *)dpnp_divide_c_ext<float , float , float >};
1427
- fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_FLT][eft_DBL] = {eft_DBL,
1428
- (void *)dpnp_divide_c_ext<double , float , double >};
1429
- fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_DBL][eft_INT] = {eft_DBL,
1430
- (void *)dpnp_divide_c_ext<double , double , int32_t >};
1431
- fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_DBL][eft_LNG] = {eft_DBL,
1432
- (void *)dpnp_divide_c_ext<double , double , int64_t >};
1433
- fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_DBL][eft_FLT] = {eft_DBL,
1434
- (void *)dpnp_divide_c_ext<double , double , float >};
1435
- fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_DBL][eft_DBL] = {eft_DBL,
1436
- (void *)dpnp_divide_c_ext<double , double , double >};
1437
-
1438
1480
fmap[DPNPFuncName::DPNP_FN_FMOD][eft_INT][eft_INT] = {eft_INT,
1439
1481
(void *)dpnp_fmod_c_default<int32_t , int32_t , int32_t >};
1440
1482
fmap[DPNPFuncName::DPNP_FN_FMOD][eft_INT][eft_LNG] = {eft_LNG,
0 commit comments