From ce3cc9d4008dfa6becb1354917d36c744f8b32fc Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Mon, 6 Mar 2023 19:03:18 +0100 Subject: [PATCH] Add support out parameter in dpnp.sqrt() --- dpnp/dpnp_algo/dpnp_algo.pxd | 2 +- dpnp/dpnp_algo/dpnp_algo_trigonometric.pyx | 4 +- dpnp/dpnp_iface_trigonometric.py | 17 ++++++-- tests/test_umath.py | 51 ++++++++++++++++++++++ 4 files changed, 67 insertions(+), 7 deletions(-) diff --git a/dpnp/dpnp_algo/dpnp_algo.pxd b/dpnp/dpnp_algo/dpnp_algo.pxd index da1efddd3cc..cd8ebea666b 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pxd +++ b/dpnp/dpnp_algo/dpnp_algo.pxd @@ -604,7 +604,7 @@ cpdef dpnp_descriptor dpnp_radians(dpnp_descriptor array1) cpdef dpnp_descriptor dpnp_recip(dpnp_descriptor array1) cpdef dpnp_descriptor dpnp_sin(dpnp_descriptor array1, dpnp_descriptor out) cpdef dpnp_descriptor dpnp_sinh(dpnp_descriptor array1) -cpdef dpnp_descriptor dpnp_sqrt(dpnp_descriptor array1) +cpdef dpnp_descriptor dpnp_sqrt(dpnp_descriptor array1, dpnp_descriptor out) cpdef dpnp_descriptor dpnp_square(dpnp_descriptor array1) cpdef dpnp_descriptor dpnp_tan(dpnp_descriptor array1, dpnp_descriptor out) cpdef dpnp_descriptor dpnp_tanh(dpnp_descriptor array1) diff --git a/dpnp/dpnp_algo/dpnp_algo_trigonometric.pyx b/dpnp/dpnp_algo/dpnp_algo_trigonometric.pyx index bf9c4d5e0ed..81c6f3cfc0d 100644 --- a/dpnp/dpnp_algo/dpnp_algo_trigonometric.pyx +++ b/dpnp/dpnp_algo/dpnp_algo_trigonometric.pyx @@ -148,8 +148,8 @@ cpdef utils.dpnp_descriptor dpnp_sinh(utils.dpnp_descriptor x1): return call_fptr_1in_1out_strides(DPNP_FN_SINH_EXT, x1) -cpdef utils.dpnp_descriptor dpnp_sqrt(utils.dpnp_descriptor x1): - return call_fptr_1in_1out_strides(DPNP_FN_SQRT_EXT, x1) +cpdef utils.dpnp_descriptor dpnp_sqrt(utils.dpnp_descriptor x1, utils.dpnp_descriptor out): + return call_fptr_1in_1out_strides(DPNP_FN_SQRT_EXT, x1, dtype=None, out=out, where=True, func_name='sqrt') cpdef utils.dpnp_descriptor dpnp_square(utils.dpnp_descriptor x1): diff --git a/dpnp/dpnp_iface_trigonometric.py b/dpnp/dpnp_iface_trigonometric.py index 098dd19648f..b680c6f6467 100644 --- a/dpnp/dpnp_iface_trigonometric.py +++ b/dpnp/dpnp_iface_trigonometric.py @@ -41,6 +41,7 @@ import numpy +import dpctl.tensor as dpt from dpnp.dpnp_algo import * from dpnp.dpnp_utils import * @@ -906,7 +907,7 @@ def sinh(x1): return call_origin(numpy.sinh, x1, **kwargs) -def sqrt(x1): +def sqrt(x1, /, out = None, **kwargs): """ Return the positive square-root of an array, element-wise. @@ -916,6 +917,8 @@ def sqrt(x1): ----------- Input array is supported as :obj:`dpnp.ndarray`. Otherwise the function will be executed sequentially on CPU. + Parameter ``out`` is supported as :obj:`dpnp.ndarray`, obj:`dpt.usm_ndarray` and as default value ``None``. + Keyword arguments ``kwargs`` are currently unsupported. Input array data types are limited by supported DPNP :ref:`Data types`. Examples @@ -929,10 +932,16 @@ def sqrt(x1): """ x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False) - if x1_desc: - return dpnp_sqrt(x1_desc).get_pyobj() + if x1_desc and not kwargs: + if out is not None: + if not isinstance(out, (dpnp.ndarray, dpt.usm_ndarray)): + raise TypeError("return array must be of supported array type") + out_desc = dpnp.get_dpnp_descriptor(out, copy_when_nondefault_queue=False) + else: + out_desc = None + return dpnp_sqrt(x1_desc, out = out_desc).get_pyobj() - return call_origin(numpy.sqrt, x1) + return call_origin(numpy.sqrt, x1, out=out, **kwargs) def square(x1): diff --git a/tests/test_umath.py b/tests/test_umath.py index 6122b253ca3..58231fe1cb8 100644 --- a/tests/test_umath.py +++ b/tests/test_umath.py @@ -1,4 +1,5 @@ import pytest +from .helper import get_all_dtypes, get_float_dtypes import numpy import dpnp @@ -392,3 +393,53 @@ def test_invalid_shape(self, shape): with pytest.raises(ValueError): dpnp.arctan2(dp_array, dp_array, out=dp_out) + + +class TestSqrt: + @pytest.mark.parametrize("dtype", get_float_dtypes()) + def test_sqrt_ordinary(self, dtype): + array_data = numpy.arange(10) + out = numpy.empty(10, dtype=dtype) + + # DPNP + dp_array = dpnp.array(array_data, dtype=dtype) + dp_out = dpnp.array(out, dtype=dtype) + result = dpnp.sqrt(dp_array, out=dp_out) + + # original + np_array = numpy.array(array_data, dtype=dtype) + expected = numpy.sqrt(np_array, out=out) + + numpy.testing.assert_array_equal(expected, result) + numpy.testing.assert_array_equal(out, dp_out) + + @pytest.mark.parametrize("dtype", + [numpy.int64, numpy.int32], + ids=['numpy.int64', 'numpy.int32']) + def test_invalid_dtype(self, dtype): + + dp_array = dpnp.arange(10, dtype=dpnp.float32) + dp_out = dpnp.empty(10, dtype=dtype) + + with pytest.raises(ValueError): + dpnp.sqrt(dp_array, out=dp_out) + + @pytest.mark.parametrize("shape", + [(0,), (15, ), (2, 2)], + ids=['(0,)', '(15, )', '(2,2)']) + def test_invalid_shape(self, shape): + + dp_array = dpnp.arange(10, dtype=dpnp.float32) + dp_out = dpnp.empty(shape, dtype=dpnp.float32) + + with pytest.raises(ValueError): + dpnp.sqrt(dp_array, out=dp_out) + + @pytest.mark.parametrize("out", + [4, (), [], (3, 7), [2, 4]], + ids=['4', '()', '[]', '(3, 7)', '[2, 4]']) + def test_invalid_out(self, out): + a = dpnp.arange(10) + + numpy.testing.assert_raises(TypeError, dpnp.sqrt, a, out) + numpy.testing.assert_raises(TypeError, numpy.sqrt, a.asnumpy(), out)