Skip to content

Commit d41ed51

Browse files
antonwolfyvlad-perevezentsevnpolina4
authored
Support parameter out in dpnp.add() (#1329)
* Suppport parameter out in dpnp.add() * Update tests/test_mathematical.py Co-authored-by: vlad-perevezentsev <vladislav.perevezentsev@intel.com> * Update tests/test_mathematical.py Co-authored-by: vlad-perevezentsev <vladislav.perevezentsev@intel.com> * Update tests/test_mathematical.py * Update tests/test_mathematical.py Co-authored-by: Natalia Polina <natalia.polina@intel.com> * Use internal _check_nd_call() function which is common for mathematical ones with 2 input arrays * Add more test for 'out' parameter --------- Co-authored-by: vlad-perevezentsev <vladislav.perevezentsev@intel.com> Co-authored-by: Natalia Polina <natalia.polina@intel.com>
1 parent ee8f15a commit d41ed51

File tree

5 files changed

+107
-39
lines changed

5 files changed

+107
-39
lines changed

dpnp/dpnp_array.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,10 @@ def __gt__(self, other):
194194
return dpnp.greater(self, other)
195195

196196
# '__hash__',
197-
# '__iadd__',
197+
198+
def __iadd__(self, other):
199+
dpnp.add(self, other, out=self)
200+
return self
198201

199202
def __iand__(self, other):
200203
dpnp.bitwise_and(self, other, out=self)

dpnp/dpnp_iface_mathematical.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def add(x1,
235235
-----------
236236
Parameters `x1` and `x2` are supported as either scalar, :class:`dpnp.ndarray`
237237
or :class:`dpctl.tensor.usm_ndarray`, but both `x1` and `x2` can not be scalars at the same time.
238-
Parameters `out`, `where`, `dtype` and `subok` are supported with their default values.
238+
Parameters `where`, `dtype` and `subok` are supported with their default values.
239239
Keyword arguments ``kwargs`` are currently unsupported.
240240
Otherwise the function will be executed sequentially on CPU.
241241
Input array data types are limited by supported DPNP :ref:`Data types`.
@@ -251,29 +251,7 @@ def add(x1,
251251
252252
"""
253253

254-
if out is not None:
255-
pass
256-
elif where is not True:
257-
pass
258-
elif dtype is not None:
259-
pass
260-
elif subok is not True:
261-
pass
262-
elif dpnp.isscalar(x1) and dpnp.isscalar(x2):
263-
# at least either x1 or x2 has to be an array
264-
pass
265-
else:
266-
# get USM type and queue to copy scalar from the host memory into a USM allocation
267-
usm_type, queue = get_usm_allocations([x1, x2]) if dpnp.isscalar(x1) or dpnp.isscalar(x2) else (None, None)
268-
269-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False,
270-
alloc_usm_type=usm_type, alloc_queue=queue)
271-
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False,
272-
alloc_usm_type=usm_type, alloc_queue=queue)
273-
if x1_desc and x2_desc:
274-
return dpnp_add(x1_desc, x2_desc, dtype=dtype, out=out, where=where).get_pyobj()
275-
276-
return call_origin(numpy.add, x1, x2, out=out, where=where, dtype=dtype, subok=subok, **kwargs)
254+
return _check_nd_call(numpy.add, dpnp_add, x1, x2, out=out, where=where, dtype=dtype, subok=subok, **kwargs)
277255

278256

279257
def around(x1, decimals=0, out=None):

tests/test_mathematical.py

Lines changed: 94 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
from .helper import (
33
get_all_dtypes,
4+
get_float_complex_dtypes,
45
is_cpu_device,
56
is_win_platform
67
)
@@ -634,24 +635,108 @@ def test_invalid_shape(self, shape):
634635
dpnp.trunc(dp_array, out=dp_out)
635636

636637

638+
class TestAdd:
639+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
640+
def test_add(self, dtype):
641+
array1_data = numpy.arange(10)
642+
array2_data = numpy.arange(5, 15)
643+
out = numpy.empty(10, dtype=dtype)
644+
645+
# DPNP
646+
dp_array1 = dpnp.array(array1_data, dtype=dtype)
647+
dp_array2 = dpnp.array(array2_data, dtype=dtype)
648+
dp_out = dpnp.array(out, dtype=dtype)
649+
result = dpnp.add(dp_array1, dp_array2, out=dp_out)
650+
651+
# original
652+
np_array1 = numpy.array(array1_data, dtype=dtype)
653+
np_array2 = numpy.array(array2_data, dtype=dtype)
654+
expected = numpy.add(np_array1, np_array2, out=out)
655+
656+
assert_allclose(expected, result)
657+
assert_allclose(out, dp_out)
658+
659+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
660+
def test_out_dtypes(self, dtype):
661+
size = 2 if dtype == dpnp.bool else 10
662+
663+
np_array1 = numpy.arange(size, 2 * size, dtype=dtype)
664+
np_array2 = numpy.arange(size, dtype=dtype)
665+
np_out = numpy.empty(size, dtype=numpy.complex64)
666+
expected = numpy.add(np_array1, np_array2, out=np_out)
667+
668+
dp_array1 = dpnp.arange(size, 2 * size, dtype=dtype)
669+
dp_array2 = dpnp.arange(size, dtype=dtype)
670+
dp_out = dpnp.empty(size, dtype=dpnp.complex64)
671+
result = dpnp.add(dp_array1, dp_array2, out=dp_out)
672+
673+
assert_array_equal(expected, result)
674+
675+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
676+
def test_out_overlap(self, dtype):
677+
size = 1 if dtype == dpnp.bool else 15
678+
679+
np_a = numpy.arange(2 * size, dtype=dtype)
680+
expected = numpy.add(np_a[size::], np_a[::2], out=np_a[:size:])
681+
682+
dp_a = dpnp.arange(2 * size, dtype=dtype)
683+
result = dpnp.add(dp_a[size::], dp_a[::2], out=dp_a[:size:])
684+
685+
assert_allclose(expected, result)
686+
assert_allclose(dp_a, np_a)
687+
688+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_none=True))
689+
def test_inplace_strided_out(self, dtype):
690+
size = 21
691+
692+
np_a = numpy.arange(size, dtype=dtype)
693+
np_a[::3] += 4
694+
695+
dp_a = dpnp.arange(size, dtype=dtype)
696+
dp_a[::3] += 4
697+
698+
assert_allclose(dp_a, np_a)
699+
700+
@pytest.mark.parametrize("shape",
701+
[(0,), (15, ), (2, 2)],
702+
ids=['(0,)', '(15, )', '(2,2)'])
703+
def test_invalid_shape(self, shape):
704+
dp_array1 = dpnp.arange(10, dtype=dpnp.float64)
705+
dp_array2 = dpnp.arange(5, 15, dtype=dpnp.float64)
706+
dp_out = dpnp.empty(shape, dtype=dpnp.float64)
707+
708+
with pytest.raises(ValueError):
709+
dpnp.add(dp_array1, dp_array2, out=dp_out)
710+
711+
@pytest.mark.parametrize("out",
712+
[4, (), [], (3, 7), [2, 4]],
713+
ids=['4', '()', '[]', '(3, 7)', '[2, 4]'])
714+
def test_invalid_out(self, out):
715+
a = dpnp.arange(10)
716+
717+
assert_raises(TypeError, dpnp.add, a, 2, out)
718+
assert_raises(TypeError, numpy.add, a.asnumpy(), 2, out)
719+
720+
637721
class TestPower:
638-
def test_power(self):
722+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
723+
def test_power(self, dtype):
639724
array1_data = numpy.arange(10)
640725
array2_data = numpy.arange(5, 15)
641-
out = numpy.empty(10, dtype=numpy.float64)
726+
out = numpy.empty(10, dtype=dtype)
642727

643728
# DPNP
644-
dp_array1 = dpnp.array(array1_data, dtype=dpnp.float64)
645-
dp_array2 = dpnp.array(array2_data, dtype=dpnp.float64)
646-
dp_out = dpnp.array(out, dtype=dpnp.float64)
729+
dp_array1 = dpnp.array(array1_data, dtype=dtype)
730+
dp_array2 = dpnp.array(array2_data, dtype=dtype)
731+
dp_out = dpnp.array(out, dtype=dtype)
647732
result = dpnp.power(dp_array1, dp_array2, out=dp_out)
648733

649734
# original
650-
np_array1 = numpy.array(array1_data, dtype=numpy.float64)
651-
np_array2 = numpy.array(array2_data, dtype=numpy.float64)
735+
np_array1 = numpy.array(array1_data, dtype=dtype)
736+
np_array2 = numpy.array(array2_data, dtype=dtype)
652737
expected = numpy.power(np_array1, np_array2, out=out)
653738

654-
assert_array_equal(expected, result)
739+
assert_allclose(expected, result)
655740

656741
@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True, no_none=True))
657742
def test_out_dtypes(self, dtype):
@@ -662,7 +747,7 @@ def test_out_dtypes(self, dtype):
662747
np_out = numpy.empty(size, dtype=numpy.complex64)
663748
expected = numpy.power(np_array1, np_array2, out=np_out)
664749

665-
dp_array1 = dpnp.arange(size, 2*size, dtype=dtype)
750+
dp_array1 = dpnp.arange(size, 2 * size, dtype=dtype)
666751
dp_array2 = dpnp.arange(size, dtype=dtype)
667752
dp_out = dpnp.empty(size, dtype=dpnp.complex64)
668753
result = dpnp.power(dp_array1, dp_array2, out=dp_out)

tests/test_strides.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def test_strides_true_devide(dtype, shape):
217217

218218

219219
@pytest.mark.parametrize("func_name",
220-
["power"])
220+
["add", "power"])
221221
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
222222
def test_strided_out_2args(func_name, dtype):
223223
np_out = numpy.ones((5, 3, 2))[::3]
@@ -236,7 +236,7 @@ def test_strided_out_2args(func_name, dtype):
236236

237237

238238
@pytest.mark.parametrize("func_name",
239-
["power"])
239+
["add", "power"])
240240
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
241241
def test_strided_in_out_2args(func_name, dtype):
242242
sh = (3, 4, 2)
@@ -258,7 +258,7 @@ def test_strided_in_out_2args(func_name, dtype):
258258

259259

260260
@pytest.mark.parametrize("func_name",
261-
["power"])
261+
["add", "power"])
262262
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
263263
def test_strided_in_out_2args_diff_out_dtype(func_name, dtype):
264264
sh = (3, 3, 2)
@@ -280,7 +280,7 @@ def test_strided_in_out_2args_diff_out_dtype(func_name, dtype):
280280

281281

282282
@pytest.mark.parametrize("func_name",
283-
["power"])
283+
["add", "power"])
284284
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True, no_none=True))
285285
def test_strided_in_2args_overlap(func_name, dtype):
286286
size = 5
@@ -296,7 +296,7 @@ def test_strided_in_2args_overlap(func_name, dtype):
296296

297297

298298
@pytest.mark.parametrize("func_name",
299-
["power"])
299+
["add", "power"])
300300
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True, no_none=True))
301301
def test_strided_in_out_2args_overlap(func_name, dtype):
302302
sh = (4, 3, 2)

tests/test_usm_type.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def test_coerced_usm_types_sum(usm_type_x, usm_type_y):
1818
y = dp.arange(1000, usm_type = usm_type_y)
1919

2020
z = 1.3 + x + y + 2
21+
z += x
22+
z += 7.4
2123

2224
assert x.usm_type == usm_type_x
2325
assert y.usm_type == usm_type_y

0 commit comments

Comments
 (0)