Skip to content

Commit ece6b00

Browse files
committed
Support parameter out in dpnp.multiply()
1 parent b4ec5ea commit ece6b00

File tree

5 files changed

+96
-30
lines changed

5 files changed

+96
-30
lines changed

dpnp/dpnp_array.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,10 @@ def __ilshift__(self, other):
211211

212212
# '__imatmul__',
213213
# '__imod__',
214-
# '__imul__',
214+
215+
def __imul__(self, other):
216+
dpnp.multiply(self, other, out=self)
217+
return self
215218

216219
def __index__(self):
217220
return self._array_obj.__index__()

dpnp/dpnp_iface_mathematical.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,7 +1155,7 @@ def multiply(x1,
11551155
-----------
11561156
Parameters `x1` and `x2` are supported as either scalar, :class:`dpnp.ndarray`
11571157
or :class:`dpctl.tensor.usm_ndarray`, but both `x1` and `x2` can not be scalars at the same time.
1158-
Parameters `out`, `where`, `dtype` and `subok` are supported with their default values.
1158+
Parameters `where`, `dtype` and `subok` are supported with their default values.
11591159
Keyword arguments ``kwargs`` are currently unsupported.
11601160
Otherwise the functions will be executed sequentially on CPU.
11611161
Input array data types are limited by supported DPNP :ref:`Data types`.
@@ -1170,29 +1170,7 @@ def multiply(x1,
11701170
11711171
"""
11721172

1173-
if out is not None:
1174-
pass
1175-
elif where is not True:
1176-
pass
1177-
elif dtype is not None:
1178-
pass
1179-
elif subok is not True:
1180-
pass
1181-
elif dpnp.isscalar(x1) and dpnp.isscalar(x2):
1182-
# at least either x1 or x2 has to be an array
1183-
pass
1184-
else:
1185-
# get USM type and queue to copy scalar from the host memory into a USM allocation
1186-
usm_type, queue = get_usm_allocations([x1, x2]) if dpnp.isscalar(x1) or dpnp.isscalar(x2) else (None, None)
1187-
1188-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False,
1189-
alloc_usm_type=usm_type, alloc_queue=queue)
1190-
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False,
1191-
alloc_usm_type=usm_type, alloc_queue=queue)
1192-
if x1_desc and x2_desc:
1193-
return dpnp_multiply(x1_desc, x2_desc, dtype=dtype, out=out, where=where).get_pyobj()
1194-
1195-
return call_origin(numpy.multiply, x1, x2, out=out, where=where, dtype=dtype, subok=subok, **kwargs)
1173+
return _check_nd_call(numpy.multiply, dpnp_multiply, x1, x2, out=out, where=where, dtype=dtype, subok=subok, **kwargs)
11961174

11971175

11981176
def nancumprod(x1, **kwargs):

tests/test_mathematical.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,89 @@ def test_invalid_out(self, out):
718718
assert_raises(TypeError, numpy.add, a.asnumpy(), 2, out)
719719

720720

721+
class TestMultiply:
722+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
723+
def test_multiply(self, dtype):
724+
array1_data = numpy.arange(10)
725+
array2_data = numpy.arange(5, 15)
726+
out = numpy.empty(10, dtype=dtype)
727+
728+
# DPNP
729+
dp_array1 = dpnp.array(array1_data, dtype=dtype)
730+
dp_array2 = dpnp.array(array2_data, dtype=dtype)
731+
dp_out = dpnp.array(out, dtype=dtype)
732+
result = dpnp.multiply(dp_array1, dp_array2, out=dp_out)
733+
734+
# original
735+
np_array1 = numpy.array(array1_data, dtype=dtype)
736+
np_array2 = numpy.array(array2_data, dtype=dtype)
737+
expected = numpy.multiply(np_array1, np_array2, out=out)
738+
739+
assert_allclose(expected, result)
740+
assert_allclose(out, dp_out)
741+
742+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
743+
def test_out_dtypes(self, dtype):
744+
size = 2 if dtype == dpnp.bool else 10
745+
746+
np_array1 = numpy.arange(size, 2 * size, dtype=dtype)
747+
np_array2 = numpy.arange(size, dtype=dtype)
748+
np_out = numpy.empty(size, dtype=numpy.complex64)
749+
expected = numpy.multiply(np_array1, np_array2, out=np_out)
750+
751+
dp_array1 = dpnp.arange(size, 2 * size, dtype=dtype)
752+
dp_array2 = dpnp.arange(size, dtype=dtype)
753+
dp_out = dpnp.empty(size, dtype=dpnp.complex64)
754+
result = dpnp.multiply(dp_array1, dp_array2, out=dp_out)
755+
756+
assert_array_equal(expected, result)
757+
758+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
759+
def test_out_overlap(self, dtype):
760+
size = 1 if dtype == dpnp.bool else 15
761+
762+
np_a = numpy.arange(2 * size, dtype=dtype)
763+
expected = numpy.multiply(np_a[size::], np_a[::2], out=np_a[:size:])
764+
765+
dp_a = dpnp.arange(2 * size, dtype=dtype)
766+
result = dpnp.multiply(dp_a[size::], dp_a[::2], out=dp_a[:size:])
767+
768+
assert_allclose(expected, result)
769+
assert_allclose(dp_a, np_a)
770+
771+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_none=True))
772+
def test_inplace_strided_out(self, dtype):
773+
size = 21
774+
775+
np_a = numpy.arange(size, dtype=dtype)
776+
np_a[::3] *= 4
777+
778+
dp_a = dpnp.arange(size, dtype=dtype)
779+
dp_a[::3] *= 4
780+
781+
assert_allclose(dp_a, np_a)
782+
783+
@pytest.mark.parametrize("shape",
784+
[(0,), (15, ), (2, 2)],
785+
ids=['(0,)', '(15, )', '(2,2)'])
786+
def test_invalid_shape(self, shape):
787+
dp_array1 = dpnp.arange(10, dtype=dpnp.float64)
788+
dp_array2 = dpnp.arange(5, 15, dtype=dpnp.float64)
789+
dp_out = dpnp.empty(shape, dtype=dpnp.float64)
790+
791+
with pytest.raises(ValueError):
792+
dpnp.multiply(dp_array1, dp_array2, out=dp_out)
793+
794+
@pytest.mark.parametrize("out",
795+
[4, (), [], (3, 7), [2, 4]],
796+
ids=['4', '()', '[]', '(3, 7)', '[2, 4]'])
797+
def test_invalid_out(self, out):
798+
a = dpnp.arange(10)
799+
800+
assert_raises(TypeError, dpnp.multiply, a, 2, out)
801+
assert_raises(TypeError, numpy.multiply, a.asnumpy(), 2, out)
802+
803+
721804
class TestPower:
722805
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
723806
def test_power(self, dtype):

tests/test_strides.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def test_strides_true_devide(dtype, shape):
217217

218218

219219
@pytest.mark.parametrize("func_name",
220-
["add", "power"])
220+
["add", "multiply", "power"])
221221
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
222222
def test_strided_out_2args(func_name, dtype):
223223
np_out = numpy.ones((5, 3, 2))[::3]
@@ -236,7 +236,7 @@ def test_strided_out_2args(func_name, dtype):
236236

237237

238238
@pytest.mark.parametrize("func_name",
239-
["add", "power"])
239+
["add", "multiply", "power"])
240240
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
241241
def test_strided_in_out_2args(func_name, dtype):
242242
sh = (3, 4, 2)
@@ -258,7 +258,7 @@ def test_strided_in_out_2args(func_name, dtype):
258258

259259

260260
@pytest.mark.parametrize("func_name",
261-
["add", "power"])
261+
["add", "multiply", "power"])
262262
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
263263
def test_strided_in_out_2args_diff_out_dtype(func_name, dtype):
264264
sh = (3, 3, 2)
@@ -280,7 +280,7 @@ def test_strided_in_out_2args_diff_out_dtype(func_name, dtype):
280280

281281

282282
@pytest.mark.parametrize("func_name",
283-
["add", "power"])
283+
["add", "multiply", "power"])
284284
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True, no_none=True))
285285
def test_strided_in_2args_overlap(func_name, dtype):
286286
size = 5
@@ -296,7 +296,7 @@ def test_strided_in_2args_overlap(func_name, dtype):
296296

297297

298298
@pytest.mark.parametrize("func_name",
299-
["add", "power"])
299+
["add", "multiply", "power"])
300300
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True, no_none=True))
301301
def test_strided_in_out_2args_overlap(func_name, dtype):
302302
sh = (4, 3, 2)

tests/test_usm_type.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def test_coerced_usm_types_mul(usm_type_x, usm_type_y):
3333
y = dp.arange(10, usm_type = usm_type_y)
3434

3535
z = 3 * x * y * 1.5
36+
z *= x
37+
z *= 4.8
3638

3739
assert x.usm_type == usm_type_x
3840
assert y.usm_type == usm_type_y

0 commit comments

Comments
 (0)