Skip to content

Commit

Permalink
Adds tests for reduction out kwarg
Browse files Browse the repository at this point in the history
  • Loading branch information
ndgrigorian committed Apr 16, 2024
1 parent c643acb commit faa0e69
Showing 1 changed file with 171 additions and 0 deletions.
171 changes: 171 additions & 0 deletions dpctl/tests/test_usm_ndarray_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import dpctl.tensor as dpt
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
from dpctl.utils import ExecutionPlacementError

_no_complex_dtypes = [
"?",
Expand Down Expand Up @@ -497,3 +498,173 @@ def test_tree_reduction_axis1_axis0():
rtol=tol,
atol=tol,
)


def test_numeric_reduction_out_kwarg():
get_queue_or_skip()

n1, n2, n3 = 3, 4, 5
x = dpt.ones((n1, n2, n3), dtype="i8")
out = dpt.zeros((2 * n1, 3 * n2), dtype="i8")
res = dpt.sum(x, axis=-1, out=out[::-2, 1::3])
assert dpt.all(out[::-2, 0::3] == 0)
assert dpt.all(out[::-2, 2::3] == 0)
assert dpt.all(out[::-2, 1::3] == res)
assert dpt.all(out[::-2, 1::3] == 5)

out = dpt.zeros((2 * n1, 3 * n2, 1), dtype="i8")
res = dpt.sum(x, axis=-1, keepdims=True, out=out[::-2, 1::3])
assert res.shape == (n1, n2, 1)
assert dpt.all(out[::-2, 0::3] == 0)
assert dpt.all(out[::-2, 2::3] == 0)
assert dpt.all(out[::-2, 1::3] == res)
assert dpt.all(out[::-2, 1::3] == 5)

res = dpt.sum(x, axis=0, out=x[-1])
assert dpt.all(x[-1] == res)
assert dpt.all(x[-1] == 3)
assert dpt.all(x[0:-1] == 1)

# test no-op case
x = dpt.ones((n1, n2, n3), dtype="i8")
out = dpt.zeros((2 * n1, 3 * n2, n3), dtype="i8")
res = dpt.sum(x, axis=(), out=out[::-2, 1::3])
assert dpt.all(out[::-2, 0::3] == 0)
assert dpt.all(out[::-2, 2::3] == 0)
assert dpt.all(out[::-2, 1::3] == x)

# test with dtype kwarg
x = dpt.ones((n1, n2, n3), dtype="i4")
out = dpt.zeros((2 * n1, 3 * n2), dtype="f4")
res = dpt.sum(x, axis=-1, dtype="f4", out=out[::-2, 1::3])
assert dpt.allclose(out[::-2, 0::3], dpt.zeros_like(res))
assert dpt.allclose(out[::-2, 2::3], dpt.zeros_like(res))
assert dpt.allclose(out[::-2, 1::3], res)
assert dpt.allclose(out[::-2, 1::3], dpt.full_like(res, 5, dtype="f4"))


def test_comparison_reduction_out_kwarg():
get_queue_or_skip()

n1, n2, n3 = 3, 4, 5
x = dpt.reshape(dpt.arange(n1 * n2 * n3, dtype="i4"), (n1, n2, n3))
out = dpt.zeros((2 * n1, 3 * n2), dtype="i4")
res = dpt.max(x, axis=-1, out=out[::-2, 1::3])
assert dpt.all(out[::-2, 0::3] == 0)
assert dpt.all(out[::-2, 2::3] == 0)
assert dpt.all(out[::-2, 1::3] == res)
assert dpt.all(out[::-2, 1::3] == x[:, :, -1])

out = dpt.zeros((2 * n1, 3 * n2, 1), dtype="i4")
res = dpt.max(x, axis=-1, keepdims=True, out=out[::-2, 1::3])
assert res.shape == (n1, n2, 1)
assert dpt.all(out[::-2, 0::3] == 0)
assert dpt.all(out[::-2, 2::3] == 0)
assert dpt.all(out[::-2, 1::3] == res)
assert dpt.all(out[::-2, 1::3] == x[:, :, -1, dpt.newaxis])

# test no-op case
out = dpt.zeros((2 * n1, 3 * n2, n3), dtype="i4")
res = dpt.max(x, axis=(), out=out[::-2, 1::3])
assert dpt.all(out[::-2, 0::3] == 0)
assert dpt.all(out[::-2, 2::3] == 0)
assert dpt.all(out[::-2, 1::3] == x)

# test overlap
res = dpt.max(x, axis=0, out=x[0])
assert dpt.all(x[0] == res)
assert dpt.all(x[0] == x[-1])


def test_search_reduction_out_kwarg():
get_queue_or_skip()

n1, n2, n3 = 3, 4, 5
dt = dpt.__array_namespace_info__().default_dtypes()["indexing"]

x = dpt.reshape(dpt.arange(n1 * n2 * n3, dtype=dt), (n1, n2, n3))
out = dpt.zeros((2 * n1, 3 * n2), dtype=dt)
res = dpt.argmax(x, axis=-1, out=out[::-2, 1::3])
assert dpt.all(out[::-2, 0::3] == 0)
assert dpt.all(out[::-2, 2::3] == 0)
assert dpt.all(out[::-2, 1::3] == res)
assert dpt.all(out[::-2, 1::3] == n2)

out = dpt.zeros((2 * n1, 3 * n2, 1), dtype=dt)
res = dpt.argmax(x, axis=-1, keepdims=True, out=out[::-2, 1::3])
assert res.shape == (n1, n2, 1)
assert dpt.all(out[::-2, 0::3] == 0)
assert dpt.all(out[::-2, 2::3] == 0)
assert dpt.all(out[::-2, 1::3] == res)
assert dpt.all(out[::-2, 1::3] == n3 - 1)

# test no-op case
x = dpt.ones((), dtype=dt)
out = dpt.ones(2, dtype=dt)
res = dpt.argmax(x, axis=None, out=out[1])
assert dpt.all(out[0] == 1)
assert dpt.all(out[1] == 0)

# test overlap
x = dpt.reshape(dpt.arange(n1 * n2, dtype=dt), (n1, n2))
res = dpt.argmax(x, axis=0, out=x[0])
assert dpt.all(x[0] == res)
assert dpt.all(x[0] == n1 - 1)


def test_reduction_out_kwarg_arg_validation():
q1 = get_queue_or_skip()
q2 = get_queue_or_skip()

ind_dt = dpt.__array_namespace_info__().default_dtypes()["indexing"]

x = dpt.ones(10, dtype="f4")
out_wrong_queue = dpt.empty((), dtype="f4", sycl_queue=q2)
out_wrong_dtype = dpt.empty((), dtype="i4", sycl_queue=q1)
out_wrong_shape = dpt.empty(1, dtype="f4", sycl_queue=q1)
out_wrong_keepdims = dpt.empty((), dtype="f4", sycl_queue=q1)
out_not_writable = dpt.empty((), dtype="f4", sycl_queue=q1)
out_not_writable.flags["W"] = False

with pytest.raises(TypeError):
dpt.sum(x, out=dict())
with pytest.raises(TypeError):
dpt.max(x, out=dict())
with pytest.raises(TypeError):
dpt.argmax(x, out=dict())
with pytest.raises(ExecutionPlacementError):
dpt.sum(x, out=out_wrong_queue)
with pytest.raises(ExecutionPlacementError):
dpt.max(x, out=out_wrong_queue)
with pytest.raises(ExecutionPlacementError):
dpt.argmax(x, out=dpt.empty_like(out_wrong_queue, dtype=ind_dt))
with pytest.raises(ValueError):
dpt.sum(x, out=out_wrong_dtype)
with pytest.raises(ValueError):
dpt.max(x, out=out_wrong_dtype)
with pytest.raises(ValueError):
dpt.argmax(x, out=dpt.empty_like(out_wrong_dtype, dtype="f4"))
with pytest.raises(ValueError):
dpt.sum(x, out=out_wrong_shape)
with pytest.raises(ValueError):
dpt.max(x, out=out_wrong_shape)
with pytest.raises(ValueError):
dpt.argmax(x, out=dpt.empty_like(out_wrong_shape, dtype=ind_dt))
with pytest.raises(ValueError):
dpt.sum(x, out=out_not_writable)
with pytest.raises(ValueError):
dpt.max(x, out=out_not_writable)
with pytest.raises(ValueError):
search_not_writable = dpt.empty_like(out_not_writable, dtype=ind_dt)
search_not_writable.flags["W"] = False
dpt.argmax(x, out=search_not_writable)
with pytest.raises(ValueError):
dpt.sum(x, keepdims=True, out=out_wrong_keepdims)
with pytest.raises(ValueError):
dpt.max(x, keepdims=True, out=out_wrong_keepdims)
with pytest.raises(ValueError):
dpt.argmax(
x,
keepdims=True,
out=dpt.empty_like(out_wrong_keepdims, dtype=ind_dt),
)

0 comments on commit faa0e69

Please sign in to comment.