Skip to content

Commit aa7071f

Browse files
Implementation of squeeze function (#790)
* Add squeeze func * Add tests for squeeze func
1 parent e30e7a4 commit aa7071f

File tree

3 files changed

+139
-1
lines changed

3 files changed

+139
-1
lines changed

dpctl/tensor/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
from dpctl.tensor._ctors import asarray, empty
2626
from dpctl.tensor._device import Device
2727
from dpctl.tensor._dlpack import from_dlpack
28-
from dpctl.tensor._manipulation_functions import expand_dims, permute_dims
28+
from dpctl.tensor._manipulation_functions import (
29+
expand_dims,
30+
permute_dims,
31+
squeeze,
32+
)
2933
from dpctl.tensor._reshape import reshape
3034
from dpctl.tensor._usmarray import usm_ndarray
3135

@@ -39,6 +43,7 @@
3943
"reshape",
4044
"permute_dims",
4145
"expand_dims",
46+
"squeeze",
4247
"from_numpy",
4348
"to_numpy",
4449
"asnumpy",

dpctl/tensor/_manipulation_functions.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,37 @@ def expand_dims(X, axes):
6868
shape = tuple(1 if ax in axes else next(shape_it) for ax in range(out_ndim))
6969

7070
return dpt.reshape(X, shape)
71+
72+
73+
def squeeze(X, axes=None):
74+
"""
75+
squeeze(X: usm_ndarray, axes: int or tuple or list) -> usm_ndarray
76+
77+
Removes singleton dimensions (axes) from X; returns a view, if possible,
78+
a copy otherwise, but with all or a subset of the dimensions
79+
of length 1 removed.
80+
"""
81+
if not isinstance(X, dpt.usm_ndarray):
82+
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
83+
X_shape = X.shape
84+
if axes is not None:
85+
if not isinstance(axes, (tuple, list)):
86+
axes = (axes,)
87+
axes = normalize_axis_tuple(axes, X.ndim if X.ndim != 0 else X.ndim + 1)
88+
new_shape = []
89+
for i, x in enumerate(X_shape):
90+
if i not in axes:
91+
new_shape.append(x)
92+
else:
93+
if x != 1:
94+
raise ValueError(
95+
"Cannot select an axis to squeeze out "
96+
"which has size not equal to one."
97+
)
98+
new_shape = tuple(new_shape)
99+
else:
100+
new_shape = tuple(axis for axis in X_shape if axis != 1)
101+
if new_shape == X.shape:
102+
return X
103+
else:
104+
return dpt.reshape(X, new_shape)

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,102 @@ def test_expand_dims_incorrect_tuple():
166166
pytest.raises(np.AxisError, dpt.expand_dims, X, (0, 5))
167167

168168
pytest.raises(ValueError, dpt.expand_dims, X, (1, 1))
169+
170+
171+
def test_squeeze_incorrect_type():
172+
X_list = list([1, 2, 3, 4, 5])
173+
X_tuple = tuple(X_list)
174+
Xnp = np.array(X_list)
175+
176+
pytest.raises(TypeError, dpt.permute_dims, X_list, 1)
177+
pytest.raises(TypeError, dpt.permute_dims, X_tuple, 1)
178+
pytest.raises(TypeError, dpt.permute_dims, Xnp, 1)
179+
180+
181+
def test_squeeze_0d():
182+
try:
183+
q = dpctl.SyclQueue()
184+
except dpctl.SyclQueueCreationError:
185+
pytest.skip("Queue could not be created")
186+
187+
Xnp = np.array(1)
188+
X = dpt.asarray(Xnp, sycl_queue=q)
189+
Y = dpt.squeeze(X)
190+
Ynp = Xnp.squeeze()
191+
assert_array_equal(Ynp, dpt.asnumpy(Y))
192+
193+
Y = dpt.squeeze(X, 0)
194+
Ynp = Xnp.squeeze(0)
195+
assert_array_equal(Ynp, dpt.asnumpy(Y))
196+
197+
Y = dpt.squeeze(X, (0))
198+
Ynp = Xnp.squeeze((0))
199+
assert_array_equal(Ynp, dpt.asnumpy(Y))
200+
201+
Y = dpt.squeeze(X, -1)
202+
Ynp = Xnp.squeeze(-1)
203+
assert_array_equal(Ynp, dpt.asnumpy(Y))
204+
205+
pytest.raises(np.AxisError, dpt.squeeze, X, 1)
206+
pytest.raises(np.AxisError, dpt.squeeze, X, -2)
207+
pytest.raises(np.AxisError, dpt.squeeze, X, (1))
208+
pytest.raises(np.AxisError, dpt.squeeze, X, (-2))
209+
pytest.raises(ValueError, dpt.squeeze, X, (0, 0))
210+
211+
212+
@pytest.mark.parametrize(
213+
"shapes",
214+
[
215+
(0),
216+
(1),
217+
(1, 2),
218+
(2, 1),
219+
(1, 1),
220+
(2, 2),
221+
(1, 0),
222+
(0, 1),
223+
(1, 2, 1),
224+
(2, 1, 2),
225+
(2, 2, 2),
226+
(1, 1, 1),
227+
(1, 0, 1),
228+
(0, 1, 0),
229+
],
230+
)
231+
def test_squeeze_without_axes(shapes):
232+
try:
233+
q = dpctl.SyclQueue()
234+
except dpctl.SyclQueueCreationError:
235+
pytest.skip("Queue could not be created")
236+
237+
Xnp = np.empty(shapes)
238+
X = dpt.asarray(Xnp, sycl_queue=q)
239+
Y = dpt.squeeze(X)
240+
Ynp = Xnp.squeeze()
241+
assert_array_equal(Ynp, dpt.asnumpy(Y))
242+
243+
244+
@pytest.mark.parametrize("axes", [0, 2, (0), (2), (0, 2)])
245+
def test_squeeze_axes_arg(axes):
246+
try:
247+
q = dpctl.SyclQueue()
248+
except dpctl.SyclQueueCreationError:
249+
pytest.skip("Queue could not be created")
250+
251+
Xnp = np.array([[[1], [2], [3]]])
252+
X = dpt.asarray(Xnp, sycl_queue=q)
253+
Y = dpt.squeeze(X, axes)
254+
Ynp = Xnp.squeeze(axes)
255+
assert_array_equal(Ynp, dpt.asnumpy(Y))
256+
257+
258+
@pytest.mark.parametrize("axes", [1, -2, (1), (-2), (0, 0), (1, 1)])
259+
def test_squeeze_axes_arg_error(axes):
260+
try:
261+
q = dpctl.SyclQueue()
262+
except dpctl.SyclQueueCreationError:
263+
pytest.skip("Queue could not be created")
264+
265+
Xnp = np.array([[[1], [2], [3]]])
266+
X = dpt.asarray(Xnp, sycl_queue=q)
267+
pytest.raises(ValueError, dpt.squeeze, X, axes)

0 commit comments

Comments
 (0)