From c389b9d301dec24b18101e60f7f9142631e6e4dc Mon Sep 17 00:00:00 2001 From: vlad-perevezentsev Date: Wed, 5 Apr 2023 18:43:43 +0200 Subject: [PATCH] Support `out` parameter in dpnp.sqrt() (#1332) * Add support out parameter in dpnp.sqrt() --- dpnp/dpnp_algo/dpnp_algo.pxd | 2 +- dpnp/dpnp_algo/dpnp_algo_trigonometric.pyx | 4 +- dpnp/dpnp_iface_trigonometric.py | 26 +++++++++-- tests/test_sycl_queue.py | 36 +++++++++++++-- tests/test_umath.py | 53 +++++++++++++++++++++- tests/test_usm_type.py | 17 +++++++ 6 files changed, 126 insertions(+), 12 deletions(-) diff --git a/dpnp/dpnp_algo/dpnp_algo.pxd b/dpnp/dpnp_algo/dpnp_algo.pxd index c2bb15102cf..09af5667f8c 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pxd +++ b/dpnp/dpnp_algo/dpnp_algo.pxd @@ -603,7 +603,7 @@ cpdef dpnp_descriptor dpnp_radians(dpnp_descriptor array1) cpdef dpnp_descriptor dpnp_recip(dpnp_descriptor array1) cpdef dpnp_descriptor dpnp_sin(dpnp_descriptor array1, dpnp_descriptor out) cpdef dpnp_descriptor dpnp_sinh(dpnp_descriptor array1) -cpdef dpnp_descriptor dpnp_sqrt(dpnp_descriptor array1) +cpdef dpnp_descriptor dpnp_sqrt(dpnp_descriptor array1, dpnp_descriptor out) cpdef dpnp_descriptor dpnp_square(dpnp_descriptor array1) cpdef dpnp_descriptor dpnp_tan(dpnp_descriptor array1, dpnp_descriptor out) cpdef dpnp_descriptor dpnp_tanh(dpnp_descriptor array1) diff --git a/dpnp/dpnp_algo/dpnp_algo_trigonometric.pyx b/dpnp/dpnp_algo/dpnp_algo_trigonometric.pyx index bf9c4d5e0ed..81c6f3cfc0d 100644 --- a/dpnp/dpnp_algo/dpnp_algo_trigonometric.pyx +++ b/dpnp/dpnp_algo/dpnp_algo_trigonometric.pyx @@ -148,8 +148,8 @@ cpdef utils.dpnp_descriptor dpnp_sinh(utils.dpnp_descriptor x1): return call_fptr_1in_1out_strides(DPNP_FN_SINH_EXT, x1) -cpdef utils.dpnp_descriptor dpnp_sqrt(utils.dpnp_descriptor x1): - return call_fptr_1in_1out_strides(DPNP_FN_SQRT_EXT, x1) +cpdef utils.dpnp_descriptor dpnp_sqrt(utils.dpnp_descriptor x1, utils.dpnp_descriptor out): + return call_fptr_1in_1out_strides(DPNP_FN_SQRT_EXT, x1, dtype=None, out=out, where=True, func_name='sqrt') cpdef utils.dpnp_descriptor dpnp_square(utils.dpnp_descriptor x1): diff --git a/dpnp/dpnp_iface_trigonometric.py b/dpnp/dpnp_iface_trigonometric.py index 098dd19648f..47340107164 100644 --- a/dpnp/dpnp_iface_trigonometric.py +++ b/dpnp/dpnp_iface_trigonometric.py @@ -41,6 +41,7 @@ import numpy +import dpctl.tensor as dpt from dpnp.dpnp_algo import * from dpnp.dpnp_utils import * @@ -906,7 +907,7 @@ def sinh(x1): return call_origin(numpy.sinh, x1, **kwargs) -def sqrt(x1): +def sqrt(x1, /, out = None, **kwargs): """ Return the positive square-root of an array, element-wise. @@ -914,8 +915,11 @@ def sqrt(x1): Limitations ----------- - Input array is supported as :obj:`dpnp.ndarray`. + Input array is supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`. + Parameter `out` is supported as class:`dpnp.ndarray`, class:`dpctl.tensor.usm_ndarray` or + with default value ``None``. Otherwise the function will be executed sequentially on CPU. + Keyword arguments ``kwargs`` are currently unsupported. Input array data types are limited by supported DPNP :ref:`Data types`. Examples @@ -928,11 +932,23 @@ def sqrt(x1): """ - x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False) + x1_desc = ( + dpnp.get_dpnp_descriptor( + x1, copy_when_strides=False, copy_when_nondefault_queue=False + ) + if not kwargs + else None + ) if x1_desc: - return dpnp_sqrt(x1_desc).get_pyobj() + if out is not None: + if not isinstance(out, (dpnp.ndarray, dpt.usm_ndarray)): + raise TypeError("return array must be of supported array type") + out_desc = dpnp.get_dpnp_descriptor(out, copy_when_nondefault_queue=False) or None + else: + out_desc = None + return dpnp_sqrt(x1_desc, out=out_desc).get_pyobj() - return call_origin(numpy.sqrt, x1) + return call_origin(numpy.sqrt, x1, out=out, **kwargs) def square(x1): diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index ab974e426f9..fcea0d82eb8 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -9,6 +9,7 @@ import numpy from numpy.testing import ( + assert_allclose, assert_array_equal, assert_raises ) @@ -218,7 +219,7 @@ def test_array_creation_cross_device(func, args, kwargs, device_x, device_y): dpnp_kwargs = dict(kwargs) dpnp_kwargs['device'] = device_y - + y = getattr(dpnp, func)(*dpnp_args, **dpnp_kwargs) numpy.testing.assert_allclose(y_orig, y) @@ -279,6 +280,8 @@ def test_meshgrid(device_x, device_y): [1., 2.]), pytest.param("sign", [-5., 4.5]), + pytest.param("sqrt", + [1., 3., 9.]), pytest.param("sum", [1., 2.]), pytest.param("trapz", @@ -297,7 +300,7 @@ def test_1in_1out(func, data, device): x = dpnp.array(data, device=device) result = getattr(dpnp, func)(x) - assert_array_equal(result, expected) + assert_allclose(result, expected) expected_queue = x.get_array().sycl_queue result_queue = result.get_array().sycl_queue @@ -529,6 +532,33 @@ def test_random_state(func, args, kwargs, device, usm_type): assert_sycl_queue_equal(res_array.sycl_queue, sycl_queue) +@pytest.mark.usefixtures("allow_fall_back_on_numpy") +@pytest.mark.parametrize( + "func,data", + [ + pytest.param("sqrt", + [0., 1., 2., 3., 4., 5., 6., 7., 8.]), + ], +) +@pytest.mark.parametrize("device", + valid_devices, + ids=[device.filter_string for device in valid_devices]) +def test_out_1in_1out(func, data, device): + x_orig = numpy.array(data) + np_out = getattr(numpy, func)(x_orig) + expected = numpy.empty_like(np_out) + getattr(numpy, func)(x_orig, out=expected) + + x = dpnp.array(data, device=device) + dp_out = getattr(dpnp, func)(x) + result = dpnp.empty_like(dp_out) + getattr(dpnp, func)(x, out=result) + + assert_allclose(result, expected) + + assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue) + + @pytest.mark.usefixtures("allow_fall_back_on_numpy") @pytest.mark.parametrize( "func,data1,data2", @@ -574,7 +604,7 @@ def test_random_state(func, args, kwargs, device, usm_type): @pytest.mark.parametrize("device", valid_devices, ids=[device.filter_string for device in valid_devices]) -def test_out(func, data1, data2, device): +def test_out_2in_1out(func, data1, data2, device): x1_orig = numpy.array(data1) x2_orig = numpy.array(data2) np_out = getattr(numpy, func)(x1_orig, x2_orig) diff --git a/tests/test_umath.py b/tests/test_umath.py index 3a1f4467dce..7b5c4b762d8 100644 --- a/tests/test_umath.py +++ b/tests/test_umath.py @@ -1,6 +1,7 @@ import pytest from .helper import ( - get_all_dtypes + get_all_dtypes, + get_float_dtypes ) import numpy @@ -402,3 +403,53 @@ def test_invalid_shape(self, shape): with pytest.raises(ValueError): dpnp.arctan2(dp_array, dp_array, out=dp_out) + + +class TestSqrt: + @pytest.mark.parametrize("dtype", get_float_dtypes()) + def test_sqrt_ordinary(self, dtype): + array_data = numpy.arange(10) + out = numpy.empty(10, dtype=dtype) + + # DPNP + dp_array = dpnp.array(array_data, dtype=dtype) + dp_out = dpnp.array(out, dtype=dtype) + result = dpnp.sqrt(dp_array, out=dp_out) + + # original + np_array = numpy.array(array_data, dtype=dtype) + expected = numpy.sqrt(np_array, out=out) + + numpy.testing.assert_allclose(expected, result) + numpy.testing.assert_allclose(out, dp_out) + + @pytest.mark.parametrize("dtype", + [numpy.int64, numpy.int32], + ids=['numpy.int64', 'numpy.int32']) + def test_invalid_dtype(self, dtype): + + dp_array = dpnp.arange(10, dtype=dpnp.float32) + dp_out = dpnp.empty(10, dtype=dtype) + + with pytest.raises(ValueError): + dpnp.sqrt(dp_array, out=dp_out) + + @pytest.mark.parametrize("shape", + [(0,), (15, ), (2, 2)], + ids=['(0,)', '(15, )', '(2,2)']) + def test_invalid_shape(self, shape): + + dp_array = dpnp.arange(10, dtype=dpnp.float32) + dp_out = dpnp.empty(shape, dtype=dpnp.float32) + + with pytest.raises(ValueError): + dpnp.sqrt(dp_array, out=dp_out) + + @pytest.mark.parametrize("out", + [4, (), [], (3, 7), [2, 4]], + ids=['4', '()', '[]', '(3, 7)', '[2, 4]']) + def test_invalid_out(self, out): + a = dpnp.arange(10) + + numpy.testing.assert_raises(TypeError, dpnp.sqrt, a, out) + numpy.testing.assert_raises(TypeError, numpy.sqrt, a.asnumpy(), out) diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index 9c48a20fa26..df8575197b3 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -178,6 +178,23 @@ def test_meshgrid(usm_type_x, usm_type_y): assert z[1].usm_type == usm_type_y +@pytest.mark.parametrize( + "func,data", + [ + pytest.param( + "sqrt", + [1.0, 3.0, 9.0], + ), + ], +) +@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types) +def test_1in_1out(func, data, usm_type): + x = dp.array(data, usm_type=usm_type) + res = getattr(dp, func)(x) + assert x.usm_type == usm_type + assert res.usm_type == usm_type + + @pytest.mark.parametrize( "func,data1,data2", [