Skip to content

Commit d07de32

Browse files
Changed argsort tests to use take_along_axis
1 parent 0e69998 commit d07de32

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

dpctl/tests/test_usm_ndarray_sorting.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,24 @@ def test_argsort_axis0():
177177
x = dpt.reshape(xf, (n, m))
178178
idx = dpt.argsort(x, axis=0)
179179

180-
conseq_idx = dpt.arange(m, dtype=idx.dtype)
181-
s = x[idx, conseq_idx[dpt.newaxis, :]]
180+
s = dpt.take_along_axis(x, idx, axis=0)
182181

183182
assert dpt.all(s[:-1, :] <= s[1:, :])
184183

185184

185+
def test_argsort_axis1():
186+
get_queue_or_skip()
187+
188+
n, m = 200, 30
189+
xf = dpt.arange(n * m, 0, step=-1, dtype="i4")
190+
x = dpt.reshape(xf, (n, m))
191+
idx = dpt.argsort(x, axis=1)
192+
193+
s = dpt.take_along_axis(x, idx, axis=1)
194+
195+
assert dpt.all(s[:, :-1] <= s[:, 1:])
196+
197+
186198
def test_sort_strided():
187199
get_queue_or_skip()
188200

@@ -199,8 +211,9 @@ def test_argsort_strided():
199211
x_orig = dpt.arange(100, dtype="i4")
200212
x_flipped = dpt.flip(x_orig, axis=0)
201213
idx = dpt.argsort(x_flipped)
214+
s = dpt.take_along_axis(x_flipped, idx, axis=0)
202215

203-
assert dpt.all(x_flipped[idx] == x_orig)
216+
assert dpt.all(s == x_orig)
204217

205218

206219
def test_sort_0d_array():

0 commit comments

Comments
 (0)