39
39
40
40
41
41
def test_where_basic ():
42
- get_queue_or_skip
42
+ get_queue_or_skip ()
43
43
44
44
cond = dpt .asarray (
45
45
[
@@ -58,27 +58,87 @@ def test_where_basic():
58
58
assert (dpt .asnumpy (out ) == dpt .asnumpy (out_expected )).all ()
59
59
60
60
61
+ def _dtype_all_close (x1 , x2 ):
62
+ if np .issubdtype (x2 .dtype , np .floating ) or np .issubdtype (
63
+ x2 .dtype , np .complexfloating
64
+ ):
65
+ x2_dtype = x2 .dtype
66
+ return np .allclose (
67
+ x1 , x2 , atol = np .finfo (x2_dtype ).eps , rtol = np .finfo (x2_dtype ).eps
68
+ )
69
+ else :
70
+ return np .allclose (x1 , x2 )
71
+
72
+
61
73
@pytest .mark .parametrize ("dt1" , _all_dtypes )
62
74
@pytest .mark .parametrize ("dt2" , _all_dtypes )
63
75
def test_where_all_dtypes (dt1 , dt2 ):
64
76
q = get_queue_or_skip ()
65
77
skip_if_dtype_not_supported (dt1 , q )
66
78
skip_if_dtype_not_supported (dt2 , q )
67
79
68
- cond_np = np .arange (5 ) > 2
69
- x1_np = np .asarray (2 , dtype = dt1 )
70
- x2_np = np .asarray (3 , dtype = dt2 )
71
-
72
- cond = dpt .asarray (cond_np , sycl_queue = q )
73
- x1 = dpt .asarray (x1_np , sycl_queue = q )
74
- x2 = dpt .asarray (x2_np , sycl_queue = q )
80
+ cond = dpt .asarray ([False , False , False , True , True ], sycl_queue = q )
81
+ x1 = dpt .asarray (2 , sycl_queue = q )
82
+ x2 = dpt .asarray (3 , sycl_queue = q )
75
83
76
84
res = dpt .where (cond , x1 , x2 )
77
- res_np = np .where ( cond_np , x1_np , x2_np )
85
+ res_check = np .asarray ([ 3 , 3 , 3 , 2 , 2 ], dtype = res . dtype )
78
86
79
- if res .dtype != res_np .dtype :
80
- assert res .dtype .kind == res_np .dtype .kind
81
- assert_array_equal (dpt .asnumpy (res ).astype (res_np .dtype ), res_np )
87
+ dev = q .sycl_device
82
88
83
- else :
84
- assert_array_equal (dpt .asnumpy (res ), res_np )
89
+ if not dev .has_aspect_fp16 or not dev .has_aspect_fp64 :
90
+ assert res .dtype .kind == dpt .result_type (x1 .dtype , x2 .dtype ).kind
91
+
92
+ assert _dtype_all_close (dpt .asnumpy (res ), res_check )
93
+
94
+
95
+ def test_where_empty ():
96
+ # check that numpy returns same results when
97
+ # handling empty arrays
98
+ get_queue_or_skip ()
99
+
100
+ empty = dpt .empty (0 )
101
+ m = dpt .asarray (True )
102
+ x1 = dpt .asarray (1 )
103
+ x2 = dpt .asarray (2 )
104
+ res = dpt .where (empty , x1 , x2 )
105
+
106
+ empty_np = np .empty (0 )
107
+ m_np = dpt .asnumpy (m )
108
+ x1_np = dpt .asnumpy (x1 )
109
+ x2_np = dpt .asnumpy (x2 )
110
+ res_np = np .where (empty_np , x1_np , x2_np )
111
+
112
+ assert_array_equal (dpt .asnumpy (res ), res_np )
113
+
114
+ res = dpt .where (m , empty , x2 )
115
+ res_np = np .where (m_np , empty_np , x2_np )
116
+
117
+ assert_array_equal (dpt .asnumpy (res ), res_np )
118
+
119
+
120
+ @pytest .mark .parametrize ("dt" , _all_dtypes )
121
+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
122
+ def test_where_contiguous (dt , order ):
123
+ q = get_queue_or_skip ()
124
+ skip_if_dtype_not_supported (dt , q )
125
+
126
+ cond = dpt .asarray (
127
+ [
128
+ [[True , False , False ], [False , True , True ]],
129
+ [[False , True , False ], [True , False , True ]],
130
+ [[False , False , True ], [False , False , True ]],
131
+ [[False , False , False ], [True , False , True ]],
132
+ [[True , True , True ], [True , False , True ]],
133
+ ],
134
+ sycl_queue = q ,
135
+ order = order ,
136
+ )
137
+
138
+ x1 = dpt .full (cond .shape , 2 , dtype = dt , order = order , sycl_queue = q )
139
+ x2 = dpt .full (cond .shape , 3 , dtype = dt , order = order , sycl_queue = q )
140
+
141
+ expected = np .where (dpt .asnumpy (cond ), dpt .asnumpy (x1 ), dpt .asnumpy (x2 ))
142
+ res = dpt .where (cond , x1 , x2 )
143
+
144
+ assert _dtype_all_close (dpt .asnumpy (res ), expected )
0 commit comments