@@ -166,3 +166,102 @@ def test_expand_dims_incorrect_tuple():
166
166
pytest .raises (np .AxisError , dpt .expand_dims , X , (0 , 5 ))
167
167
168
168
pytest .raises (ValueError , dpt .expand_dims , X , (1 , 1 ))
169
+
170
+
171
+ def test_squeeze_incorrect_type ():
172
+ X_list = list ([1 , 2 , 3 , 4 , 5 ])
173
+ X_tuple = tuple (X_list )
174
+ Xnp = np .array (X_list )
175
+
176
+ pytest .raises (TypeError , dpt .permute_dims , X_list , 1 )
177
+ pytest .raises (TypeError , dpt .permute_dims , X_tuple , 1 )
178
+ pytest .raises (TypeError , dpt .permute_dims , Xnp , 1 )
179
+
180
+
181
+ def test_squeeze_0d ():
182
+ try :
183
+ q = dpctl .SyclQueue ()
184
+ except dpctl .SyclQueueCreationError :
185
+ pytest .skip ("Queue could not be created" )
186
+
187
+ Xnp = np .array (1 )
188
+ X = dpt .asarray (Xnp , sycl_queue = q )
189
+ Y = dpt .squeeze (X )
190
+ Ynp = Xnp .squeeze ()
191
+ assert_array_equal (Ynp , dpt .asnumpy (Y ))
192
+
193
+ Y = dpt .squeeze (X , 0 )
194
+ Ynp = Xnp .squeeze (0 )
195
+ assert_array_equal (Ynp , dpt .asnumpy (Y ))
196
+
197
+ Y = dpt .squeeze (X , (0 ))
198
+ Ynp = Xnp .squeeze ((0 ))
199
+ assert_array_equal (Ynp , dpt .asnumpy (Y ))
200
+
201
+ Y = dpt .squeeze (X , - 1 )
202
+ Ynp = Xnp .squeeze (- 1 )
203
+ assert_array_equal (Ynp , dpt .asnumpy (Y ))
204
+
205
+ pytest .raises (np .AxisError , dpt .squeeze , X , 1 )
206
+ pytest .raises (np .AxisError , dpt .squeeze , X , - 2 )
207
+ pytest .raises (np .AxisError , dpt .squeeze , X , (1 ))
208
+ pytest .raises (np .AxisError , dpt .squeeze , X , (- 2 ))
209
+ pytest .raises (ValueError , dpt .squeeze , X , (0 , 0 ))
210
+
211
+
212
+ @pytest .mark .parametrize (
213
+ "shapes" ,
214
+ [
215
+ (0 ),
216
+ (1 ),
217
+ (1 , 2 ),
218
+ (2 , 1 ),
219
+ (1 , 1 ),
220
+ (2 , 2 ),
221
+ (1 , 0 ),
222
+ (0 , 1 ),
223
+ (1 , 2 , 1 ),
224
+ (2 , 1 , 2 ),
225
+ (2 , 2 , 2 ),
226
+ (1 , 1 , 1 ),
227
+ (1 , 0 , 1 ),
228
+ (0 , 1 , 0 ),
229
+ ],
230
+ )
231
+ def test_squeeze_without_axes (shapes ):
232
+ try :
233
+ q = dpctl .SyclQueue ()
234
+ except dpctl .SyclQueueCreationError :
235
+ pytest .skip ("Queue could not be created" )
236
+
237
+ Xnp = np .empty (shapes )
238
+ X = dpt .asarray (Xnp , sycl_queue = q )
239
+ Y = dpt .squeeze (X )
240
+ Ynp = Xnp .squeeze ()
241
+ assert_array_equal (Ynp , dpt .asnumpy (Y ))
242
+
243
+
244
+ @pytest .mark .parametrize ("axes" , [0 , 2 , (0 ), (2 ), (0 , 2 )])
245
+ def test_squeeze_axes_arg (axes ):
246
+ try :
247
+ q = dpctl .SyclQueue ()
248
+ except dpctl .SyclQueueCreationError :
249
+ pytest .skip ("Queue could not be created" )
250
+
251
+ Xnp = np .array ([[[1 ], [2 ], [3 ]]])
252
+ X = dpt .asarray (Xnp , sycl_queue = q )
253
+ Y = dpt .squeeze (X , axes )
254
+ Ynp = Xnp .squeeze (axes )
255
+ assert_array_equal (Ynp , dpt .asnumpy (Y ))
256
+
257
+
258
+ @pytest .mark .parametrize ("axes" , [1 , - 2 , (1 ), (- 2 ), (0 , 0 ), (1 , 1 )])
259
+ def test_squeeze_axes_arg_error (axes ):
260
+ try :
261
+ q = dpctl .SyclQueue ()
262
+ except dpctl .SyclQueueCreationError :
263
+ pytest .skip ("Queue could not be created" )
264
+
265
+ Xnp = np .array ([[[1 ], [2 ], [3 ]]])
266
+ X = dpt .asarray (Xnp , sycl_queue = q )
267
+ pytest .raises (ValueError , dpt .squeeze , X , axes )
0 commit comments