Skip to content

Commit 691fd86

Browse files
Added test for argmin with keepdims=True
1 parent 8240c76 commit 691fd86

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

dpctl/tests/test_usm_ndarray_reductions.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,16 @@ def test_max_min_axis():
4242
def test_reduction_keepdims():
4343
get_queue_or_skip()
4444

45-
x = dpt.ones((3, 4, 5, 6, 7), dtype="i4")
45+
n0, n1 = 3, 6
46+
x = dpt.ones((n0, 4, 5, n1, 7), dtype="i4")
4647
m = dpt.max(x, axis=(1, 2, -1), keepdims=True)
4748

48-
assert m.shape == (3, 1, 1, 6, 1)
49+
xx = dpt.reshape(dpt.permute_dims(x, (0, 3, 1, 2, -1)), (n0, n1, -1))
50+
p = dpt.argmax(xx, axis=-1, keepdims=True)
51+
52+
assert m.shape == (n0, 1, 1, n1, 1)
4953
assert dpt.all(m == dpt.reshape(x[:, 0, 0, :, 0], m.shape))
54+
assert dpt.all(p == 0)
5055

5156

5257
def test_max_scalar():

0 commit comments

Comments
 (0)