Skip to content

Commit c95040b

Browse files
committed
WIP
1 parent 0f4bd65 commit c95040b

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

array_api_tests/test_manipulation_functions.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,34 @@ def test_expand_dims(x, axis):
167167
)
168168
)
169169
def test_expand_dims_tuples(x, axes):
170-
print(x.shape, axes)
170+
x_ndim, y_ndim = x.ndim, x.ndim + len(axes)
171+
172+
# normalize
173+
n_axes = tuple(ax + y_ndim if ax < 0 else ax for ax in axes)
174+
unique_axes = set(n_axes)
175+
176+
# print(x.shape, axes, n_axes)
177+
178+
if any(ax < 0 or ax >= y_ndim for ax in n_axes) or len(n_axes) != len(unique_axes):
179+
# print("\t raises")
180+
with pytest.raises((IndexError, ValueError)):
181+
xp.expand_dims(x, axis=axes)
182+
return
183+
184+
185+
repro_snippet = ph.format_snippet(f"xp.expand_dims({x!r}, axis={axes!r})")
186+
try:
187+
y = xp.expand_dims(x, axis=axes)
188+
189+
ye = x
190+
for ax in sorted(n_axes):
191+
ye = xp.expand_dims(ye, axis=ax)
192+
193+
assert y.shape == ye.shape
194+
195+
except Exception as exc:
196+
ph.add_note(exc, repro_snippet)
197+
raise
171198

172199

173200
@pytest.mark.min_version("2023.12")

0 commit comments

Comments
 (0)