Skip to content

Commit 17f9977

Browse files
authored
Merge branch 'master' into fix-gh-1352
2 parents bc628d0 + c389b9d commit 17f9977

File tree

6 files changed

+126
-12
lines changed

6 files changed

+126
-12
lines changed

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ cpdef dpnp_descriptor dpnp_radians(dpnp_descriptor array1)
603603
cpdef dpnp_descriptor dpnp_recip(dpnp_descriptor array1)
604604
cpdef dpnp_descriptor dpnp_sin(dpnp_descriptor array1, dpnp_descriptor out)
605605
cpdef dpnp_descriptor dpnp_sinh(dpnp_descriptor array1)
606-
cpdef dpnp_descriptor dpnp_sqrt(dpnp_descriptor array1)
606+
cpdef dpnp_descriptor dpnp_sqrt(dpnp_descriptor array1, dpnp_descriptor out)
607607
cpdef dpnp_descriptor dpnp_square(dpnp_descriptor array1)
608608
cpdef dpnp_descriptor dpnp_tan(dpnp_descriptor array1, dpnp_descriptor out)
609609
cpdef dpnp_descriptor dpnp_tanh(dpnp_descriptor array1)

dpnp/dpnp_algo/dpnp_algo_trigonometric.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ cpdef utils.dpnp_descriptor dpnp_sinh(utils.dpnp_descriptor x1):
148148
return call_fptr_1in_1out_strides(DPNP_FN_SINH_EXT, x1)
149149

150150

151-
cpdef utils.dpnp_descriptor dpnp_sqrt(utils.dpnp_descriptor x1):
152-
return call_fptr_1in_1out_strides(DPNP_FN_SQRT_EXT, x1)
151+
cpdef utils.dpnp_descriptor dpnp_sqrt(utils.dpnp_descriptor x1, utils.dpnp_descriptor out):
152+
return call_fptr_1in_1out_strides(DPNP_FN_SQRT_EXT, x1, dtype=None, out=out, where=True, func_name='sqrt')
153153

154154

155155
cpdef utils.dpnp_descriptor dpnp_square(utils.dpnp_descriptor x1):

dpnp/dpnp_iface_trigonometric.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141

4242

4343
import numpy
44+
import dpctl.tensor as dpt
4445

4546
from dpnp.dpnp_algo import *
4647
from dpnp.dpnp_utils import *
@@ -906,16 +907,19 @@ def sinh(x1):
906907
return call_origin(numpy.sinh, x1, **kwargs)
907908

908909

909-
def sqrt(x1):
910+
def sqrt(x1, /, out = None, **kwargs):
910911
"""
911912
Return the positive square-root of an array, element-wise.
912913
913914
For full documentation refer to :obj:`numpy.sqrt`.
914915
915916
Limitations
916917
-----------
917-
Input array is supported as :obj:`dpnp.ndarray`.
918+
Input array is supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
919+
Parameter `out` is supported as class:`dpnp.ndarray`, class:`dpctl.tensor.usm_ndarray` or
920+
with default value ``None``.
918921
Otherwise the function will be executed sequentially on CPU.
922+
Keyword arguments ``kwargs`` are currently unsupported.
919923
Input array data types are limited by supported DPNP :ref:`Data types`.
920924
921925
Examples
@@ -928,11 +932,23 @@ def sqrt(x1):
928932
929933
"""
930934

931-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False)
935+
x1_desc = (
936+
dpnp.get_dpnp_descriptor(
937+
x1, copy_when_strides=False, copy_when_nondefault_queue=False
938+
)
939+
if not kwargs
940+
else None
941+
)
932942
if x1_desc:
933-
return dpnp_sqrt(x1_desc).get_pyobj()
943+
if out is not None:
944+
if not isinstance(out, (dpnp.ndarray, dpt.usm_ndarray)):
945+
raise TypeError("return array must be of supported array type")
946+
out_desc = dpnp.get_dpnp_descriptor(out, copy_when_nondefault_queue=False) or None
947+
else:
948+
out_desc = None
949+
return dpnp_sqrt(x1_desc, out=out_desc).get_pyobj()
934950

935-
return call_origin(numpy.sqrt, x1)
951+
return call_origin(numpy.sqrt, x1, out=out, **kwargs)
936952

937953

938954
def square(x1):

tests/test_sycl_queue.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy
1010

1111
from numpy.testing import (
12+
assert_allclose,
1213
assert_array_equal,
1314
assert_raises
1415
)
@@ -218,7 +219,7 @@ def test_array_creation_cross_device(func, args, kwargs, device_x, device_y):
218219

219220
dpnp_kwargs = dict(kwargs)
220221
dpnp_kwargs['device'] = device_y
221-
222+
222223
y = getattr(dpnp, func)(*dpnp_args, **dpnp_kwargs)
223224
numpy.testing.assert_allclose(y_orig, y)
224225

@@ -279,6 +280,8 @@ def test_meshgrid(device_x, device_y):
279280
[1., 2.]),
280281
pytest.param("sign",
281282
[-5., 4.5]),
283+
pytest.param("sqrt",
284+
[1., 3., 9.]),
282285
pytest.param("sum",
283286
[1., 2.]),
284287
pytest.param("trapz",
@@ -297,7 +300,7 @@ def test_1in_1out(func, data, device):
297300
x = dpnp.array(data, device=device)
298301
result = getattr(dpnp, func)(x)
299302

300-
assert_array_equal(result, expected)
303+
assert_allclose(result, expected)
301304

302305
expected_queue = x.get_array().sycl_queue
303306
result_queue = result.get_array().sycl_queue
@@ -529,6 +532,33 @@ def test_random_state(func, args, kwargs, device, usm_type):
529532
assert_sycl_queue_equal(res_array.sycl_queue, sycl_queue)
530533

531534

535+
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
536+
@pytest.mark.parametrize(
537+
"func,data",
538+
[
539+
pytest.param("sqrt",
540+
[0., 1., 2., 3., 4., 5., 6., 7., 8.]),
541+
],
542+
)
543+
@pytest.mark.parametrize("device",
544+
valid_devices,
545+
ids=[device.filter_string for device in valid_devices])
546+
def test_out_1in_1out(func, data, device):
547+
x_orig = numpy.array(data)
548+
np_out = getattr(numpy, func)(x_orig)
549+
expected = numpy.empty_like(np_out)
550+
getattr(numpy, func)(x_orig, out=expected)
551+
552+
x = dpnp.array(data, device=device)
553+
dp_out = getattr(dpnp, func)(x)
554+
result = dpnp.empty_like(dp_out)
555+
getattr(dpnp, func)(x, out=result)
556+
557+
assert_allclose(result, expected)
558+
559+
assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue)
560+
561+
532562
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
533563
@pytest.mark.parametrize(
534564
"func,data1,data2",
@@ -574,7 +604,7 @@ def test_random_state(func, args, kwargs, device, usm_type):
574604
@pytest.mark.parametrize("device",
575605
valid_devices,
576606
ids=[device.filter_string for device in valid_devices])
577-
def test_out(func, data1, data2, device):
607+
def test_out_2in_1out(func, data1, data2, device):
578608
x1_orig = numpy.array(data1)
579609
x2_orig = numpy.array(data2)
580610
np_out = getattr(numpy, func)(x1_orig, x2_orig)

tests/test_umath.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
from .helper import (
3-
get_all_dtypes
3+
get_all_dtypes,
4+
get_float_dtypes
45
)
56

67
import numpy
@@ -402,3 +403,53 @@ def test_invalid_shape(self, shape):
402403

403404
with pytest.raises(ValueError):
404405
dpnp.arctan2(dp_array, dp_array, out=dp_out)
406+
407+
408+
class TestSqrt:
409+
@pytest.mark.parametrize("dtype", get_float_dtypes())
410+
def test_sqrt_ordinary(self, dtype):
411+
array_data = numpy.arange(10)
412+
out = numpy.empty(10, dtype=dtype)
413+
414+
# DPNP
415+
dp_array = dpnp.array(array_data, dtype=dtype)
416+
dp_out = dpnp.array(out, dtype=dtype)
417+
result = dpnp.sqrt(dp_array, out=dp_out)
418+
419+
# original
420+
np_array = numpy.array(array_data, dtype=dtype)
421+
expected = numpy.sqrt(np_array, out=out)
422+
423+
numpy.testing.assert_allclose(expected, result)
424+
numpy.testing.assert_allclose(out, dp_out)
425+
426+
@pytest.mark.parametrize("dtype",
427+
[numpy.int64, numpy.int32],
428+
ids=['numpy.int64', 'numpy.int32'])
429+
def test_invalid_dtype(self, dtype):
430+
431+
dp_array = dpnp.arange(10, dtype=dpnp.float32)
432+
dp_out = dpnp.empty(10, dtype=dtype)
433+
434+
with pytest.raises(ValueError):
435+
dpnp.sqrt(dp_array, out=dp_out)
436+
437+
@pytest.mark.parametrize("shape",
438+
[(0,), (15, ), (2, 2)],
439+
ids=['(0,)', '(15, )', '(2,2)'])
440+
def test_invalid_shape(self, shape):
441+
442+
dp_array = dpnp.arange(10, dtype=dpnp.float32)
443+
dp_out = dpnp.empty(shape, dtype=dpnp.float32)
444+
445+
with pytest.raises(ValueError):
446+
dpnp.sqrt(dp_array, out=dp_out)
447+
448+
@pytest.mark.parametrize("out",
449+
[4, (), [], (3, 7), [2, 4]],
450+
ids=['4', '()', '[]', '(3, 7)', '[2, 4]'])
451+
def test_invalid_out(self, out):
452+
a = dpnp.arange(10)
453+
454+
numpy.testing.assert_raises(TypeError, dpnp.sqrt, a, out)
455+
numpy.testing.assert_raises(TypeError, numpy.sqrt, a.asnumpy(), out)

tests/test_usm_type.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,23 @@ def test_meshgrid(usm_type_x, usm_type_y):
178178
assert z[1].usm_type == usm_type_y
179179

180180

181+
@pytest.mark.parametrize(
182+
"func,data",
183+
[
184+
pytest.param(
185+
"sqrt",
186+
[1.0, 3.0, 9.0],
187+
),
188+
],
189+
)
190+
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
191+
def test_1in_1out(func, data, usm_type):
192+
x = dp.array(data, usm_type=usm_type)
193+
res = getattr(dp, func)(x)
194+
assert x.usm_type == usm_type
195+
assert res.usm_type == usm_type
196+
197+
181198
@pytest.mark.parametrize(
182199
"func,data1,data2",
183200
[

0 commit comments

Comments
 (0)