Skip to content

Commit e9a4f30

Browse files
committed
Adds tests for inplace addition
1 parent 429c368 commit e9a4f30

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed

dpctl/tests/elementwise/test_add.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,3 +341,58 @@ def test_add_dtype_error(
341341
assert_raises_regex(
342342
TypeError, "Output array of type.*is needed", dpt.add, ar1, ar2, y
343343
)
344+
345+
346+
@pytest.mark.parametrize("dtype", _all_dtypes)
347+
def test_add_inplace_python_scalar(dtype):
348+
q = get_queue_or_skip()
349+
skip_if_dtype_not_supported(dtype, q)
350+
X = dpt.zeros((10, 10), dtype=dtype, sycl_queue=q)
351+
dt_kind = X.dtype.kind
352+
if dt_kind in "ui":
353+
X += int(0)
354+
elif dt_kind == "f":
355+
X += float(0)
356+
elif dt_kind == "c":
357+
X += complex(0)
358+
elif dt_kind == "b":
359+
X += bool(0)
360+
361+
362+
@pytest.mark.parametrize("op1_dtype", _all_dtypes)
363+
@pytest.mark.parametrize("op2_dtype", _all_dtypes)
364+
def test_add_inplace_dtype_matrix(op1_dtype, op2_dtype):
365+
q = get_queue_or_skip()
366+
skip_if_dtype_not_supported(op1_dtype, q)
367+
skip_if_dtype_not_supported(op2_dtype, q)
368+
369+
if dpt.can_cast(op2_dtype, op1_dtype, casting="safe"):
370+
sz = 127
371+
ar1 = dpt.ones(sz, dtype=op1_dtype)
372+
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
373+
374+
ar1 += ar2
375+
assert (
376+
dpt.asnumpy(ar1) == np.full(ar1.shape, 2, dtype=ar1.dtype)
377+
).all()
378+
379+
ar3 = dpt.ones(sz, dtype=op1_dtype)
380+
ar4 = dpt.ones(2 * sz, dtype=op2_dtype)
381+
382+
ar3[::-1] += ar4[::2]
383+
assert (
384+
dpt.asnumpy(ar3) == np.full(ar3.shape, 2, dtype=ar3.dtype)
385+
).all()
386+
387+
else:
388+
assert pytest.raises(TypeError)
389+
390+
391+
def test_add_inplace_broadcasting():
392+
get_queue_or_skip()
393+
394+
m = dpt.ones((100, 5), dtype="i4")
395+
v = dpt.arange(5, dtype="i4")
396+
397+
m += v
398+
assert (dpt.asnumpy(m) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all()

0 commit comments

Comments
 (0)