@@ -370,3 +370,39 @@ def test_where_compute_follows_data():
370
370
dpt .where (dpt .empty ((1 ,), dtype = "i4" , sycl_queue = q3 ), x1 , x2 )
371
371
with pytest .raises (ExecutionPlacementError ):
372
372
dpt .where (x1 , x1 , x2 )
373
+
374
+
375
+ def test_where_order ():
376
+ get_queue_or_skip ()
377
+
378
+ test_sh = (
379
+ 20 ,
380
+ 20 ,
381
+ )
382
+ test_sh2 = tuple (2 * dim for dim in test_sh )
383
+ n = test_sh [- 1 ]
384
+
385
+ for dt1 , dt2 in zip (["i4" , "i4" , "f4" ], ["i4" , "f4" , "i4" ]):
386
+ ar1 = dpt .zeros (test_sh , dtype = dt1 , order = "C" )
387
+ ar2 = dpt .ones (test_sh , dtype = dt2 , order = "C" )
388
+ condition = dpt .zeros (test_sh , dtype = "?" , order = "C" )
389
+ res = dpt .where (condition , ar1 , ar2 )
390
+ assert res .flags .c_contiguous
391
+
392
+ ar1 = dpt .ones (test_sh , dtype = dt1 , order = "F" )
393
+ ar2 = dpt .ones (test_sh , dtype = dt2 , order = "F" )
394
+ condition = dpt .zeros (test_sh , dtype = "?" , order = "F" )
395
+ res = dpt .where (condition , ar1 , ar2 )
396
+ assert res .flags .f_contiguous
397
+
398
+ ar1 = dpt .ones (test_sh2 , dtype = dt1 , order = "C" )[:20 , ::- 2 ]
399
+ ar2 = dpt .ones (test_sh2 , dtype = dt2 , order = "C" )[:20 , ::- 2 ]
400
+ condition = dpt .zeros (test_sh2 , dtype = "?" , order = "C" )[:20 , ::- 2 ]
401
+ res = dpt .where (condition , ar1 , ar2 )
402
+ assert res .strides == (n , - 1 )
403
+
404
+ ar1 = dpt .ones (test_sh2 , dtype = dt1 , order = "C" )[:20 , ::- 2 ].mT
405
+ ar2 = dpt .ones (test_sh2 , dtype = dt2 , order = "C" )[:20 , ::- 2 ].mT
406
+ condition = dpt .zeros (test_sh2 , dtype = "?" , order = "C" )[:20 , ::- 2 ].mT
407
+ res = dpt .where (condition , ar1 , ar2 )
408
+ assert res .strides == (- 1 , n )
0 commit comments