Skip to content

Commit 1cda64c

Browse files
committed
Added a test for raised errors in reductions
Also removed unused `_usm_types` in `test_tensor_sum`
1 parent 691fd86 commit 1cda64c

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

dpctl/tests/test_tensor_sum.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
"c8",
3737
"c16",
3838
]
39-
_usm_types = ["device", "shared", "host"]
4039

4140

4241
@pytest.mark.parametrize("arg_dtype", _all_dtypes)

dpctl/tests/test_usm_ndarray_reductions.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,21 @@ def test_argmax_argmin_identities():
216216
assert dpt.argmax(x) == 0
217217
x = dpt.full(3, dpt.iinfo(dpt.int32).max, dtype="i4")
218218
assert dpt.argmin(x) == 0
219+
220+
221+
def test_reduction_arg_validation():
222+
get_queue_or_skip()
223+
224+
x = dict()
225+
with pytest.raises(TypeError):
226+
dpt.sum(x)
227+
with pytest.raises(TypeError):
228+
dpt.max(x)
229+
with pytest.raises(TypeError):
230+
dpt.argmax(x)
231+
232+
x = dpt.zeros((0,), dtype="i4")
233+
with pytest.raises(ValueError):
234+
dpt.max(x)
235+
with pytest.raises(ValueError):
236+
dpt.argmax(x)

0 commit comments

Comments
 (0)