Skip to content

Commit

Permalink
Merge branch 'master' into fix-gh-1352
Browse files Browse the repository at this point in the history
  • Loading branch information
npolina4 authored Apr 5, 2023
2 parents bc628d0 + c389b9d commit 17f9977
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 12 deletions.
2 changes: 1 addition & 1 deletion dpnp/dpnp_algo/dpnp_algo.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ cpdef dpnp_descriptor dpnp_radians(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_recip(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_sin(dpnp_descriptor array1, dpnp_descriptor out)
cpdef dpnp_descriptor dpnp_sinh(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_sqrt(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_sqrt(dpnp_descriptor array1, dpnp_descriptor out)
cpdef dpnp_descriptor dpnp_square(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_tan(dpnp_descriptor array1, dpnp_descriptor out)
cpdef dpnp_descriptor dpnp_tanh(dpnp_descriptor array1)
4 changes: 2 additions & 2 deletions dpnp/dpnp_algo/dpnp_algo_trigonometric.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ cpdef utils.dpnp_descriptor dpnp_sinh(utils.dpnp_descriptor x1):
return call_fptr_1in_1out_strides(DPNP_FN_SINH_EXT, x1)


cpdef utils.dpnp_descriptor dpnp_sqrt(utils.dpnp_descriptor x1):
return call_fptr_1in_1out_strides(DPNP_FN_SQRT_EXT, x1)
cpdef utils.dpnp_descriptor dpnp_sqrt(utils.dpnp_descriptor x1, utils.dpnp_descriptor out):
return call_fptr_1in_1out_strides(DPNP_FN_SQRT_EXT, x1, dtype=None, out=out, where=True, func_name='sqrt')


cpdef utils.dpnp_descriptor dpnp_square(utils.dpnp_descriptor x1):
Expand Down
26 changes: 21 additions & 5 deletions dpnp/dpnp_iface_trigonometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@


import numpy
import dpctl.tensor as dpt

from dpnp.dpnp_algo import *
from dpnp.dpnp_utils import *
Expand Down Expand Up @@ -906,16 +907,19 @@ def sinh(x1):
return call_origin(numpy.sinh, x1, **kwargs)


def sqrt(x1):
def sqrt(x1, /, out = None, **kwargs):
"""
Return the positive square-root of an array, element-wise.
For full documentation refer to :obj:`numpy.sqrt`.
Limitations
-----------
Input array is supported as :obj:`dpnp.ndarray`.
Input array is supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
Parameter `out` is supported as class:`dpnp.ndarray`, class:`dpctl.tensor.usm_ndarray` or
with default value ``None``.
Otherwise the function will be executed sequentially on CPU.
Keyword arguments ``kwargs`` are currently unsupported.
Input array data types are limited by supported DPNP :ref:`Data types`.
Examples
Expand All @@ -928,11 +932,23 @@ def sqrt(x1):
"""

x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False)
x1_desc = (
dpnp.get_dpnp_descriptor(
x1, copy_when_strides=False, copy_when_nondefault_queue=False
)
if not kwargs
else None
)
if x1_desc:
return dpnp_sqrt(x1_desc).get_pyobj()
if out is not None:
if not isinstance(out, (dpnp.ndarray, dpt.usm_ndarray)):
raise TypeError("return array must be of supported array type")
out_desc = dpnp.get_dpnp_descriptor(out, copy_when_nondefault_queue=False) or None
else:
out_desc = None
return dpnp_sqrt(x1_desc, out=out_desc).get_pyobj()

return call_origin(numpy.sqrt, x1)
return call_origin(numpy.sqrt, x1, out=out, **kwargs)


def square(x1):
Expand Down
36 changes: 33 additions & 3 deletions tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy

from numpy.testing import (
assert_allclose,
assert_array_equal,
assert_raises
)
Expand Down Expand Up @@ -218,7 +219,7 @@ def test_array_creation_cross_device(func, args, kwargs, device_x, device_y):

dpnp_kwargs = dict(kwargs)
dpnp_kwargs['device'] = device_y

y = getattr(dpnp, func)(*dpnp_args, **dpnp_kwargs)
numpy.testing.assert_allclose(y_orig, y)

Expand Down Expand Up @@ -279,6 +280,8 @@ def test_meshgrid(device_x, device_y):
[1., 2.]),
pytest.param("sign",
[-5., 4.5]),
pytest.param("sqrt",
[1., 3., 9.]),
pytest.param("sum",
[1., 2.]),
pytest.param("trapz",
Expand All @@ -297,7 +300,7 @@ def test_1in_1out(func, data, device):
x = dpnp.array(data, device=device)
result = getattr(dpnp, func)(x)

assert_array_equal(result, expected)
assert_allclose(result, expected)

expected_queue = x.get_array().sycl_queue
result_queue = result.get_array().sycl_queue
Expand Down Expand Up @@ -529,6 +532,33 @@ def test_random_state(func, args, kwargs, device, usm_type):
assert_sycl_queue_equal(res_array.sycl_queue, sycl_queue)


@pytest.mark.usefixtures("allow_fall_back_on_numpy")
@pytest.mark.parametrize(
"func,data",
[
pytest.param("sqrt",
[0., 1., 2., 3., 4., 5., 6., 7., 8.]),
],
)
@pytest.mark.parametrize("device",
valid_devices,
ids=[device.filter_string for device in valid_devices])
def test_out_1in_1out(func, data, device):
x_orig = numpy.array(data)
np_out = getattr(numpy, func)(x_orig)
expected = numpy.empty_like(np_out)
getattr(numpy, func)(x_orig, out=expected)

x = dpnp.array(data, device=device)
dp_out = getattr(dpnp, func)(x)
result = dpnp.empty_like(dp_out)
getattr(dpnp, func)(x, out=result)

assert_allclose(result, expected)

assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue)


@pytest.mark.usefixtures("allow_fall_back_on_numpy")
@pytest.mark.parametrize(
"func,data1,data2",
Expand Down Expand Up @@ -574,7 +604,7 @@ def test_random_state(func, args, kwargs, device, usm_type):
@pytest.mark.parametrize("device",
valid_devices,
ids=[device.filter_string for device in valid_devices])
def test_out(func, data1, data2, device):
def test_out_2in_1out(func, data1, data2, device):
x1_orig = numpy.array(data1)
x2_orig = numpy.array(data2)
np_out = getattr(numpy, func)(x1_orig, x2_orig)
Expand Down
53 changes: 52 additions & 1 deletion tests/test_umath.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from .helper import (
get_all_dtypes
get_all_dtypes,
get_float_dtypes
)

import numpy
Expand Down Expand Up @@ -402,3 +403,53 @@ def test_invalid_shape(self, shape):

with pytest.raises(ValueError):
dpnp.arctan2(dp_array, dp_array, out=dp_out)


class TestSqrt:
@pytest.mark.parametrize("dtype", get_float_dtypes())
def test_sqrt_ordinary(self, dtype):
array_data = numpy.arange(10)
out = numpy.empty(10, dtype=dtype)

# DPNP
dp_array = dpnp.array(array_data, dtype=dtype)
dp_out = dpnp.array(out, dtype=dtype)
result = dpnp.sqrt(dp_array, out=dp_out)

# original
np_array = numpy.array(array_data, dtype=dtype)
expected = numpy.sqrt(np_array, out=out)

numpy.testing.assert_allclose(expected, result)
numpy.testing.assert_allclose(out, dp_out)

@pytest.mark.parametrize("dtype",
[numpy.int64, numpy.int32],
ids=['numpy.int64', 'numpy.int32'])
def test_invalid_dtype(self, dtype):

dp_array = dpnp.arange(10, dtype=dpnp.float32)
dp_out = dpnp.empty(10, dtype=dtype)

with pytest.raises(ValueError):
dpnp.sqrt(dp_array, out=dp_out)

@pytest.mark.parametrize("shape",
[(0,), (15, ), (2, 2)],
ids=['(0,)', '(15, )', '(2,2)'])
def test_invalid_shape(self, shape):

dp_array = dpnp.arange(10, dtype=dpnp.float32)
dp_out = dpnp.empty(shape, dtype=dpnp.float32)

with pytest.raises(ValueError):
dpnp.sqrt(dp_array, out=dp_out)

@pytest.mark.parametrize("out",
[4, (), [], (3, 7), [2, 4]],
ids=['4', '()', '[]', '(3, 7)', '[2, 4]'])
def test_invalid_out(self, out):
a = dpnp.arange(10)

numpy.testing.assert_raises(TypeError, dpnp.sqrt, a, out)
numpy.testing.assert_raises(TypeError, numpy.sqrt, a.asnumpy(), out)
17 changes: 17 additions & 0 deletions tests/test_usm_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,23 @@ def test_meshgrid(usm_type_x, usm_type_y):
assert z[1].usm_type == usm_type_y


@pytest.mark.parametrize(
"func,data",
[
pytest.param(
"sqrt",
[1.0, 3.0, 9.0],
),
],
)
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
def test_1in_1out(func, data, usm_type):
x = dp.array(data, usm_type=usm_type)
res = getattr(dp, func)(x)
assert x.usm_type == usm_type
assert res.usm_type == usm_type


@pytest.mark.parametrize(
"func,data1,data2",
[
Expand Down

0 comments on commit 17f9977

Please sign in to comment.