@@ -177,12 +177,24 @@ def test_argsort_axis0():
177
177
x = dpt .reshape (xf , (n , m ))
178
178
idx = dpt .argsort (x , axis = 0 )
179
179
180
- conseq_idx = dpt .arange (m , dtype = idx .dtype )
181
- s = x [idx , conseq_idx [dpt .newaxis , :]]
180
+ s = dpt .take_along_axis (x , idx , axis = 0 )
182
181
183
182
assert dpt .all (s [:- 1 , :] <= s [1 :, :])
184
183
185
184
185
+ def test_argsort_axis1 ():
186
+ get_queue_or_skip ()
187
+
188
+ n , m = 200 , 30
189
+ xf = dpt .arange (n * m , 0 , step = - 1 , dtype = "i4" )
190
+ x = dpt .reshape (xf , (n , m ))
191
+ idx = dpt .argsort (x , axis = 1 )
192
+
193
+ s = dpt .take_along_axis (x , idx , axis = 1 )
194
+
195
+ assert dpt .all (s [:, :- 1 ] <= s [:, 1 :])
196
+
197
+
186
198
def test_sort_strided ():
187
199
get_queue_or_skip ()
188
200
@@ -199,8 +211,9 @@ def test_argsort_strided():
199
211
x_orig = dpt .arange (100 , dtype = "i4" )
200
212
x_flipped = dpt .flip (x_orig , axis = 0 )
201
213
idx = dpt .argsort (x_flipped )
214
+ s = dpt .take_along_axis (x_flipped , idx , axis = 0 )
202
215
203
- assert dpt .all (x_flipped [ idx ] == x_orig )
216
+ assert dpt .all (s == x_orig )
204
217
205
218
206
219
def test_sort_0d_array ():
0 commit comments