diff --git a/dpnp/dpnp_algo/dpnp_algo_mathematical.pyx b/dpnp/dpnp_algo/dpnp_algo_mathematical.pyx index 5d937a7008b..b0a74ca587e 100644 --- a/dpnp/dpnp_algo/dpnp_algo_mathematical.pyx +++ b/dpnp/dpnp_algo/dpnp_algo_mathematical.pyx @@ -631,6 +631,9 @@ cpdef utils.dpnp_descriptor dpnp_sum(utils.dpnp_descriptor x1, usm_type=x1_obj.usm_type, sycl_queue=x1_obj.sycl_queue) + if x1.size == 0 and axis is None: + return result + result_sycl_queue = result.get_array().sycl_queue cdef c_dpctl.SyclQueue q = result_sycl_queue diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index 525180e6106..98dcc71d31a 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -1629,10 +1629,16 @@ def sum(x1, axis=None, dtype=None, out=None, keepdims=False, initial=None, where if where is not True: pass else: + if dpnp.isscalar(out): + raise TypeError("output must be an array") out_desc = dpnp.get_dpnp_descriptor(out, copy_when_nondefault_queue=False) if out is not None else None result_obj = dpnp_sum(x1_desc, axis, dtype, out_desc, keepdims, initial, where).get_pyobj() result = dpnp.convert_single_elem_array_to_scalar(result_obj, keepdims) + if x1_desc.size == 0 and axis is None: + result = dpnp.zeros_like(result) + if out is not None: + out[...] = result return result return call_origin(numpy.sum, x1, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where) diff --git a/tests/test_mathematical.py b/tests/test_mathematical.py index ad16a9c7555..5f0d73b23b7 100644 --- a/tests/test_mathematical.py +++ b/tests/test_mathematical.py @@ -923,3 +923,35 @@ def test_float_to_inf(self): dpnp_res = dpnp.array(a) ** dpnp.array(b) assert_allclose(numpy_res, dpnp_res.asnumpy()) + + +@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True, no_bool=True)) +@pytest.mark.parametrize("axis", [None, 0, 1, 2, 3]) +def test_sum_empty(dtype, axis): + a = numpy.empty((1, 2, 0, 4), dtype=dtype) + numpy_res = a.sum(axis=axis) + dpnp_res = dpnp.array(a).sum(axis=axis) + assert_array_equal(numpy_res, dpnp_res.asnumpy()) + + +@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True, no_bool=True)) +def test_sum_empty_out(dtype): + a = dpnp.empty((1, 2, 0, 4), dtype=dtype) + out = dpnp.ones(()) + res = a.sum(out=out) + assert_array_equal(out.asnumpy(), res.asnumpy()) + assert_array_equal(out.asnumpy(), numpy.array(0, dtype=dtype)) + + +@pytest.mark.parametrize("shape", [(), (1, 2, 3), (1, 0, 2), (10), (3, 3, 3), (5, 5), (0, 6)]) +@pytest.mark.parametrize("dtype_in", get_all_dtypes(no_complex=True, no_bool=True)) +@pytest.mark.parametrize("dtype_out", get_all_dtypes(no_complex=True, no_bool=True)) +def test_sum(shape, dtype_in, dtype_out): + a_np = numpy.ones(shape, dtype=dtype_in) + a = dpnp.ones(shape, dtype=dtype_in) + axes = [None, 0, 1, 2] + for axis in axes: + if axis is None or axis < a.ndim: + numpy_res = a_np.sum(axis=axis, dtype=dtype_out) + dpnp_res = a.sum(axis=axis, dtype=dtype_out) + assert_array_equal(numpy_res, dpnp_res.asnumpy())