20
20
from numpy .testing import assert_array_equal
21
21
22
22
import dpctl .tensor as dpt
23
+ from dpctl .tensor ._search_functions import _where_result_type
24
+ from dpctl .tensor ._type_utils import _all_data_types
25
+ from dpctl .utils import ExecutionPlacementError
23
26
24
27
_all_dtypes = [
25
28
"u1" ,
38
41
]
39
42
40
43
44
+ class mock_device :
45
+ def __init__ (self , fp16 , fp64 ):
46
+ self .has_aspect_fp16 = fp16
47
+ self .has_aspect_fp64 = fp64
48
+
49
+
41
50
def test_where_basic ():
42
51
get_queue_or_skip ()
43
52
@@ -54,7 +63,16 @@ def test_where_basic():
54
63
out_expected = dpt .asarray (
55
64
[[1 , 0 , 0 ], [0 , 1 , 0 ], [0 , 0 , 1 ], [0 , 0 , 0 ], [1 , 1 , 1 ]]
56
65
)
66
+ assert (dpt .asnumpy (out ) == dpt .asnumpy (out_expected )).all ()
67
+
68
+ out = dpt .where (cond , dpt .ones (cond .shape ), dpt .zeros (cond .shape ))
69
+ assert (dpt .asnumpy (out ) == dpt .asnumpy (out_expected )).all ()
57
70
71
+ out = dpt .where (
72
+ cond ,
73
+ dpt .ones (cond .shape [0 ])[:, dpt .newaxis ],
74
+ dpt .zeros (cond .shape [0 ])[:, dpt .newaxis ],
75
+ )
58
76
assert (dpt .asnumpy (out ) == dpt .asnumpy (out_expected )).all ()
59
77
60
78
@@ -70,6 +88,31 @@ def _dtype_all_close(x1, x2):
70
88
return np .allclose (x1 , x2 )
71
89
72
90
91
+ @pytest .mark .parametrize ("dt1" , _all_dtypes )
92
+ @pytest .mark .parametrize ("dt2" , _all_dtypes )
93
+ @pytest .mark .parametrize ("fp16" , [True , False ])
94
+ @pytest .mark .parametrize ("fp64" , [True , False ])
95
+ def test_where_result_types (dt1 , dt2 , fp16 , fp64 ):
96
+ dev = mock_device (fp16 , fp64 )
97
+
98
+ dt1 = dpt .dtype (dt1 )
99
+ dt2 = dpt .dtype (dt2 )
100
+ res_t = _where_result_type (dt1 , dt2 , dev )
101
+
102
+ if fp16 and fp64 :
103
+ assert res_t == dpt .result_type (dt1 , dt2 )
104
+ else :
105
+ if res_t :
106
+ assert res_t .kind == dpt .result_type (dt1 , dt2 ).kind
107
+ else :
108
+ # some illegal cases are covered above, but
109
+ # this guarantees that _where_result_type
110
+ # produces None only when one of the dtypes
111
+ # is illegal given fp aspects of device
112
+ all_dts = _all_data_types (fp16 , fp64 )
113
+ assert dt1 not in all_dts or dt2 not in all_dts
114
+
115
+
73
116
@pytest .mark .parametrize ("dt1" , _all_dtypes )
74
117
@pytest .mark .parametrize ("dt2" , _all_dtypes )
75
118
def test_where_all_dtypes (dt1 , dt2 ):
@@ -78,17 +121,39 @@ def test_where_all_dtypes(dt1, dt2):
78
121
skip_if_dtype_not_supported (dt2 , q )
79
122
80
123
cond = dpt .asarray ([False , False , False , True , True ], sycl_queue = q )
81
- x1 = dpt .asarray (2 , sycl_queue = q )
82
- x2 = dpt .asarray (3 , sycl_queue = q )
83
-
124
+ x1 = dpt .asarray (2 , dtype = dt1 , sycl_queue = q )
125
+ x2 = dpt .asarray (3 , dtype = dt2 , sycl_queue = q )
84
126
res = dpt .where (cond , x1 , x2 )
127
+
85
128
res_check = np .asarray ([3 , 3 , 3 , 2 , 2 ], dtype = res .dtype )
129
+ assert _dtype_all_close (dpt .asnumpy (res ), res_check )
86
130
87
- dev = q .sycl_device
131
+ # contiguous cases
132
+ x1 = dpt .full (cond .shape , 2 , dtype = dt1 , sycl_queue = q )
133
+ x2 = dpt .full (cond .shape , 3 , dtype = dt2 , sycl_queue = q )
134
+ res = dpt .where (cond , x1 , x2 )
135
+ assert _dtype_all_close (dpt .asnumpy (res ), res_check )
88
136
89
- if not dev .has_aspect_fp16 or not dev .has_aspect_fp64 :
90
- assert res .dtype .kind == dpt .result_type (x1 .dtype , x2 .dtype ).kind
91
137
138
+ @pytest .mark .parametrize ("dt1" , _all_dtypes )
139
+ @pytest .mark .parametrize ("dt2" , _all_dtypes )
140
+ def test_where_mask_dtypes (dt1 , dt2 ):
141
+ q = get_queue_or_skip ()
142
+ skip_if_dtype_not_supported (dt1 , q )
143
+ skip_if_dtype_not_supported (dt2 , q )
144
+
145
+ cond = dpt .asarray ([0 , 1 , 3 , 0 , 10 ], dtype = dt1 , sycl_queue = q )
146
+ x1 = dpt .asarray (2 , dtype = dt2 , sycl_queue = q )
147
+ x2 = dpt .asarray (3 , dtype = dt2 , sycl_queue = q )
148
+ res = dpt .where (cond , x1 , x2 )
149
+
150
+ res_check = np .asarray ([3 , 2 , 2 , 3 , 2 ], dtype = res .dtype )
151
+ assert _dtype_all_close (dpt .asnumpy (res ), res_check )
152
+
153
+ # contiguous cases
154
+ x1 = dpt .full (cond .shape , 2 , dtype = dt2 , sycl_queue = q )
155
+ x2 = dpt .full (cond .shape , 3 , dtype = dt2 , sycl_queue = q )
156
+ res = dpt .where (cond , x1 , x2 )
92
157
assert _dtype_all_close (dpt .asnumpy (res ), res_check )
93
158
94
159
@@ -116,12 +181,14 @@ def test_where_empty():
116
181
117
182
assert_array_equal (dpt .asnumpy (res ), res_np )
118
183
184
+ # check that broadcasting is performed
185
+ with pytest .raises (ValueError ):
186
+ dpt .where (empty , x1 , dpt .empty ((1 , 2 )))
187
+
119
188
120
- @pytest .mark .parametrize ("dt" , _all_dtypes )
121
189
@pytest .mark .parametrize ("order" , ["C" , "F" ])
122
- def test_where_contiguous (dt , order ):
123
- q = get_queue_or_skip ()
124
- skip_if_dtype_not_supported (dt , q )
190
+ def test_where_contiguous (order ):
191
+ get_queue_or_skip ()
125
192
126
193
cond = dpt .asarray (
127
194
[
@@ -131,14 +198,81 @@ def test_where_contiguous(dt, order):
131
198
[[False , False , False ], [True , False , True ]],
132
199
[[True , True , True ], [True , False , True ]],
133
200
],
134
- sycl_queue = q ,
135
201
order = order ,
136
202
)
137
203
138
- x1 = dpt .full (cond .shape , 2 , dtype = dt , order = order , sycl_queue = q )
139
- x2 = dpt .full (cond .shape , 3 , dtype = dt , order = order , sycl_queue = q )
204
+ x1 = dpt .full (cond .shape , 2 , order = order )
205
+ x2 = dpt .full (cond .shape , 3 , order = order )
206
+ expected = np .where (dpt .asnumpy (cond ), dpt .asnumpy (x1 ), dpt .asnumpy (x2 ))
207
+ res = dpt .where (cond , x1 , x2 )
208
+
209
+ assert _dtype_all_close (dpt .asnumpy (res ), expected )
210
+
211
+
212
+ def test_where_contiguous1D ():
213
+ get_queue_or_skip ()
214
+
215
+ cond = dpt .asarray ([True , False , True , False , False , True ])
216
+
217
+ x1 = dpt .full (cond .shape , 2 )
218
+ x2 = dpt .full (cond .shape , 3 )
219
+ expected = np .where (dpt .asnumpy (cond ), dpt .asnumpy (x1 ), dpt .asnumpy (x2 ))
220
+ res = dpt .where (cond , x1 , x2 )
221
+ assert _dtype_all_close (dpt .asnumpy (res ), expected )
222
+
223
+ # test with complex dtype (branch in kernel)
224
+ x1 = dpt .astype (x1 , dpt .complex64 )
225
+ x2 = dpt .astype (x2 , dpt .complex64 )
226
+ expected = np .where (dpt .asnumpy (cond ), dpt .asnumpy (x1 ), dpt .asnumpy (x2 ))
227
+ res = dpt .where (cond , x1 , x2 )
228
+ assert _dtype_all_close (dpt .asnumpy (res ), expected )
229
+
230
+
231
+ def test_where_strided ():
232
+ get_queue_or_skip ()
233
+
234
+ s0 , s1 = 4 , 9
235
+ cond = dpt .reshape (
236
+ dpt .asarray (
237
+ [True , False , False , False , True , True , False , True , False ] * s0
238
+ ),
239
+ (s0 , s1 ),
240
+ )[:, ::3 ]
140
241
242
+ x1 = dpt .ones ((cond .shape [0 ], cond .shape [1 ] * 2 ))[:, ::2 ]
243
+ x2 = dpt .zeros ((cond .shape [0 ], cond .shape [1 ] * 3 ))[:, ::3 ]
141
244
expected = np .where (dpt .asnumpy (cond ), dpt .asnumpy (x1 ), dpt .asnumpy (x2 ))
142
245
res = dpt .where (cond , x1 , x2 )
143
246
144
247
assert _dtype_all_close (dpt .asnumpy (res ), expected )
248
+
249
+
250
+ def test_where_arg_validation ():
251
+ get_queue_or_skip ()
252
+
253
+ check = dict ()
254
+ x1 = dpt .empty ((1 ,))
255
+ x2 = dpt .empty ((1 ,))
256
+
257
+ with pytest .raises (TypeError ):
258
+ dpt .where (check , x1 , x2 )
259
+ with pytest .raises (TypeError ):
260
+ dpt .where (x1 , check , x2 )
261
+ with pytest .raises (TypeError ):
262
+ dpt .where (x1 , x2 , check )
263
+
264
+
265
+ def test_where_compute_follows_data ():
266
+ q1 = get_queue_or_skip ()
267
+ q2 = get_queue_or_skip ()
268
+ q3 = get_queue_or_skip ()
269
+
270
+ x1 = dpt .empty ((1 ,), sycl_queue = q1 )
271
+ x2 = dpt .empty ((1 ,), sycl_queue = q2 )
272
+
273
+ with pytest .raises (ExecutionPlacementError ):
274
+ dpt .where (dpt .empty ((1 ,), sycl_queue = q1 ), x1 , x2 )
275
+ with pytest .raises (ExecutionPlacementError ):
276
+ dpt .where (dpt .empty ((1 ,), sycl_queue = q3 ), x1 , x2 )
277
+ with pytest .raises (ExecutionPlacementError ):
278
+ dpt .where (x1 , x1 , x2 )
0 commit comments