Skip to content

Commit db9ee1e

Browse files
committed
dpnp.fmod() doesn't work properly with a scalar
1 parent 4de4ef9 commit db9ee1e

File tree

3 files changed

+103
-75
lines changed

3 files changed

+103
-75
lines changed

dpnp/backend/include/dpnp_gen_2arg_3type_tbl.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,9 @@ MACRO_2ARG_3TYPES_OP(dpnp_divide_c,
138138
MACRO_UNPACK_TYPES(float, double, std::complex<float>, std::complex<double>))
139139

140140
MACRO_2ARG_3TYPES_OP(dpnp_fmod_c,
141-
sycl::fmod((double)input1_elem, (double)input2_elem),
142-
nullptr,
143-
std::false_type,
141+
dispatch_fmod_op(input1_elem, input2_elem),
142+
x1 % x2,
143+
MACRO_UNPACK_TYPES(bool, std::int32_t, std::int64_t),
144144
oneapi::mkl::vm::fmod,
145145
MACRO_UNPACK_TYPES(float, double))
146146

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,18 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
848848
return;
849849
}
850850

851+
template <typename T>
852+
constexpr auto dispatch_fmod_op(T elem1, T elem2)
853+
{
854+
if constexpr (is_any_v<T, std::int32_t, std::int64_t>)
855+
{
856+
return elem1 % elem2;
857+
}
858+
else
859+
{
860+
return sycl::fmod(elem1, elem2);
861+
}
862+
}
851863

852864
#define MACRO_2ARG_3TYPES_OP( \
853865
__name__, __operation__, __vec_operation__, __vec_types__, __mkl_operation__, __mkl_types__) \
@@ -995,8 +1007,8 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
9951007
const size_t output_id = global_id[0]; /* for (size_t i = 0; i < result_size; ++i) */ \
9961008
{ \
9971009
const shape_elem_type* result_strides_data = &dev_strides_data[0]; \
998-
const shape_elem_type* input1_strides_data = &dev_strides_data[1]; \
999-
const shape_elem_type* input2_strides_data = &dev_strides_data[2]; \
1010+
const shape_elem_type* input1_strides_data = &dev_strides_data[result_ndim]; \
1011+
const shape_elem_type* input2_strides_data = &dev_strides_data[2 * result_ndim]; \
10001012
\
10011013
size_t input1_id = 0; \
10021014
size_t input2_id = 0; \
@@ -1261,6 +1273,19 @@ static constexpr DPNPFuncType get_divide_res_type()
12611273
return widest_type;
12621274
}
12631275

1276+
template <DPNPFuncType FT1, DPNPFuncType FT2>
1277+
static constexpr DPNPFuncType get_fmod_res_type()
1278+
{
1279+
constexpr auto widest_type = populate_func_types<FT1, FT2>();
1280+
constexpr auto shortes_type = (widest_type == FT1) ? FT2 : FT1;
1281+
1282+
if constexpr (shortes_type == DPNPFuncType::DPNP_FT_BOOL)
1283+
{
1284+
return DPNPFuncType::DPNP_FT_INT;
1285+
}
1286+
return widest_type;
1287+
}
1288+
12641289
template <DPNPFuncType FT1, DPNPFuncType... FTs>
12651290
static void func_map_elemwise_2arg_3type_core(func_map_t& fmap)
12661291
{
@@ -1300,12 +1325,29 @@ static void func_map_elemwise_2arg_3type_core(func_map_t& fmap)
13001325
...);
13011326
}
13021327

1328+
template <DPNPFuncType FT1, DPNPFuncType... FTs>
1329+
static void func_map_elemwise_2arg_3type_core_no_complex(func_map_t& fmap)
1330+
{
1331+
((fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][FT1][FTs] =
1332+
{get_fmod_res_type<FT1, FTs>(),
1333+
(void*)dpnp_fmod_c_ext<func_type_map_t::find_type<get_fmod_res_type<FT1, FTs>()>,
1334+
func_type_map_t::find_type<FT1>,
1335+
func_type_map_t::find_type<FTs>>}),
1336+
...);
1337+
}
1338+
13031339
template <DPNPFuncType... FTs>
13041340
static void func_map_elemwise_2arg_3type_helper(func_map_t& fmap)
13051341
{
13061342
((func_map_elemwise_2arg_3type_core<FTs, FTs...>(fmap)), ...);
13071343
}
13081344

1345+
template <DPNPFuncType... FTs>
1346+
static void func_map_elemwise_2arg_3type_helper_no_complex(func_map_t& fmap)
1347+
{
1348+
((func_map_elemwise_2arg_3type_core_no_complex<FTs, FTs...>(fmap)), ...);
1349+
}
1350+
13091351
static void func_map_init_elemwise_2arg_3type(func_map_t& fmap)
13101352
{
13111353
fmap[DPNPFuncName::DPNP_FN_ADD][eft_INT][eft_INT] = {eft_INT,
@@ -1539,39 +1581,6 @@ static void func_map_init_elemwise_2arg_3type(func_map_t& fmap)
15391581
fmap[DPNPFuncName::DPNP_FN_FMOD][eft_DBL][eft_DBL] = {eft_DBL,
15401582
(void*)dpnp_fmod_c_default<double, double, double>};
15411583

1542-
fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_INT][eft_INT] = {eft_INT,
1543-
(void*)dpnp_fmod_c_ext<int32_t, int32_t, int32_t>};
1544-
fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_INT][eft_LNG] = {eft_LNG,
1545-
(void*)dpnp_fmod_c_ext<int64_t, int32_t, int64_t>};
1546-
fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_INT][eft_FLT] = {eft_DBL,
1547-
(void*)dpnp_fmod_c_ext<double, int32_t, float>};
1548-
fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_INT][eft_DBL] = {eft_DBL,
1549-
(void*)dpnp_fmod_c_ext<double, int32_t, double>};
1550-
fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_LNG][eft_INT] = {eft_LNG,
1551-
(void*)dpnp_fmod_c_ext<int64_t, int64_t, int32_t>};
1552-
fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_LNG][eft_LNG] = {eft_LNG,
1553-
(void*)dpnp_fmod_c_ext<int64_t, int64_t, int64_t>};
1554-
fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_LNG][eft_FLT] = {eft_DBL,
1555-
(void*)dpnp_fmod_c_ext<double, int64_t, float>};
1556-
fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_LNG][eft_DBL] = {eft_DBL,
1557-
(void*)dpnp_fmod_c_ext<double, int64_t, double>};
1558-
fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_FLT][eft_INT] = {eft_DBL,
1559-
(void*)dpnp_fmod_c_ext<double, float, int32_t>};
1560-
fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_FLT][eft_LNG] = {eft_DBL,
1561-
(void*)dpnp_fmod_c_ext<double, float, int64_t>};
1562-
fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_FLT][eft_FLT] = {eft_FLT,
1563-
(void*)dpnp_fmod_c_ext<float, float, float>};
1564-
fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_FLT][eft_DBL] = {eft_DBL,
1565-
(void*)dpnp_fmod_c_ext<double, float, double>};
1566-
fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_DBL][eft_INT] = {eft_DBL,
1567-
(void*)dpnp_fmod_c_ext<double, double, int32_t>};
1568-
fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_DBL][eft_LNG] = {eft_DBL,
1569-
(void*)dpnp_fmod_c_ext<double, double, int64_t>};
1570-
fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_DBL][eft_FLT] = {eft_DBL,
1571-
(void*)dpnp_fmod_c_ext<double, double, float>};
1572-
fmap[DPNPFuncName::DPNP_FN_FMOD_EXT][eft_DBL][eft_DBL] = {eft_DBL,
1573-
(void*)dpnp_fmod_c_ext<double, double, double>};
1574-
15751584
fmap[DPNPFuncName::DPNP_FN_HYPOT][eft_INT][eft_INT] = {eft_DBL,
15761585
(void*)dpnp_hypot_c_default<double, int32_t, int32_t>};
15771586
fmap[DPNPFuncName::DPNP_FN_HYPOT][eft_INT][eft_LNG] = {eft_DBL,
@@ -1918,6 +1927,7 @@ static void func_map_init_elemwise_2arg_3type(func_map_t& fmap)
19181927
eft_DBL, (void*)dpnp_subtract_c_default<double, double, double>};
19191928

19201929
func_map_elemwise_2arg_3type_helper<eft_BLN, eft_INT, eft_LNG, eft_FLT, eft_DBL, eft_C64, eft_C128>(fmap);
1930+
func_map_elemwise_2arg_3type_helper_no_complex<eft_BLN, eft_INT, eft_LNG, eft_FLT, eft_DBL>(fmap);
19211931

19221932
return;
19231933
}

dpnp/dpnp_iface_mathematical.py

Lines changed: 55 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,41 @@
9595
]
9696

9797

98+
def _check_nd_call(origin_func, dpnp_func, x1, x2, out=None, where=True, dtype=None, subok=True, **kwargs):
99+
"""Choose function to call based on input and call chosen fucntion."""
100+
101+
if kwargs:
102+
pass
103+
elif where is not True:
104+
pass
105+
elif dtype is not None:
106+
pass
107+
elif subok is not True:
108+
pass
109+
elif dpnp.isscalar(x1) and dpnp.isscalar(x2):
110+
# at least either x1 or x2 has to be an array
111+
pass
112+
else:
113+
# get USM type and queue to copy scalar from the host memory into a USM allocation
114+
usm_type, queue = get_usm_allocations([x1, x2]) if dpnp.isscalar(x1) or dpnp.isscalar(x2) else (None, None)
115+
116+
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False,
117+
alloc_usm_type=usm_type, alloc_queue=queue)
118+
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False,
119+
alloc_usm_type=usm_type, alloc_queue=queue)
120+
if x1_desc and x2_desc:
121+
if out is not None:
122+
if not isinstance(out, (dpnp.ndarray, dpt.usm_ndarray)):
123+
raise TypeError("return array must be of supported array type")
124+
out_desc = dpnp.get_dpnp_descriptor(out, copy_when_nondefault_queue=False) or None
125+
else:
126+
out_desc = None
127+
128+
return dpnp_func(x1_desc, x2_desc, dtype=dtype, out=out_desc, where=where).get_pyobj()
129+
130+
return call_origin(origin_func, x1, x2, dtype=dtype, out=out, where=where, **kwargs)
131+
132+
98133
def abs(*args, **kwargs):
99134
"""
100135
Calculate the absolute value element-wise.
@@ -852,63 +887,46 @@ def fmin(*args, **kwargs):
852887
return dpnp.minimum(*args, **kwargs)
853888

854889

855-
def fmod(x1, x2, dtype=None, out=None, where=True, **kwargs):
890+
def fmod(x1,
891+
x2,
892+
/,
893+
out=None,
894+
*,
895+
where=True,
896+
dtype=None,
897+
subok=True,
898+
**kwargs):
856899
"""
857900
Calculate the element-wise remainder of division.
858901
859902
For full documentation refer to :obj:`numpy.fmod`.
860903
861904
Limitations
862905
-----------
863-
Parameters ``x1`` and ``x2`` are supported as either :obj:`dpnp.ndarray` or scalar.
864-
Parameters ``dtype``, ``out`` and ``where`` are supported with their default values.
865-
Keyword arguments ``kwargs`` are currently unsupported.
866-
Otherwise the functions will be executed sequentially on CPU.
906+
Parameters `x1` and `x2` are supported as either scalar, :class:`dpnp.ndarray`
907+
or :class:`dpctl.tensor.usm_ndarray`, but both `x1` and `x2` can not be scalars at the same time.
908+
Parameters `where`, `dtype` and `subok` are supported with their default values.
909+
Keyword argument `kwargs` is currently unsupported.
910+
Otherwise the function will be executed sequentially on CPU.
867911
Input array data types are limited by supported DPNP :ref:`Data types`.
868912
869913
See Also
870914
--------
871-
:obj:`dpnp.reminder` : Remainder complementary to floor_divide.
915+
:obj:`dpnp.remainder` : Remainder complementary to floor_divide.
872916
:obj:`dpnp.divide` : Standard division.
873917
874918
Examples
875919
--------
876-
>>> import dpnp as np
877-
>>> a = np.array([2, -3, 4, 5, -4.5])
878-
>>> b = np.array([2, 2, 2, 2, 2])
879-
>>> result = np.fmod(a, b)
920+
>>> import dpnp as dp
921+
>>> a = dp.array([2, -3, 4, 5, -4.5])
922+
>>> b = dp.array([2, 2, 2, 2, 2])
923+
>>> result = dp.fmod(a, b)
880924
>>> [x for x in result]
881925
[0.0, -1.0, 0.0, 1.0, -0.5]
882926
883927
"""
884928

885-
x1_is_scalar = dpnp.isscalar(x1)
886-
x2_is_scalar = dpnp.isscalar(x2)
887-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False)
888-
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False)
889-
890-
if x1_desc and x2_desc and not kwargs:
891-
if not x1_desc and not x1_is_scalar:
892-
pass
893-
elif not x2_desc and not x2_is_scalar:
894-
pass
895-
elif x1_is_scalar and x2_is_scalar:
896-
pass
897-
elif x1_desc and x1_desc.ndim == 0:
898-
pass
899-
elif x2_desc and x2_desc.ndim == 0:
900-
pass
901-
elif dtype is not None:
902-
pass
903-
elif out is not None:
904-
pass
905-
elif not where:
906-
pass
907-
else:
908-
out_desc = dpnp.get_dpnp_descriptor(out, copy_when_nondefault_queue=False) if out is not None else None
909-
return dpnp_fmod(x1_desc, x2_desc, dtype, out_desc, where).get_pyobj()
910-
911-
return call_origin(numpy.fmod, x1, x2, dtype=dtype, out=out, where=where, **kwargs)
929+
return _check_nd_call(numpy.fmod, dpnp_fmod, x1, x2, out=out, where=where, dtype=dtype, subok=subok, **kwargs)
912930

913931

914932
def gradient(x1, *varargs, **kwargs):

0 commit comments

Comments
 (0)