Skip to content

Commit 02288b4

Browse files
max-sixtydcherian
authored andcommitted
Allow ellipsis (...) in transpose (#3421)
* infix_dims function * implement transpose with ellipsis * also infix in dataarray * check errors centrally, remove boilerplate from transpose methods * whatsnew * docs * remove old comments * generator->iterator * test for differently ordered dimensions
1 parent fb0cf7b commit 02288b4

File tree

12 files changed

+100
-12
lines changed

12 files changed

+100
-12
lines changed

doc/reshaping.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@ Reordering dimensions
1818
---------------------
1919

2020
To reorder dimensions on a :py:class:`~xarray.DataArray` or across all variables
21-
on a :py:class:`~xarray.Dataset`, use :py:meth:`~xarray.DataArray.transpose`:
21+
on a :py:class:`~xarray.Dataset`, use :py:meth:`~xarray.DataArray.transpose`. An
22+
ellipsis (`...`) can be use to represent all other dimensions:
2223

2324
.. ipython:: python
2425
2526
ds = xr.Dataset({'foo': (('x', 'y', 'z'), [[[42]]]), 'bar': (('y', 'z'), [[24]])})
2627
ds.transpose('y', 'z', 'x')
28+
ds.transpose(..., 'x') # equivalent
2729
ds.transpose() # reverses all dimensions
2830
2931
Expand and squeeze dimensions

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ Breaking changes
2525

2626
New Features
2727
~~~~~~~~~~~~
28+
- :py:meth:`Dataset.transpose` and :py:meth:`DataArray.transpose` now support an ellipsis (`...`)
29+
to represent all 'other' dimensions. For example, to move one dimension to the front,
30+
use `.transpose('x', ...)`. (:pull:`3421`)
31+
By `Maximilian Roos <https://github.com/max-sixty>`_
2832
- Changed `xr.ALL_DIMS` to equal python's `Ellipsis` (`...`), and changed internal usages to use
2933
`...` directly. As before, you can use this to instruct a `groupby` operation
3034
to reduce over all dimensions. While we have no plans to remove `xr.ALL_DIMS`, we suggest

setup.cfg

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,4 +117,7 @@ tag_prefix = v
117117
parentdir_prefix = xarray-
118118

119119
[aliases]
120-
test = pytest
120+
test = pytest
121+
122+
[pytest-watch]
123+
nobeep = True

xarray/core/dataarray.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,12 +1863,7 @@ def transpose(self, *dims: Hashable, transpose_coords: bool = None) -> "DataArra
18631863
Dataset.transpose
18641864
"""
18651865
if dims:
1866-
if set(dims) ^ set(self.dims):
1867-
raise ValueError(
1868-
"arguments to transpose (%s) must be "
1869-
"permuted array dimensions (%s)" % (dims, tuple(self.dims))
1870-
)
1871-
1866+
dims = tuple(utils.infix_dims(dims, self.dims))
18721867
variable = self.variable.transpose(*dims)
18731868
if transpose_coords:
18741869
coords: Dict[Hashable, Variable] = {}

xarray/core/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3712,14 +3712,14 @@ def transpose(self, *dims: Hashable) -> "Dataset":
37123712
DataArray.transpose
37133713
"""
37143714
if dims:
3715-
if set(dims) ^ set(self.dims):
3715+
if set(dims) ^ set(self.dims) and ... not in dims:
37163716
raise ValueError(
37173717
"arguments to transpose (%s) must be "
37183718
"permuted dataset dimensions (%s)" % (dims, tuple(self.dims))
37193719
)
37203720
ds = self.copy()
37213721
for name, var in self._variables.items():
3722-
var_dims = tuple(dim for dim in dims if dim in var.dims)
3722+
var_dims = tuple(dim for dim in dims if dim in (var.dims + (...,)))
37233723
ds._variables[name] = var.transpose(*var_dims)
37243724
return ds
37253725

xarray/core/utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
AbstractSet,
1111
Any,
1212
Callable,
13+
Collection,
1314
Container,
1415
Dict,
1516
Hashable,
@@ -660,6 +661,30 @@ def __len__(self) -> int:
660661
return len(self._data) - num_hidden
661662

662663

664+
def infix_dims(dims_supplied: Collection, dims_all: Collection) -> Iterator:
665+
"""
666+
Resolves a supplied list containing an ellispsis representing other items, to
667+
a generator with the 'realized' list of all items
668+
"""
669+
if ... in dims_supplied:
670+
if len(set(dims_all)) != len(dims_all):
671+
raise ValueError("Cannot use ellipsis with repeated dims")
672+
if len([d for d in dims_supplied if d == ...]) > 1:
673+
raise ValueError("More than one ellipsis supplied")
674+
other_dims = [d for d in dims_all if d not in dims_supplied]
675+
for d in dims_supplied:
676+
if d == ...:
677+
yield from other_dims
678+
else:
679+
yield d
680+
else:
681+
if set(dims_supplied) ^ set(dims_all):
682+
raise ValueError(
683+
f"{dims_supplied} must be a permuted list of {dims_all}, unless `...` is included"
684+
)
685+
yield from dims_supplied
686+
687+
663688
def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable:
664689
""" Get an new dimension name based on new_dim, that is not used in dims.
665690
If the same name exists, we add an underscore(s) in the head.

xarray/core/variable.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
OrderedSet,
2626
decode_numpy_dict_values,
2727
either_dict_or_kwargs,
28+
infix_dims,
2829
ensure_us_time_resolution,
2930
)
3031

@@ -1228,6 +1229,7 @@ def transpose(self, *dims) -> "Variable":
12281229
"""
12291230
if len(dims) == 0:
12301231
dims = self.dims[::-1]
1232+
dims = tuple(infix_dims(dims, self.dims))
12311233
axes = self.get_axis_num(dims)
12321234
if len(dims) < 2: # no need to transpose if only one dimension
12331235
return self.copy(deep=False)

xarray/tests/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,18 +158,21 @@ def source_ndarray(array):
158158

159159

160160
def assert_equal(a, b):
161+
__tracebackhide__ = True
161162
xarray.testing.assert_equal(a, b)
162163
xarray.testing._assert_internal_invariants(a)
163164
xarray.testing._assert_internal_invariants(b)
164165

165166

166167
def assert_identical(a, b):
168+
__tracebackhide__ = True
167169
xarray.testing.assert_identical(a, b)
168170
xarray.testing._assert_internal_invariants(a)
169171
xarray.testing._assert_internal_invariants(b)
170172

171173

172174
def assert_allclose(a, b, **kwargs):
175+
__tracebackhide__ = True
173176
xarray.testing.assert_allclose(a, b, **kwargs)
174177
xarray.testing._assert_internal_invariants(a)
175178
xarray.testing._assert_internal_invariants(b)

xarray/tests/test_dataarray.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2068,6 +2068,10 @@ def test_transpose(self):
20682068
)
20692069
assert_equal(expected, actual)
20702070

2071+
# same as previous but with ellipsis
2072+
actual = da.transpose("z", ..., "x", transpose_coords=True)
2073+
assert_equal(expected, actual)
2074+
20712075
with pytest.raises(ValueError):
20722076
da.transpose("x", "y")
20732077

xarray/tests/test_dataset.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4675,6 +4675,10 @@ def test_dataset_transpose(self):
46754675
)
46764676
assert_identical(expected, actual)
46774677

4678+
actual = ds.transpose(...)
4679+
expected = ds
4680+
assert_identical(expected, actual)
4681+
46784682
actual = ds.transpose("x", "y")
46794683
expected = ds.apply(lambda x: x.transpose("x", "y", transpose_coords=True))
46804684
assert_identical(expected, actual)
@@ -4690,13 +4694,32 @@ def test_dataset_transpose(self):
46904694
expected_dims = tuple(d for d in new_order if d in ds[k].dims)
46914695
assert actual[k].dims == expected_dims
46924696

4693-
with raises_regex(ValueError, "arguments to transpose"):
4697+
# same as above but with ellipsis
4698+
new_order = ("dim2", "dim3", "dim1", "time")
4699+
actual = ds.transpose("dim2", "dim3", ...)
4700+
for k in ds.variables:
4701+
expected_dims = tuple(d for d in new_order if d in ds[k].dims)
4702+
assert actual[k].dims == expected_dims
4703+
4704+
with raises_regex(ValueError, "permuted"):
46944705
ds.transpose("dim1", "dim2", "dim3")
4695-
with raises_regex(ValueError, "arguments to transpose"):
4706+
with raises_regex(ValueError, "permuted"):
46964707
ds.transpose("dim1", "dim2", "dim3", "time", "extra_dim")
46974708

46984709
assert "T" not in dir(ds)
46994710

4711+
def test_dataset_ellipsis_transpose_different_ordered_vars(self):
4712+
# https://github.com/pydata/xarray/issues/1081#issuecomment-544350457
4713+
ds = Dataset(
4714+
dict(
4715+
a=(("w", "x", "y", "z"), np.ones((2, 3, 4, 5))),
4716+
b=(("x", "w", "y", "z"), np.zeros((3, 2, 4, 5))),
4717+
)
4718+
)
4719+
result = ds.transpose(..., "z", "y")
4720+
assert list(result["a"].dims) == list("wxzy")
4721+
assert list(result["b"].dims) == list("xwzy")
4722+
47004723
def test_dataset_retains_period_index_on_transpose(self):
47014724

47024725
ds = create_test_data()

xarray/tests/test_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,3 +275,27 @@ def test_either_dict_or_kwargs():
275275

276276
with pytest.raises(ValueError, match=r"foo"):
277277
result = either_dict_or_kwargs(dict(a=1), dict(a=1), "foo")
278+
279+
280+
@pytest.mark.parametrize(
281+
["supplied", "all_", "expected"],
282+
[
283+
(list("abc"), list("abc"), list("abc")),
284+
(["a", ..., "c"], list("abc"), list("abc")),
285+
(["a", ...], list("abc"), list("abc")),
286+
(["c", ...], list("abc"), list("cab")),
287+
([..., "b"], list("abc"), list("acb")),
288+
([...], list("abc"), list("abc")),
289+
],
290+
)
291+
def test_infix_dims(supplied, all_, expected):
292+
result = list(utils.infix_dims(supplied, all_))
293+
assert result == expected
294+
295+
296+
@pytest.mark.parametrize(
297+
["supplied", "all_"], [([..., ...], list("abc")), ([...], list("aac"))]
298+
)
299+
def test_infix_dims_errors(supplied, all_):
300+
with pytest.raises(ValueError):
301+
list(utils.infix_dims(supplied, all_))

xarray/tests/test_variable.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,6 +1280,9 @@ def test_transpose(self):
12801280
w2 = Variable(["d", "b", "c", "a"], np.einsum("abcd->dbca", x))
12811281
assert w2.shape == (5, 3, 4, 2)
12821282
assert_identical(w2, w.transpose("d", "b", "c", "a"))
1283+
assert_identical(w2, w.transpose("d", ..., "a"))
1284+
assert_identical(w2, w.transpose("d", "b", "c", ...))
1285+
assert_identical(w2, w.transpose(..., "b", "c", "a"))
12831286
assert_identical(w, w2.transpose("a", "b", "c", "d"))
12841287
w3 = Variable(["b", "c", "d", "a"], np.einsum("abcd->bcda", x))
12851288
assert_identical(w, w3.transpose("a", "b", "c", "d"))

0 commit comments

Comments
 (0)