Skip to content

Commit

Permalink
Add support out parameter in dpnp.sqrt()
Browse files Browse the repository at this point in the history
  • Loading branch information
vlad-perevezentsev committed Mar 7, 2023
1 parent 29a2063 commit ce3cc9d
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 7 deletions.
2 changes: 1 addition & 1 deletion dpnp/dpnp_algo/dpnp_algo.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ cpdef dpnp_descriptor dpnp_radians(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_recip(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_sin(dpnp_descriptor array1, dpnp_descriptor out)
cpdef dpnp_descriptor dpnp_sinh(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_sqrt(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_sqrt(dpnp_descriptor array1, dpnp_descriptor out)
cpdef dpnp_descriptor dpnp_square(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_tan(dpnp_descriptor array1, dpnp_descriptor out)
cpdef dpnp_descriptor dpnp_tanh(dpnp_descriptor array1)
4 changes: 2 additions & 2 deletions dpnp/dpnp_algo/dpnp_algo_trigonometric.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ cpdef utils.dpnp_descriptor dpnp_sinh(utils.dpnp_descriptor x1):
return call_fptr_1in_1out_strides(DPNP_FN_SINH_EXT, x1)


cpdef utils.dpnp_descriptor dpnp_sqrt(utils.dpnp_descriptor x1):
return call_fptr_1in_1out_strides(DPNP_FN_SQRT_EXT, x1)
cpdef utils.dpnp_descriptor dpnp_sqrt(utils.dpnp_descriptor x1, utils.dpnp_descriptor out):
return call_fptr_1in_1out_strides(DPNP_FN_SQRT_EXT, x1, dtype=None, out=out, where=True, func_name='sqrt')


cpdef utils.dpnp_descriptor dpnp_square(utils.dpnp_descriptor x1):
Expand Down
17 changes: 13 additions & 4 deletions dpnp/dpnp_iface_trigonometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@


import numpy
import dpctl.tensor as dpt

from dpnp.dpnp_algo import *
from dpnp.dpnp_utils import *
Expand Down Expand Up @@ -906,7 +907,7 @@ def sinh(x1):
return call_origin(numpy.sinh, x1, **kwargs)


def sqrt(x1):
def sqrt(x1, /, out = None, **kwargs):
"""
Return the positive square-root of an array, element-wise.
Expand All @@ -916,6 +917,8 @@ def sqrt(x1):
-----------
Input array is supported as :obj:`dpnp.ndarray`.
Otherwise the function will be executed sequentially on CPU.
Parameter ``out`` is supported as :obj:`dpnp.ndarray`, obj:`dpt.usm_ndarray` and as default value ``None``.
Keyword arguments ``kwargs`` are currently unsupported.
Input array data types are limited by supported DPNP :ref:`Data types`.
Examples
Expand All @@ -929,10 +932,16 @@ def sqrt(x1):
"""

x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False)
if x1_desc:
return dpnp_sqrt(x1_desc).get_pyobj()
if x1_desc and not kwargs:
if out is not None:
if not isinstance(out, (dpnp.ndarray, dpt.usm_ndarray)):
raise TypeError("return array must be of supported array type")
out_desc = dpnp.get_dpnp_descriptor(out, copy_when_nondefault_queue=False)
else:
out_desc = None
return dpnp_sqrt(x1_desc, out = out_desc).get_pyobj()

return call_origin(numpy.sqrt, x1)
return call_origin(numpy.sqrt, x1, out=out, **kwargs)


def square(x1):
Expand Down
51 changes: 51 additions & 0 deletions tests/test_umath.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from .helper import get_all_dtypes, get_float_dtypes

import numpy
import dpnp
Expand Down Expand Up @@ -392,3 +393,53 @@ def test_invalid_shape(self, shape):

with pytest.raises(ValueError):
dpnp.arctan2(dp_array, dp_array, out=dp_out)


class TestSqrt:
@pytest.mark.parametrize("dtype", get_float_dtypes())
def test_sqrt_ordinary(self, dtype):
array_data = numpy.arange(10)
out = numpy.empty(10, dtype=dtype)

# DPNP
dp_array = dpnp.array(array_data, dtype=dtype)
dp_out = dpnp.array(out, dtype=dtype)
result = dpnp.sqrt(dp_array, out=dp_out)

# original
np_array = numpy.array(array_data, dtype=dtype)
expected = numpy.sqrt(np_array, out=out)

numpy.testing.assert_array_equal(expected, result)
numpy.testing.assert_array_equal(out, dp_out)

@pytest.mark.parametrize("dtype",
[numpy.int64, numpy.int32],
ids=['numpy.int64', 'numpy.int32'])
def test_invalid_dtype(self, dtype):

dp_array = dpnp.arange(10, dtype=dpnp.float32)
dp_out = dpnp.empty(10, dtype=dtype)

with pytest.raises(ValueError):
dpnp.sqrt(dp_array, out=dp_out)

@pytest.mark.parametrize("shape",
[(0,), (15, ), (2, 2)],
ids=['(0,)', '(15, )', '(2,2)'])
def test_invalid_shape(self, shape):

dp_array = dpnp.arange(10, dtype=dpnp.float32)
dp_out = dpnp.empty(shape, dtype=dpnp.float32)

with pytest.raises(ValueError):
dpnp.sqrt(dp_array, out=dp_out)

@pytest.mark.parametrize("out",
[4, (), [], (3, 7), [2, 4]],
ids=['4', '()', '[]', '(3, 7)', '[2, 4]'])
def test_invalid_out(self, out):
a = dpnp.arange(10)

numpy.testing.assert_raises(TypeError, dpnp.sqrt, a, out)
numpy.testing.assert_raises(TypeError, numpy.sqrt, a.asnumpy(), out)

0 comments on commit ce3cc9d

Please sign in to comment.