Skip to content

Commit

Permalink
Support parameter out in dpnp.multiply()
Browse files Browse the repository at this point in the history
  • Loading branch information
antonwolfy committed Apr 3, 2023
1 parent 20c262e commit 1f6619a
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 30 deletions.
5 changes: 4 additions & 1 deletion dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,10 @@ def __ilshift__(self, other):

# '__imatmul__',
# '__imod__',
# '__imul__',

def __imul__(self, other):
dpnp.multiply(self, other, out=self)
return self

def __index__(self):
return self._array_obj.__index__()
Expand Down
26 changes: 2 additions & 24 deletions dpnp/dpnp_iface_mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,7 +1155,7 @@ def multiply(x1,
-----------
Parameters `x1` and `x2` are supported as either scalar, :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`, but both `x1` and `x2` can not be scalars at the same time.
Parameters `out`, `where`, `dtype` and `subok` are supported with their default values.
Parameters `where`, `dtype` and `subok` are supported with their default values.
Keyword arguments ``kwargs`` are currently unsupported.
Otherwise the functions will be executed sequentially on CPU.
Input array data types are limited by supported DPNP :ref:`Data types`.
Expand All @@ -1170,29 +1170,7 @@ def multiply(x1,
"""

if out is not None:
pass
elif where is not True:
pass
elif dtype is not None:
pass
elif subok is not True:
pass
elif dpnp.isscalar(x1) and dpnp.isscalar(x2):
# at least either x1 or x2 has to be an array
pass
else:
# get USM type and queue to copy scalar from the host memory into a USM allocation
usm_type, queue = get_usm_allocations([x1, x2]) if dpnp.isscalar(x1) or dpnp.isscalar(x2) else (None, None)

x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False,
alloc_usm_type=usm_type, alloc_queue=queue)
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False,
alloc_usm_type=usm_type, alloc_queue=queue)
if x1_desc and x2_desc:
return dpnp_multiply(x1_desc, x2_desc, dtype=dtype, out=out, where=where).get_pyobj()

return call_origin(numpy.multiply, x1, x2, out=out, where=where, dtype=dtype, subok=subok, **kwargs)
return _check_nd_call(numpy.multiply, dpnp_multiply, x1, x2, out=out, where=where, dtype=dtype, subok=subok, **kwargs)


def nancumprod(x1, **kwargs):
Expand Down
83 changes: 83 additions & 0 deletions tests/test_mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,89 @@ def test_invalid_out(self, out):
assert_raises(TypeError, numpy.add, a.asnumpy(), 2, out)


class TestMultiply:
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
def test_multiply(self, dtype):
array1_data = numpy.arange(10)
array2_data = numpy.arange(5, 15)
out = numpy.empty(10, dtype=dtype)

# DPNP
dp_array1 = dpnp.array(array1_data, dtype=dtype)
dp_array2 = dpnp.array(array2_data, dtype=dtype)
dp_out = dpnp.array(out, dtype=dtype)
result = dpnp.multiply(dp_array1, dp_array2, out=dp_out)

# original
np_array1 = numpy.array(array1_data, dtype=dtype)
np_array2 = numpy.array(array2_data, dtype=dtype)
expected = numpy.multiply(np_array1, np_array2, out=out)

assert_allclose(expected, result)
assert_allclose(out, dp_out)

@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
def test_out_dtypes(self, dtype):
size = 2 if dtype == dpnp.bool else 10

np_array1 = numpy.arange(size, 2 * size, dtype=dtype)
np_array2 = numpy.arange(size, dtype=dtype)
np_out = numpy.empty(size, dtype=numpy.complex64)
expected = numpy.multiply(np_array1, np_array2, out=np_out)

dp_array1 = dpnp.arange(size, 2 * size, dtype=dtype)
dp_array2 = dpnp.arange(size, dtype=dtype)
dp_out = dpnp.empty(size, dtype=dpnp.complex64)
result = dpnp.multiply(dp_array1, dp_array2, out=dp_out)

assert_array_equal(expected, result)

@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
def test_out_overlap(self, dtype):
size = 1 if dtype == dpnp.bool else 15

np_a = numpy.arange(2 * size, dtype=dtype)
expected = numpy.multiply(np_a[size::], np_a[::2], out=np_a[:size:])

dp_a = dpnp.arange(2 * size, dtype=dtype)
result = dpnp.multiply(dp_a[size::], dp_a[::2], out=dp_a[:size:])

assert_allclose(expected, result)
assert_allclose(dp_a, np_a)

@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_none=True))
def test_inplace_strided_out(self, dtype):
size = 21

np_a = numpy.arange(size, dtype=dtype)
np_a[::3] *= 4

dp_a = dpnp.arange(size, dtype=dtype)
dp_a[::3] *= 4

assert_allclose(dp_a, np_a)

@pytest.mark.parametrize("shape",
[(0,), (15, ), (2, 2)],
ids=['(0,)', '(15, )', '(2,2)'])
def test_invalid_shape(self, shape):
dp_array1 = dpnp.arange(10, dtype=dpnp.float64)
dp_array2 = dpnp.arange(5, 15, dtype=dpnp.float64)
dp_out = dpnp.empty(shape, dtype=dpnp.float64)

with pytest.raises(ValueError):
dpnp.multiply(dp_array1, dp_array2, out=dp_out)

@pytest.mark.parametrize("out",
[4, (), [], (3, 7), [2, 4]],
ids=['4', '()', '[]', '(3, 7)', '[2, 4]'])
def test_invalid_out(self, out):
a = dpnp.arange(10)

assert_raises(TypeError, dpnp.multiply, a, 2, out)
assert_raises(TypeError, numpy.multiply, a.asnumpy(), 2, out)


class TestPower:
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
def test_power(self, dtype):
Expand Down
10 changes: 5 additions & 5 deletions tests/test_strides.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def test_strides_true_devide(dtype, shape):


@pytest.mark.parametrize("func_name",
["add", "power"])
["add", "multiply", "power"])
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
def test_strided_out_2args(func_name, dtype):
np_out = numpy.ones((5, 3, 2))[::3]
Expand All @@ -236,7 +236,7 @@ def test_strided_out_2args(func_name, dtype):


@pytest.mark.parametrize("func_name",
["add", "power"])
["add", "multiply", "power"])
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
def test_strided_in_out_2args(func_name, dtype):
sh = (3, 4, 2)
Expand All @@ -258,7 +258,7 @@ def test_strided_in_out_2args(func_name, dtype):


@pytest.mark.parametrize("func_name",
["add", "power"])
["add", "multiply", "power"])
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
def test_strided_in_out_2args_diff_out_dtype(func_name, dtype):
sh = (3, 3, 2)
Expand All @@ -280,7 +280,7 @@ def test_strided_in_out_2args_diff_out_dtype(func_name, dtype):


@pytest.mark.parametrize("func_name",
["add", "power"])
["add", "multiply", "power"])
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True, no_none=True))
def test_strided_in_2args_overlap(func_name, dtype):
size = 5
Expand All @@ -296,7 +296,7 @@ def test_strided_in_2args_overlap(func_name, dtype):


@pytest.mark.parametrize("func_name",
["add", "power"])
["add", "multiply", "power"])
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True, no_none=True))
def test_strided_in_out_2args_overlap(func_name, dtype):
sh = (4, 3, 2)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_usm_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def test_coerced_usm_types_mul(usm_type_x, usm_type_y):
y = dp.arange(10, usm_type = usm_type_y)

z = 3 * x * y * 1.5
z *= x
z *= 4.8

assert x.usm_type == usm_type_x
assert y.usm_type == usm_type_y
Expand Down

0 comments on commit 1f6619a

Please sign in to comment.