Skip to content

Commit 0183179

Browse files
Add tests for squeeze func
1 parent feda1b1 commit 0183179

File tree

1 file changed

+99
-0
lines changed

1 file changed

+99
-0
lines changed

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,102 @@ def test_expand_dims_incorrect_tuple():
166166
pytest.raises(np.AxisError, dpt.expand_dims, X, (0, 5))
167167

168168
pytest.raises(ValueError, dpt.expand_dims, X, (1, 1))
169+
170+
171+
def test_squeeze_incorrect_type():
172+
X_list = list([1, 2, 3, 4, 5])
173+
X_tuple = tuple(X_list)
174+
Xnp = np.array(X_list)
175+
176+
pytest.raises(TypeError, dpt.permute_dims, X_list, 1)
177+
pytest.raises(TypeError, dpt.permute_dims, X_tuple, 1)
178+
pytest.raises(TypeError, dpt.permute_dims, Xnp, 1)
179+
180+
181+
def test_squeeze_0d():
182+
try:
183+
q = dpctl.SyclQueue()
184+
except dpctl.SyclQueueCreationError:
185+
pytest.skip("Queue could not be created")
186+
187+
Xnp = np.array(1)
188+
X = dpt.asarray(Xnp, sycl_queue=q)
189+
Y = dpt.squeeze(X)
190+
Ynp = Xnp.squeeze()
191+
assert_array_equal(Ynp, dpt.asnumpy(Y))
192+
193+
Y = dpt.squeeze(X, 0)
194+
Ynp = Xnp.squeeze(0)
195+
assert_array_equal(Ynp, dpt.asnumpy(Y))
196+
197+
Y = dpt.squeeze(X, (0))
198+
Ynp = Xnp.squeeze((0))
199+
assert_array_equal(Ynp, dpt.asnumpy(Y))
200+
201+
Y = dpt.squeeze(X, -1)
202+
Ynp = Xnp.squeeze(-1)
203+
assert_array_equal(Ynp, dpt.asnumpy(Y))
204+
205+
pytest.raises(np.AxisError, dpt.squeeze, X, 1)
206+
pytest.raises(np.AxisError, dpt.squeeze, X, -2)
207+
pytest.raises(np.AxisError, dpt.squeeze, X, (1))
208+
pytest.raises(np.AxisError, dpt.squeeze, X, (-2))
209+
pytest.raises(ValueError, dpt.squeeze, X, (0, 0))
210+
211+
212+
@pytest.mark.parametrize(
213+
"shapes",
214+
[
215+
(0),
216+
(1),
217+
(1, 2),
218+
(2, 1),
219+
(1, 1),
220+
(2, 2),
221+
(1, 0),
222+
(0, 1),
223+
(1, 2, 1),
224+
(2, 1, 2),
225+
(2, 2, 2),
226+
(1, 1, 1),
227+
(1, 0, 1),
228+
(0, 1, 0),
229+
],
230+
)
231+
def test_squeeze_without_axes(shapes):
232+
try:
233+
q = dpctl.SyclQueue()
234+
except dpctl.SyclQueueCreationError:
235+
pytest.skip("Queue could not be created")
236+
237+
Xnp = np.empty(shapes)
238+
X = dpt.asarray(Xnp, sycl_queue=q)
239+
Y = dpt.squeeze(X)
240+
Ynp = Xnp.squeeze()
241+
assert_array_equal(Ynp, dpt.asnumpy(Y))
242+
243+
244+
@pytest.mark.parametrize("axes", [0, 2, (0), (2), (0, 2)])
245+
def test_squeeze_axes_arg(axes):
246+
try:
247+
q = dpctl.SyclQueue()
248+
except dpctl.SyclQueueCreationError:
249+
pytest.skip("Queue could not be created")
250+
251+
Xnp = np.array([[[1], [2], [3]]])
252+
X = dpt.asarray(Xnp, sycl_queue=q)
253+
Y = dpt.squeeze(X, axes)
254+
Ynp = Xnp.squeeze(axes)
255+
assert_array_equal(Ynp, dpt.asnumpy(Y))
256+
257+
258+
@pytest.mark.parametrize("axes", [1, -2, (1), (-2), (0, 0), (1, 1)])
259+
def test_squeeze_axes_arg_error(axes):
260+
try:
261+
q = dpctl.SyclQueue()
262+
except dpctl.SyclQueueCreationError:
263+
pytest.skip("Queue could not be created")
264+
265+
Xnp = np.array([[[1], [2], [3]]])
266+
X = dpt.asarray(Xnp, sycl_queue=q)
267+
pytest.raises(ValueError, dpt.squeeze, X, axes)

0 commit comments

Comments
 (0)