@@ -70,8 +70,8 @@ def test_where_basic():
70
70
71
71
out = dpt .where (
72
72
cond ,
73
- dpt .ones (cond .shape [0 ])[:, dpt .newaxis ],
74
- dpt .zeros (cond .shape [0 ])[:, dpt .newaxis ],
73
+ dpt .ones (cond .shape [0 ], dtype = "i4" )[:, dpt .newaxis ],
74
+ dpt .zeros (cond .shape [0 ], dtype = "i4" )[:, dpt .newaxis ],
75
75
)
76
76
assert (dpt .asnumpy (out ) == dpt .asnumpy (out_expected )).all ()
77
77
@@ -162,13 +162,13 @@ def test_where_empty():
162
162
# handling empty arrays
163
163
get_queue_or_skip ()
164
164
165
- empty = dpt .empty (0 )
165
+ empty = dpt .empty (0 , dtype = "i2" )
166
166
m = dpt .asarray (True )
167
- x1 = dpt .asarray (1 )
168
- x2 = dpt .asarray (2 )
167
+ x1 = dpt .asarray (1 , dtype = "i2" )
168
+ x2 = dpt .asarray (2 , dtype = "i2" )
169
169
res = dpt .where (empty , x1 , x2 )
170
170
171
- empty_np = np .empty (0 )
171
+ empty_np = np .empty (0 , dtype = "i2" )
172
172
m_np = dpt .asnumpy (m )
173
173
x1_np = dpt .asnumpy (x1 )
174
174
x2_np = dpt .asnumpy (x2 )
@@ -201,8 +201,8 @@ def test_where_contiguous(order):
201
201
order = order ,
202
202
)
203
203
204
- x1 = dpt .full (cond .shape , 2 , order = order )
205
- x2 = dpt .full (cond .shape , 3 , order = order )
204
+ x1 = dpt .full (cond .shape , 2 , dtype = "i4" , order = order )
205
+ x2 = dpt .full (cond .shape , 3 , dtype = "i4" , order = order )
206
206
expected = np .where (dpt .asnumpy (cond ), dpt .asnumpy (x1 ), dpt .asnumpy (x2 ))
207
207
res = dpt .where (cond , x1 , x2 )
208
208
@@ -214,11 +214,11 @@ def test_where_contiguous1D():
214
214
215
215
cond = dpt .asarray ([True , False , True , False , False , True ])
216
216
217
- x1 = dpt .full (cond .shape , 2 )
218
- x2 = dpt .full (cond .shape , 3 )
217
+ x1 = dpt .full (cond .shape , 2 , dtype = "i4" )
218
+ x2 = dpt .full (cond .shape , 3 , dtype = "i4" )
219
219
expected = np .where (dpt .asnumpy (cond ), dpt .asnumpy (x1 ), dpt .asnumpy (x2 ))
220
220
res = dpt .where (cond , x1 , x2 )
221
- assert _dtype_all_close (dpt .asnumpy (res ), expected )
221
+ assert_array_equal (dpt .asnumpy (res ), expected )
222
222
223
223
# test with complex dtype (branch in kernel)
224
224
x1 = dpt .astype (x1 , dpt .complex64 )
@@ -239,20 +239,39 @@ def test_where_strided():
239
239
(s0 , s1 ),
240
240
)[:, ::3 ]
241
241
242
- x1 = dpt .ones ((cond .shape [0 ], cond .shape [1 ] * 2 ))[:, ::2 ]
243
- x2 = dpt .zeros ((cond .shape [0 ], cond .shape [1 ] * 3 ))[:, ::3 ]
242
+ x1 = dpt .reshape (
243
+ dpt .arange (cond .shape [0 ] * cond .shape [1 ] * 2 , dtype = "i4" ),
244
+ (cond .shape [0 ], cond .shape [1 ] * 2 ),
245
+ )[:, ::2 ]
246
+ x2 = dpt .reshape (
247
+ dpt .arange (cond .shape [0 ] * cond .shape [1 ] * 3 , dtype = "i4" ),
248
+ (cond .shape [0 ], cond .shape [1 ] * 3 ),
249
+ )[:, ::3 ]
244
250
expected = np .where (dpt .asnumpy (cond ), dpt .asnumpy (x1 ), dpt .asnumpy (x2 ))
245
251
res = dpt .where (cond , x1 , x2 )
246
252
247
- assert _dtype_all_close (dpt .asnumpy (res ), expected )
253
+ assert_array_equal (dpt .asnumpy (res ), expected )
254
+
255
+ # negative strides
256
+ res = dpt .where (cond , dpt .flip (x1 ), x2 )
257
+ expected = np .where (
258
+ dpt .asnumpy (cond ), np .flip (dpt .asnumpy (x1 )), dpt .asnumpy (x2 )
259
+ )
260
+ assert_array_equal (dpt .asnumpy (res ), expected )
261
+
262
+ res = dpt .where (dpt .flip (cond ), x1 , x2 )
263
+ expected = np .where (
264
+ np .flip (dpt .asnumpy (cond )), dpt .asnumpy (x1 ), dpt .asnumpy (x2 )
265
+ )
266
+ assert_array_equal (dpt .asnumpy (res ), expected )
248
267
249
268
250
269
def test_where_arg_validation ():
251
270
get_queue_or_skip ()
252
271
253
272
check = dict ()
254
- x1 = dpt .empty ((1 ,))
255
- x2 = dpt .empty ((1 ,))
273
+ x1 = dpt .empty ((1 ,), dtype = "i4" )
274
+ x2 = dpt .empty ((1 ,), dtype = "i4" )
256
275
257
276
with pytest .raises (TypeError ):
258
277
dpt .where (check , x1 , x2 )
@@ -267,12 +286,12 @@ def test_where_compute_follows_data():
267
286
q2 = get_queue_or_skip ()
268
287
q3 = get_queue_or_skip ()
269
288
270
- x1 = dpt .empty ((1 ,), sycl_queue = q1 )
271
- x2 = dpt .empty ((1 ,), sycl_queue = q2 )
289
+ x1 = dpt .empty ((1 ,), dtype = "i4" , sycl_queue = q1 )
290
+ x2 = dpt .empty ((1 ,), dtype = "i4" , sycl_queue = q2 )
272
291
273
292
with pytest .raises (ExecutionPlacementError ):
274
- dpt .where (dpt .empty ((1 ,), sycl_queue = q1 ), x1 , x2 )
293
+ dpt .where (dpt .empty ((1 ,), dtype = "i4" , sycl_queue = q1 ), x1 , x2 )
275
294
with pytest .raises (ExecutionPlacementError ):
276
- dpt .where (dpt .empty ((1 ,), sycl_queue = q3 ), x1 , x2 )
295
+ dpt .where (dpt .empty ((1 ,), dtype = "i4" , sycl_queue = q3 ), x1 , x2 )
277
296
with pytest .raises (ExecutionPlacementError ):
278
297
dpt .where (x1 , x1 , x2 )
0 commit comments