Skip to content

Commit 786bec7

Browse files
Added tests for abs/not_equal with not-aligned start of array
1 parent 904b634 commit 786bec7

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

dpctl/tests/elementwise/test_abs.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,18 @@ def test_abs_complex_fp_special_values(dtype):
187187
tol = dpt.finfo(r.dtype).resolution
188188

189189
assert dpt.allclose(r, expected, atol=tol, rtol=tol, equal_nan=True)
190+
191+
192+
@pytest.mark.parametrize("dtype", _all_dtypes)
193+
def test_abs_alignment(dtype):
194+
q = get_queue_or_skip()
195+
skip_if_dtype_not_supported(dtype, q)
196+
197+
x = dpt.ones(512, dtype=dtype)
198+
r = dpt.abs(x)
199+
200+
r2 = dpt.abs(x[1:])
201+
assert np.allclose(dpt.asnumpy(r[1:]), dpt.asnumpy(r2))
202+
203+
dpt.abs(x[:-1], out=r[1:])
204+
assert np.allclose(dpt.asnumpy(r[1:]), dpt.asnumpy(r2))

dpctl/tests/elementwise/test_not_equal.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,21 @@ def __sycl_usm_array_interface__(self):
188188
c = Canary()
189189
with pytest.raises(ValueError):
190190
dpt.not_equal(a, c)
191+
192+
193+
@pytest.mark.parametrize("dtype", _all_dtypes)
194+
def test_not_equal_alignment(dtype):
195+
q = get_queue_or_skip()
196+
skip_if_dtype_not_supported(dtype, q)
197+
198+
n = 256
199+
s = dpt.concat((dpt.zeros(n, dtype=dtype), dpt.zeros(n, dtype=dtype)))
200+
201+
mask = s[:-1] != s[1:]
202+
(pos,) = dpt.nonzero(mask)
203+
assert dpt.all(pos == n)
204+
205+
out_arr = dpt.zeros(2 * n, dtype=mask.dtype)
206+
dpt.not_equal(s[:-1], s[1:], out=out_arr[1:])
207+
(pos,) = dpt.nonzero(mask)
208+
assert dpt.all(pos == (n + 1))

0 commit comments

Comments
 (0)