Skip to content

Commit 19715de

Browse files
authored
dpnp.divide() doesn't work properly with a scalar (#1295)
* dpnp.add() doesn't work properly with a scalar * dpnp.subtract() doesn't work properly with a scalar * dpnp.divide() doesn't work properly with a scalar * dpnp.divide() doesn't work properly with a scalar * Use std::int32_t and std::int64_t types * Disable floating-point optimizations that assume arguments and results are not NaNs or +-Inf * Fix issue with divide on Iris Xe
1 parent 6640e9e commit 19715de

18 files changed

+312
-149
lines changed

dpnp/backend/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ string(CONCAT COMMON_COMPILE_FLAGS
9393
"-fsycl "
9494
"-fsycl-device-code-split=per_kernel "
9595
"-fno-approx-func "
96+
"-fno-finite-math-only "
9697
)
9798
string(CONCAT COMMON_LINK_FLAGS
9899
"-fsycl "

dpnp/backend/include/dpnp_gen_2arg_3type_tbl.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,10 @@ MACRO_2ARG_3TYPES_OP(dpnp_copysign_c,
132132

133133
MACRO_2ARG_3TYPES_OP(dpnp_divide_c,
134134
input1_elem / input2_elem,
135-
nullptr,
136-
std::false_type,
135+
x1 / x2,
136+
MACRO_UNPACK_TYPES(bool, std::int32_t, std::int64_t),
137137
oneapi::mkl::vm::div,
138-
MACRO_UNPACK_TYPES(float, double))
138+
MACRO_UNPACK_TYPES(float, double, std::complex<float>, std::complex<double>))
139139

140140
MACRO_2ARG_3TYPES_OP(dpnp_fmod_c,
141141
sycl::fmod((double)input1_elem, (double)input2_elem),
@@ -169,7 +169,7 @@ MACRO_2ARG_3TYPES_OP(dpnp_minimum_c,
169169
// pytest "tests/third_party/cupy/creation_tests/test_ranges.py::TestMgrid::test_mgrid3"
170170
// requires multiplication shape1[10] with shape2[10,1] and result expected as shape[10,10]
171171
MACRO_2ARG_3TYPES_OP(dpnp_multiply_c,
172-
input1_elem* input2_elem,
172+
input1_elem * input2_elem,
173173
x1 * x2,
174174
MACRO_UNPACK_TYPES(bool, std::int32_t, std::int64_t),
175175
oneapi::mkl::vm::mul,

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,8 +417,26 @@ size_t operator-(DPNPFuncType lhs, DPNPFuncType rhs);
417417
*/
418418
typedef struct DPNPFuncData
419419
{
420-
DPNPFuncType return_type; /**< return type identifier which expected by the @ref ptr function */
421-
void* ptr; /**< C++ backend function pointer */
420+
DPNPFuncData(const DPNPFuncType gen_type, void* gen_ptr, const DPNPFuncType type_no_fp64, void* ptr_no_fp64)
421+
: return_type(gen_type)
422+
, ptr(gen_ptr)
423+
, return_type_no_fp64(type_no_fp64)
424+
, ptr_no_fp64(ptr_no_fp64)
425+
{
426+
}
427+
DPNPFuncData(const DPNPFuncType gen_type, void* gen_ptr)
428+
: DPNPFuncData(gen_type, gen_ptr, DPNPFuncType::DPNP_FT_NONE, nullptr)
429+
{
430+
}
431+
DPNPFuncData()
432+
: DPNPFuncData(DPNPFuncType::DPNP_FT_NONE, nullptr)
433+
{
434+
}
435+
436+
DPNPFuncType return_type; /**< return type identifier which expected by the @ref ptr function */
437+
void* ptr; /**< C++ backend function pointer */
438+
DPNPFuncType return_type_no_fp64; /**< alternative return type identifier when no fp64 support by device */
439+
void* ptr_no_fp64; /**< alternative C++ backend function pointer when no fp64 support by device */
422440
} DPNPFuncData_t;
423441

424442
/**

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 82 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,26 +1029,50 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
10291029
\
10301030
if (start + static_cast<size_t>(vec_sz) * max_sg_size < result_size) \
10311031
{ \
1032-
sycl::vec<_DataType_input1, vec_sz> x1 = \
1033-
sg.load<vec_sz>(sycl::multi_ptr<_DataType_input1, global_space>(&input1_data[start])); \
1034-
sycl::vec<_DataType_input2, vec_sz> x2 = \
1035-
sg.load<vec_sz>(sycl::multi_ptr<_DataType_input2, global_space>(&input2_data[start])); \
1032+
using input1_ptrT = sycl::multi_ptr<_DataType_input1, global_space>; \
1033+
using input2_ptrT = sycl::multi_ptr<_DataType_input2, global_space>; \
1034+
using result_ptrT = sycl::multi_ptr<_DataType_output, global_space>; \
1035+
\
10361036
sycl::vec<_DataType_output, vec_sz> res_vec; \
10371037
\
1038-
if constexpr (both_types_are_same<_DataType_input1, _DataType_input2, __vec_types__>) \
1038+
if constexpr (both_types_are_any_of<_DataType_input1, _DataType_input2, __vec_types__>) \
10391039
{ \
1040-
res_vec = __vec_operation__; \
1040+
if constexpr (both_types_are_same<_DataType_input1, _DataType_input2, _DataType_output>) \
1041+
{ \
1042+
sycl::vec<_DataType_input1, vec_sz> x1 = \
1043+
sg.load<vec_sz>(input1_ptrT(&input1_data[start])); \
1044+
sycl::vec<_DataType_input2, vec_sz> x2 = \
1045+
sg.load<vec_sz>(input2_ptrT(&input2_data[start])); \
1046+
\
1047+
res_vec = __vec_operation__; \
1048+
} \
1049+
else /* input types don't match result type, so explicit casting is required */ \
1050+
{ \
1051+
sycl::vec<_DataType_output, vec_sz> x1 = \
1052+
dpnp_vec_cast<_DataType_output, _DataType_input1, vec_sz>( \
1053+
sg.load<vec_sz>(input1_ptrT(&input1_data[start]))); \
1054+
sycl::vec<_DataType_output, vec_sz> x2 = \
1055+
dpnp_vec_cast<_DataType_output, _DataType_input2, vec_sz>( \
1056+
sg.load<vec_sz>(input2_ptrT(&input2_data[start]))); \
1057+
\
1058+
res_vec = __vec_operation__; \
1059+
} \
10411060
} \
10421061
else \
10431062
{ \
1063+
sycl::vec<_DataType_input1, vec_sz> x1 = \
1064+
sg.load<vec_sz>(input1_ptrT(&input1_data[start])); \
1065+
sycl::vec<_DataType_input2, vec_sz> x2 = \
1066+
sg.load<vec_sz>(input2_ptrT(&input2_data[start])); \
1067+
\
10441068
for (size_t k = 0; k < vec_sz; ++k) \
10451069
{ \
10461070
const _DataType_output input1_elem = x1[k]; \
10471071
const _DataType_output input2_elem = x2[k]; \
10481072
res_vec[k] = __operation__; \
10491073
} \
10501074
} \
1051-
sg.store<vec_sz>(sycl::multi_ptr<_DataType_output, global_space>(&result[start]), res_vec); \
1075+
sg.store<vec_sz>(result_ptrT(&result[start]), res_vec); \
10521076
} \
10531077
else \
10541078
{ \
@@ -1173,6 +1197,47 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
11731197

11741198
#include <dpnp_gen_2arg_3type_tbl.hpp>
11751199

1200+
template <DPNPFuncType FT1, DPNPFuncType FT2, typename has_fp64 = std::true_type>
1201+
static constexpr DPNPFuncType get_divide_res_type()
1202+
{
1203+
constexpr auto widest_type = populate_func_types<FT1, FT2>();
1204+
constexpr auto shortes_type = (widest_type == FT1) ? FT2 : FT1;
1205+
1206+
if constexpr (widest_type == DPNPFuncType::DPNP_FT_CMPLX128 || widest_type == DPNPFuncType::DPNP_FT_DOUBLE)
1207+
{
1208+
return widest_type;
1209+
}
1210+
else if constexpr (widest_type == DPNPFuncType::DPNP_FT_CMPLX64)
1211+
{
1212+
if constexpr (shortes_type == DPNPFuncType::DPNP_FT_DOUBLE)
1213+
{
1214+
return DPNPFuncType::DPNP_FT_CMPLX128;
1215+
}
1216+
else if constexpr (has_fp64::value &&
1217+
(shortes_type == DPNPFuncType::DPNP_FT_INT || shortes_type == DPNPFuncType::DPNP_FT_LONG))
1218+
{
1219+
return DPNPFuncType::DPNP_FT_CMPLX128;
1220+
}
1221+
}
1222+
else if constexpr (widest_type == DPNPFuncType::DPNP_FT_FLOAT)
1223+
{
1224+
if constexpr (has_fp64::value &&
1225+
(shortes_type == DPNPFuncType::DPNP_FT_INT || shortes_type == DPNPFuncType::DPNP_FT_LONG))
1226+
{
1227+
return DPNPFuncType::DPNP_FT_DOUBLE;
1228+
}
1229+
}
1230+
else if constexpr (has_fp64::value)
1231+
{
1232+
return DPNPFuncType::DPNP_FT_DOUBLE;
1233+
}
1234+
else
1235+
{
1236+
return DPNPFuncType::DPNP_FT_FLOAT;
1237+
}
1238+
return widest_type;
1239+
}
1240+
11761241
template <DPNPFuncType FT1, DPNPFuncType... FTs>
11771242
static void func_map_elemwise_2arg_3type_core(func_map_t& fmap)
11781243
{
@@ -1194,6 +1259,16 @@ static void func_map_elemwise_2arg_3type_core(func_map_t& fmap)
11941259
func_type_map_t::find_type<FT1>,
11951260
func_type_map_t::find_type<FTs>>}),
11961261
...);
1262+
((fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][FT1][FTs] =
1263+
{get_divide_res_type<FT1, FTs>(),
1264+
(void*)dpnp_divide_c_ext<func_type_map_t::find_type<get_divide_res_type<FT1, FTs>()>,
1265+
func_type_map_t::find_type<FT1>,
1266+
func_type_map_t::find_type<FTs>>,
1267+
get_divide_res_type<FT1, FTs, std::false_type>(),
1268+
(void*)dpnp_divide_c_ext<func_type_map_t::find_type<get_divide_res_type<FT1, FTs, std::false_type>()>,
1269+
func_type_map_t::find_type<FT1>,
1270+
func_type_map_t::find_type<FTs>>}),
1271+
...);
11971272
}
11981273

11991274
template <DPNPFuncType... FTs>
@@ -1402,39 +1477,6 @@ static void func_map_init_elemwise_2arg_3type(func_map_t& fmap)
14021477
fmap[DPNPFuncName::DPNP_FN_DIVIDE][eft_DBL][eft_DBL] = {eft_DBL,
14031478
(void*)dpnp_divide_c_default<double, double, double>};
14041479

1405-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_INT][eft_INT] = {eft_DBL,
1406-
(void*)dpnp_divide_c_ext<double, int32_t, int32_t>};
1407-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_INT][eft_LNG] = {eft_DBL,
1408-
(void*)dpnp_divide_c_ext<double, int32_t, int64_t>};
1409-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_INT][eft_FLT] = {eft_DBL,
1410-
(void*)dpnp_divide_c_ext<double, int32_t, float>};
1411-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_INT][eft_DBL] = {eft_DBL,
1412-
(void*)dpnp_divide_c_ext<double, int32_t, double>};
1413-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_LNG][eft_INT] = {eft_DBL,
1414-
(void*)dpnp_divide_c_ext<double, int64_t, int32_t>};
1415-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_LNG][eft_LNG] = {eft_DBL,
1416-
(void*)dpnp_divide_c_ext<double, int64_t, int64_t>};
1417-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_LNG][eft_FLT] = {eft_DBL,
1418-
(void*)dpnp_divide_c_ext<double, int64_t, float>};
1419-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_LNG][eft_DBL] = {eft_DBL,
1420-
(void*)dpnp_divide_c_ext<double, int64_t, double>};
1421-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_FLT][eft_INT] = {eft_DBL,
1422-
(void*)dpnp_divide_c_ext<double, float, int32_t>};
1423-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_FLT][eft_LNG] = {eft_DBL,
1424-
(void*)dpnp_divide_c_ext<double, float, int64_t>};
1425-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_FLT][eft_FLT] = {eft_FLT,
1426-
(void*)dpnp_divide_c_ext<float, float, float>};
1427-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_FLT][eft_DBL] = {eft_DBL,
1428-
(void*)dpnp_divide_c_ext<double, float, double>};
1429-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_DBL][eft_INT] = {eft_DBL,
1430-
(void*)dpnp_divide_c_ext<double, double, int32_t>};
1431-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_DBL][eft_LNG] = {eft_DBL,
1432-
(void*)dpnp_divide_c_ext<double, double, int64_t>};
1433-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_DBL][eft_FLT] = {eft_DBL,
1434-
(void*)dpnp_divide_c_ext<double, double, float>};
1435-
fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_DBL][eft_DBL] = {eft_DBL,
1436-
(void*)dpnp_divide_c_ext<double, double, double>};
1437-
14381480
fmap[DPNPFuncName::DPNP_FN_FMOD][eft_INT][eft_INT] = {eft_INT,
14391481
(void*)dpnp_fmod_c_default<int32_t, int32_t, int32_t>};
14401482
fmap[DPNPFuncName::DPNP_FN_FMOD][eft_INT][eft_LNG] = {eft_LNG,

dpnp/backend/src/dpnp_fptr.hpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
#include <map>
3636
#include <complex>
3737

38+
#include <CL/sycl.hpp>
39+
3840
#include <dpnp_iface_fptr.hpp>
3941

4042
/**
@@ -116,6 +118,31 @@ static constexpr DPNPFuncType populate_func_types()
116118
return (FT1 < FT2) ? FT2 : FT1;
117119
}
118120

121+
/**
122+
* @brief A helper function to cast SYCL vector between types.
123+
*/
124+
template <typename Op, typename Vec, std::size_t... I>
125+
static auto dpnp_vec_cast_impl(const Vec& v, std::index_sequence<I...>)
126+
{
127+
return Op{v[I]...};
128+
}
129+
130+
/**
131+
* @brief A casting function for SYCL vector.
132+
*
133+
* @tparam dstT A result type upon casting.
134+
* @tparam srcT An incoming type of the vector.
135+
* @tparam N A number of elements with the vector.
136+
* @tparam Indices A sequence of integers
137+
* @param s An incoming SYCL vector to cast.
138+
* @return SYCL vector casted to desctination type.
139+
*/
140+
template <typename dstT, typename srcT, std::size_t N, typename Indices = std::make_index_sequence<N>>
141+
static auto dpnp_vec_cast(const sycl::vec<srcT, N>& s)
142+
{
143+
return dpnp_vec_cast_impl<sycl::vec<dstT, N>, sycl::vec<srcT, N>>(s, Indices{});
144+
}
145+
119146
/**
120147
* Removes parentheses for a passed list of types separated by comma.
121148
* It's intended to be used in operations macro.
@@ -142,6 +169,12 @@ struct are_same : std::conjunction<std::is_same<T, Ts>...> {};
142169
template <typename T1, typename T2, typename... Ts>
143170
constexpr auto both_types_are_same = std::conjunction_v<is_any<T1, Ts...>, are_same<T1, T2>>;
144171

172+
/**
173+
* A template constat to check if both types T1 and T2 match any type from Ts.
174+
*/
175+
template <typename T1, typename T2, typename... Ts>
176+
constexpr auto both_types_are_any_of = std::conjunction_v<is_any<T1, Ts...>, is_any<T2, Ts...>>;
177+
145178
/**
146179
* A template constat to check if both types T1 and T2 don't match any type from Ts sequence.
147180
*/

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,8 @@ cdef extern from "dpnp_iface_fptr.hpp":
374374
struct DPNPFuncData:
375375
DPNPFuncType return_type
376376
void * ptr
377+
DPNPFuncType return_type_no_fp64
378+
void *ptr_no_fp64
377379

378380
DPNPFuncData get_dpnp_function_ptr(DPNPFuncName name, DPNPFuncType first_type, DPNPFuncType second_type) except +
379381

0 commit comments

Comments
 (0)