Skip to content

Commit 5d94ca8

Browse files
authored
Implement dpnp.nancumprod() through dpnp.cumprod call (#1812)
* Implement dpnp.cumprod through dpctl.tensor * Implement dpnp.nancumprod() through existing calls * Applied review comments
1 parent a079815 commit 5d94ca8

File tree

9 files changed

+50
-157
lines changed

9 files changed

+50
-157
lines changed

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,6 @@ enum class DPNPFuncName : size_t
105105
DPNP_FN_COV, /**< Used in numpy.cov() impl */
106106
DPNP_FN_CROSS, /**< Used in numpy.cross() impl */
107107
DPNP_FN_CUMPROD, /**< Used in numpy.cumprod() impl */
108-
DPNP_FN_CUMPROD_EXT, /**< Used in numpy.cumprod() impl, requires extra
109-
parameters */
110108
DPNP_FN_CUMSUM, /**< Used in numpy.cumsum() impl */
111109
DPNP_FN_DEGREES, /**< Used in numpy.degrees() impl */
112110
DPNP_FN_DEGREES_EXT, /**< Used in numpy.degrees() impl, requires extra

dpnp/backend/kernels/dpnp_krnl_mathematical.cpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -364,14 +364,6 @@ template <typename _DataType_input, typename _DataType_output>
364364
void (*dpnp_cumprod_default_c)(void *, void *, size_t) =
365365
dpnp_cumprod_c<_DataType_input, _DataType_output>;
366366

367-
template <typename _DataType_input, typename _DataType_output>
368-
DPCTLSyclEventRef (*dpnp_cumprod_ext_c)(DPCTLSyclQueueRef,
369-
void *,
370-
void *,
371-
size_t,
372-
const DPCTLEventVectorRef) =
373-
dpnp_cumprod_c<_DataType_input, _DataType_output>;
374-
375367
template <typename _KernelNameSpecialization1,
376368
typename _KernelNameSpecialization2>
377369
class dpnp_cumsum_c_kernel;
@@ -1153,15 +1145,6 @@ void func_map_init_mathematical(func_map_t &fmap)
11531145
fmap[DPNPFuncName::DPNP_FN_CUMPROD][eft_DBL][eft_DBL] = {
11541146
eft_DBL, (void *)dpnp_cumprod_default_c<double, double>};
11551147

1156-
fmap[DPNPFuncName::DPNP_FN_CUMPROD_EXT][eft_INT][eft_INT] = {
1157-
eft_LNG, (void *)dpnp_cumprod_ext_c<int32_t, int64_t>};
1158-
fmap[DPNPFuncName::DPNP_FN_CUMPROD_EXT][eft_LNG][eft_LNG] = {
1159-
eft_LNG, (void *)dpnp_cumprod_ext_c<int64_t, int64_t>};
1160-
fmap[DPNPFuncName::DPNP_FN_CUMPROD_EXT][eft_FLT][eft_FLT] = {
1161-
eft_FLT, (void *)dpnp_cumprod_ext_c<float, float>};
1162-
fmap[DPNPFuncName::DPNP_FN_CUMPROD_EXT][eft_DBL][eft_DBL] = {
1163-
eft_DBL, (void *)dpnp_cumprod_ext_c<double, double>};
1164-
11651148
fmap[DPNPFuncName::DPNP_FN_CUMSUM][eft_INT][eft_INT] = {
11661149
eft_LNG, (void *)dpnp_cumsum_default_c<int32_t, int64_t>};
11671150
fmap[DPNPFuncName::DPNP_FN_CUMSUM][eft_LNG][eft_LNG] = {

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
3737
DPNP_FN_CHOOSE_EXT
3838
DPNP_FN_COPY_EXT
3939
DPNP_FN_CORRELATE_EXT
40-
DPNP_FN_CUMPROD_EXT
4140
DPNP_FN_DEGREES_EXT
4241
DPNP_FN_DIAG_INDICES_EXT
4342
DPNP_FN_DIAGONAL_EXT
@@ -127,9 +126,6 @@ cdef extern from "dpnp_iface.hpp":
127126

128127

129128
# C function pointer to the C library template functions
130-
ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_1in_1out_t)(c_dpctl.DPCTLSyclQueueRef,
131-
void *, void * , size_t,
132-
const c_dpctl.DPCTLEventVectorRef)
133129
ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_1in_1out_strides_t)(c_dpctl.DPCTLSyclQueueRef,
134130
void *, const size_t, const size_t,
135131
const shape_elem_type * , const shape_elem_type * ,

dpnp/dpnp_algo/dpnp_algo.pyx

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -147,56 +147,6 @@ cdef dpnp_DPNPFuncType_to_dtype(size_t type):
147147
utils.checker_throw_type_error("dpnp_DPNPFuncType_to_dtype", type)
148148

149149

150-
cdef utils.dpnp_descriptor call_fptr_1in_1out(DPNPFuncName fptr_name,
151-
utils.dpnp_descriptor x1,
152-
shape_type_c result_shape,
153-
utils.dpnp_descriptor out=None,
154-
func_name=None):
155-
156-
""" Convert type (x1.dtype) to C enum DPNPFuncType """
157-
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype)
158-
159-
""" get the FPTR data structure """
160-
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(fptr_name, param1_type, param1_type)
161-
162-
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
163-
164-
cdef utils.dpnp_descriptor result
165-
166-
if out is None:
167-
""" Create result array with type given by FPTR data """
168-
x1_obj = x1.get_array()
169-
result = utils.create_output_descriptor(result_shape,
170-
kernel_data.return_type,
171-
None,
172-
device=x1_obj.sycl_device,
173-
usm_type=x1_obj.usm_type,
174-
sycl_queue=x1_obj.sycl_queue)
175-
else:
176-
if out.dtype != result_type:
177-
utils.checker_throw_value_error(func_name, 'out.dtype', out.dtype, result_type)
178-
if out.shape != result_shape:
179-
utils.checker_throw_value_error(func_name, 'out.shape', out.shape, result_shape)
180-
181-
result = out
182-
183-
utils.get_common_usm_allocation(x1, result) # check USM allocation is common
184-
185-
result_sycl_queue = result.get_array().sycl_queue
186-
187-
cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
188-
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()
189-
190-
cdef fptr_1in_1out_t func = <fptr_1in_1out_t > kernel_data.ptr
191-
192-
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref, x1.get_data(), result.get_data(), x1.size, NULL)
193-
194-
with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
195-
c_dpctl.DPCTLEvent_Delete(event_ref)
196-
197-
return result
198-
199-
200150
cdef utils.dpnp_descriptor call_fptr_1in_1out_strides(DPNPFuncName fptr_name,
201151
utils.dpnp_descriptor x1,
202152
object dtype=None,

dpnp/dpnp_algo/dpnp_algo_mathematical.pxi

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ __all__ += [
4343
"dpnp_fmax",
4444
"dpnp_fmin",
4545
"dpnp_modf",
46-
"dpnp_nancumprod",
4746
"dpnp_trapz",
4847
]
4948

@@ -56,18 +55,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*ftpr_custom_trapz_2in_1out_with_2size_t)(c_d
5655
const c_dpctl.DPCTLEventVectorRef)
5756

5857

59-
cpdef utils.dpnp_descriptor dpnp_cumprod(utils.dpnp_descriptor x1):
60-
# instead of x1.shape, (x1.size, ) is passed to the function
61-
# due to the following:
62-
# >>> import numpy
63-
# >>> a = numpy.array([[1, 2], [2, 3]])
64-
# >>> res = numpy.cumprod(a)
65-
# >>> res.shape
66-
# (4,)
67-
68-
return call_fptr_1in_1out(DPNP_FN_CUMPROD_EXT, x1, (x1.size,))
69-
70-
7158
cpdef utils.dpnp_descriptor dpnp_ediff1d(utils.dpnp_descriptor x1):
7259

7360
if x1.size <= 1:
@@ -226,19 +213,6 @@ cpdef tuple dpnp_modf(utils.dpnp_descriptor x1):
226213
return (result1.get_pyobj(), result2.get_pyobj())
227214

228215

229-
cpdef utils.dpnp_descriptor dpnp_nancumprod(utils.dpnp_descriptor x1):
230-
cur_x1 = x1.get_pyobj().copy()
231-
232-
cur_x1_flatiter = cur_x1.flat
233-
234-
for i in range(cur_x1.size):
235-
if dpnp.isnan(cur_x1_flatiter[i]):
236-
cur_x1_flatiter[i] = 1
237-
238-
x1_desc = dpnp.get_dpnp_descriptor(cur_x1, copy_when_nondefault_queue=False)
239-
return dpnp_cumprod(x1_desc)
240-
241-
242216
cpdef utils.dpnp_descriptor dpnp_trapz(utils.dpnp_descriptor y1, utils.dpnp_descriptor x1, double dx):
243217

244218
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(y1.dtype)

dpnp/dpnp_iface_nanfunctions.py

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,8 @@
3939

4040
import warnings
4141

42-
import numpy
43-
4442
import dpnp
4543

46-
# pylint: disable=no-name-in-module
47-
from .dpnp_algo import (
48-
dpnp_nancumprod,
49-
)
50-
from .dpnp_utils import (
51-
call_origin,
52-
)
53-
5444
__all__ = [
5545
"nanargmax",
5646
"nanargmin",
@@ -249,19 +239,40 @@ def nanargmin(a, axis=None, out=None, *, keepdims=False):
249239
return dpnp.argmin(a, axis=axis, out=out, keepdims=keepdims)
250240

251241

252-
def nancumprod(x1, **kwargs):
242+
def nancumprod(a, axis=None, dtype=None, out=None):
253243
"""
254244
Return the cumulative product of array elements over a given axis treating
255-
Not a Numbers (NaNs) as one.
245+
Not a Numbers (NaNs) as zero. The cumulative product does not change when
246+
NaNs are encountered and leading NaNs are replaced by ones.
256247
257248
For full documentation refer to :obj:`numpy.nancumprod`.
258249
259-
Limitations
260-
-----------
261-
Parameter `x` is supported as :class:`dpnp.ndarray`.
262-
Keyword argument `kwargs` is currently unsupported.
263-
Otherwise the function will be executed sequentially on CPU.
264-
Input array data types are limited by supported DPNP :ref:`Data types`.
250+
Parameters
251+
----------
252+
a : {dpnp.ndarray, usm_ndarray}
253+
Input array.
254+
axis : {None, int}, optional
255+
Axis along which the cumulative product is computed. The default
256+
(``None``) is to compute the cumulative product over the flattened
257+
array.
258+
dtype : {None, dtype}, optional
259+
Type of the returned array and of the accumulator in which the elements
260+
are summed. If `dtype` is not specified, it defaults to the dtype of
261+
`a`, unless `a` has an integer dtype with a precision less than that of
262+
the default platform integer. In that case, the default platform
263+
integer is used.
264+
out : {None, dpnp.ndarray, usm_ndarray}, optional
265+
Alternative output array in which to place the result. It must have the
266+
same shape and buffer length as the expected output but the type will
267+
be cast if necessary.
268+
269+
Returns
270+
-------
271+
out : dpnp.ndarray
272+
A new array holding the result is returned unless `out` is specified as
273+
:class:`dpnp.ndarray`, in which case a reference to `out` is returned.
274+
The result has the same size as `a`, and the same shape as `a` if `axis`
275+
is not ``None`` or `a` is a 1-d array.
265276
266277
See Also
267278
--------
@@ -271,22 +282,26 @@ def nancumprod(x1, **kwargs):
271282
Examples
272283
--------
273284
>>> import dpnp as np
274-
>>> a = np.array([1., np.nan])
275-
>>> result = np.nancumprod(a)
276-
>>> [x for x in result]
277-
[1.0, 1.0]
278-
>>> b = np.array([[1., 2., np.nan], [4., np.nan, 6.]])
279-
>>> result = np.nancumprod(b)
280-
>>> [x for x in result]
281-
[1.0, 2.0, 2.0, 8.0, 8.0, 48.0]
285+
>>> np.nancumprod(np.array(1))
286+
array(1)
287+
>>> np.nancumprod(np.array([1]))
288+
array([1])
289+
>>> np.nancumprod(np.array([1, np.nan]))
290+
array([1., 1.])
291+
>>> a = np.array([[1, 2], [3, np.nan]])
292+
>>> np.nancumprod(a)
293+
array([1., 2., 6., 6.])
294+
>>> np.nancumprod(a, axis=0)
295+
array([[1., 2.],
296+
[3., 2.]])
297+
>>> np.nancumprod(a, axis=1)
298+
array([[1., 2.],
299+
[3., 3.]])
282300
283301
"""
284302

285-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
286-
if x1_desc and not kwargs:
287-
return dpnp_nancumprod(x1_desc).get_pyobj()
288-
289-
return call_origin(numpy.nancumprod, x1, **kwargs)
303+
a, _ = _replace_nan(a, 1)
304+
return dpnp.cumprod(a, axis=axis, dtype=dtype, out=out)
290305

291306

292307
def nancumsum(a, axis=None, dtype=None, out=None):
@@ -332,7 +347,7 @@ def nancumsum(a, axis=None, dtype=None, out=None):
332347
--------
333348
>>> import dpnp as np
334349
>>> np.nancumsum(np.array(1))
335-
array([1])
350+
array(1)
336351
>>> np.nancumsum(np.array([1]))
337352
array([1])
338353
>>> np.nancumsum(np.array([1, np.nan]))

tests/test_nanfunctions.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,6 @@
2525
)
2626
)
2727
class TestNanCumSumProd:
28-
@pytest.fixture(autouse=True)
29-
def setUp(self):
30-
if self.func == "nancumprod":
31-
pytest.skip("nancumprod() is not implemented yet")
32-
pass
33-
3428
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
3529
@pytest.mark.parametrize(
3630
"array",

tests/test_usm_type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,7 @@ def test_norm(usm_type, ord, axis):
542542
pytest.param("min", [1.0, 2.0, 4.0, 7.0]),
543543
pytest.param("nanargmax", [1.0, 2.0, 4.0, dp.nan]),
544544
pytest.param("nanargmin", [1.0, 2.0, 4.0, dp.nan]),
545+
pytest.param("nancumprod", [3.0, dp.nan]),
545546
pytest.param("nancumsum", [3.0, dp.nan]),
546547
pytest.param("nanmax", [1.0, 2.0, 4.0, dp.nan]),
547548
pytest.param("nanmean", [1.0, 2.0, 4.0, dp.nan]),

tests/third_party/cupy/math_tests/test_sumprod.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -235,17 +235,6 @@ def _numpy_nanprod_implemented(self):
235235
)
236236

237237
def _test(self, xp, dtype):
238-
if (
239-
self.func == "nanprod"
240-
and self.shape == (20, 30, 40)
241-
and has_support_aspect64()
242-
):
243-
# If input type is float, NumPy returns the same data type but
244-
# dpctl (and dpnp) returns default platform float following array api.
245-
# When input is `float32` and output is a very large number, dpnp returns
246-
# the number because it is `float64` but NumPy returns `inf` since it is `float32`.
247-
pytest.skip("Output is a very large number.")
248-
249238
a = testing.shaped_arange(self.shape, xp, dtype)
250239
if self.transpose_axes:
251240
a = a.transpose(2, 0, 1)
@@ -265,9 +254,7 @@ def test_nansum_all(self, xp, dtype):
265254
return self._test(xp, dtype)
266255

267256
@testing.for_all_dtypes(no_bool=True, no_float16=True)
268-
@testing.numpy_cupy_allclose(
269-
contiguous_check=False, type_check=has_support_aspect64()
270-
)
257+
@testing.numpy_cupy_allclose(type_check=has_support_aspect64())
271258
def test_nansum_axis_transposed(self, xp, dtype):
272259
if (
273260
not self._numpy_nanprod_implemented()
@@ -579,6 +566,7 @@ def test_cumproduct_alias(self, xp):
579566
return xp.cumproduct(a)
580567

581568

569+
@pytest.mark.usefixtures("suppress_invalid_numpy_warnings")
582570
@testing.parameterize(
583571
*testing.product(
584572
{
@@ -591,12 +579,6 @@ def test_cumproduct_alias(self, xp):
591579
class TestNanCumSumProd:
592580
zero_density = 0.25
593581

594-
@pytest.fixture(autouse=True)
595-
def setUp(self):
596-
if self.func == "nancumprod":
597-
pytest.skip("nancumprod() is not implemented yet")
598-
pass
599-
600582
def _make_array(self, dtype):
601583
dtype = numpy.dtype(dtype)
602584
if dtype.char in "efdFD":

0 commit comments

Comments
 (0)