Skip to content

Commit

Permalink
Closes #4647 DataArray transpose inconsistent with Dataset Ellipsis u…
Browse files Browse the repository at this point in the history
…sage (#4767)

- Add missing_dims parameter to transpose to mimic isel behavior
- Add missing_dims to infix_dims to make function consistent
across different methods.

Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com>
  • Loading branch information
mesejo and max-sixty authored Jan 5, 2021
1 parent 7298df0 commit 31d540f
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 26 deletions.
2 changes: 1 addition & 1 deletion doc/internals.rst
Original file line number Diff line number Diff line change
Expand Up @@ -230,4 +230,4 @@ re-open it directly with Zarr:
zgroup = zarr.open("rasm.zarr")
print(zgroup.tree())
dict(zgroup["Tair"].attrs)
dict(zgroup["Tair"].attrs)
2 changes: 1 addition & 1 deletion doc/plotting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -955,4 +955,4 @@ One can also make line plots with multidimensional coordinates. In this case, ``
f, ax = plt.subplots(2, 1)
da.plot.line(x="lon", hue="y", ax=ax[0])
@savefig plotting_example_2d_hue_xy.png
da.plot.line(x="lon", hue="x", ax=ax[1])
da.plot.line(x="lon", hue="x", ax=ax[1])
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ Bug fixes
- Fix a crash in orthogonal indexing on geographic coordinates with ``engine='cfgrib'`` (:issue:`4733` :pull:`4737`).
By `Alessandro Amici <https://github.com/alexamici>`_
- Limit number of data rows when printing large datasets. (:issue:`4736`, :pull:`4750`). By `Jimmy Westling <https://github.com/illviljan>`_.
- Add ``missing_dims`` parameter to transpose (:issue:`4647`, :pull:`4767`). By `Daniel Mesejo <https://github.com/mesejo>`_.

Documentation
~~~~~~~~~~~~~
Expand All @@ -76,6 +77,7 @@ Internal Changes
- Run the tests in parallel using pytest-xdist (:pull:`4694`).

By `Justus Magin <https://github.com/keewis>`_ and `Mathias Hauser <https://github.com/mathause>`_.

- Replace all usages of ``assert x.identical(y)`` with ``assert_identical(x, y)``
for clearer error messages.
(:pull:`4752`);
Expand Down
15 changes: 13 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2120,7 +2120,12 @@ def to_unstacked_dataset(self, dim, level=0):
# unstacked dataset
return Dataset(data_dict)

def transpose(self, *dims: Hashable, transpose_coords: bool = True) -> "DataArray":
def transpose(
self,
*dims: Hashable,
transpose_coords: bool = True,
missing_dims: str = "raise",
) -> "DataArray":
"""Return a new DataArray object with transposed dimensions.
Parameters
Expand All @@ -2130,6 +2135,12 @@ def transpose(self, *dims: Hashable, transpose_coords: bool = True) -> "DataArra
dimensions to this order.
transpose_coords : bool, default: True
If True, also transpose the coordinates of this DataArray.
missing_dims : {"raise", "warn", "ignore"}, default: "raise"
What to do if dimensions that should be selected from are not present in the
DataArray:
- "raise": raise an exception
- "warning": raise a warning, and ignore the missing dimensions
- "ignore": ignore the missing dimensions
Returns
-------
Expand All @@ -2148,7 +2159,7 @@ def transpose(self, *dims: Hashable, transpose_coords: bool = True) -> "DataArra
Dataset.transpose
"""
if dims:
dims = tuple(utils.infix_dims(dims, self.dims))
dims = tuple(utils.infix_dims(dims, self.dims, missing_dims))
variable = self.variable.transpose(*dims)
if transpose_coords:
coords: Dict[Hashable, Variable] = {}
Expand Down
64 changes: 55 additions & 9 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,28 +744,32 @@ def __len__(self) -> int:
return len(self._data) - num_hidden


def infix_dims(dims_supplied: Collection, dims_all: Collection) -> Iterator:
def infix_dims(
dims_supplied: Collection, dims_all: Collection, missing_dims: str = "raise"
) -> Iterator:
"""
Resolves a supplied list containing an ellispsis representing other items, to
Resolves a supplied list containing an ellipsis representing other items, to
a generator with the 'realized' list of all items
"""
if ... in dims_supplied:
if len(set(dims_all)) != len(dims_all):
raise ValueError("Cannot use ellipsis with repeated dims")
if len([d for d in dims_supplied if d == ...]) > 1:
if list(dims_supplied).count(...) > 1:
raise ValueError("More than one ellipsis supplied")
other_dims = [d for d in dims_all if d not in dims_supplied]
for d in dims_supplied:
if d == ...:
existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims)
for d in existing_dims:
if d is ...:
yield from other_dims
else:
yield d
else:
if set(dims_supplied) ^ set(dims_all):
existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims)
if set(existing_dims) ^ set(dims_all):
raise ValueError(
f"{dims_supplied} must be a permuted list of {dims_all}, unless `...` is included"
)
yield from dims_supplied
yield from existing_dims


def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable:
Expand Down Expand Up @@ -805,7 +809,7 @@ def drop_dims_from_indexers(
invalid = indexers.keys() - set(dims)
if invalid:
raise ValueError(
f"dimensions {invalid} do not exist. Expected one or more of {dims}"
f"Dimensions {invalid} do not exist. Expected one or more of {dims}"
)

return indexers
Expand All @@ -818,7 +822,7 @@ def drop_dims_from_indexers(
invalid = indexers.keys() - set(dims)
if invalid:
warnings.warn(
f"dimensions {invalid} do not exist. Expected one or more of {dims}"
f"Dimensions {invalid} do not exist. Expected one or more of {dims}"
)
for key in invalid:
indexers.pop(key)
Expand All @@ -834,6 +838,48 @@ def drop_dims_from_indexers(
)


def drop_missing_dims(
supplied_dims: Collection, dims: Collection, missing_dims: str
) -> Collection:
"""Depending on the setting of missing_dims, drop any dimensions from supplied_dims that
are not present in dims.
Parameters
----------
supplied_dims : dict
dims : sequence
missing_dims : {"raise", "warn", "ignore"}
"""

if missing_dims == "raise":
supplied_dims_set = set(val for val in supplied_dims if val is not ...)
invalid = supplied_dims_set - set(dims)
if invalid:
raise ValueError(
f"Dimensions {invalid} do not exist. Expected one or more of {dims}"
)

return supplied_dims

elif missing_dims == "warn":

invalid = set(supplied_dims) - set(dims)
if invalid:
warnings.warn(
f"Dimensions {invalid} do not exist. Expected one or more of {dims}"
)

return [val for val in supplied_dims if val in dims or val is ...]

elif missing_dims == "ignore":
return [val for val in supplied_dims if val in dims or val is ...]

else:
raise ValueError(
f"Unrecognised option {missing_dims} for missing_dims argument"
)


class UncachedAccessor:
"""Acts like a property, but on both classes and class instances
Expand Down
23 changes: 14 additions & 9 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,13 +797,13 @@ def test_isel(self):
assert_identical(self.dv[:3, :5], self.dv.isel(x=slice(3), y=slice(5)))
with raises_regex(
ValueError,
r"dimensions {'not_a_dim'} do not exist. Expected "
r"Dimensions {'not_a_dim'} do not exist. Expected "
r"one or more of \('x', 'y'\)",
):
self.dv.isel(not_a_dim=0)
with pytest.warns(
UserWarning,
match=r"dimensions {'not_a_dim'} do not exist. "
match=r"Dimensions {'not_a_dim'} do not exist. "
r"Expected one or more of \('x', 'y'\)",
):
self.dv.isel(not_a_dim=0, missing_dims="warn")
Expand Down Expand Up @@ -2231,9 +2231,21 @@ def test_transpose(self):
actual = da.transpose("z", ..., "x", transpose_coords=True)
assert_equal(expected, actual)

# same as previous but with a missing dimension
actual = da.transpose(
"z", "y", "x", "not_a_dim", transpose_coords=True, missing_dims="ignore"
)
assert_equal(expected, actual)

with pytest.raises(ValueError):
da.transpose("x", "y")

with pytest.raises(ValueError):
da.transpose("not_a_dim", "z", "x", ...)

with pytest.warns(UserWarning):
da.transpose("not_a_dim", "y", "x", ..., missing_dims="warn")

def test_squeeze(self):
assert_equal(self.dv.variable.squeeze(), self.dv.squeeze().variable)

Expand Down Expand Up @@ -6227,7 +6239,6 @@ def da_dask(seed=123):

@pytest.mark.parametrize("da", ("repeating_ints",), indirect=True)
def test_isin(da):

expected = DataArray(
np.asarray([[0, 0, 0], [1, 0, 0]]),
dims=list("yx"),
Expand Down Expand Up @@ -6277,7 +6288,6 @@ def test_coarsen_keep_attrs():

@pytest.mark.parametrize("da", (1, 2), indirect=True)
def test_rolling_iter(da):

rolling_obj = da.rolling(time=7)
rolling_obj_mean = rolling_obj.mean()

Expand Down Expand Up @@ -6452,7 +6462,6 @@ def test_rolling_construct(center, window):
@pytest.mark.parametrize("window", (1, 2, 3, 4))
@pytest.mark.parametrize("name", ("sum", "mean", "std", "max"))
def test_rolling_reduce(da, center, min_periods, window, name):

if min_periods is not None and window < min_periods:
min_periods = window

Expand Down Expand Up @@ -6491,7 +6500,6 @@ def test_rolling_reduce_nonnumeric(center, min_periods, window, name):


def test_rolling_count_correct():

da = DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time")

kwargs = [
Expand Down Expand Up @@ -6579,7 +6587,6 @@ def test_ndrolling_construct(center, fill_value):
],
)
def test_rolling_keep_attrs(funcname, argument):

attrs_da = {"da_attr": "test"}

data = np.linspace(10, 15, 100)
Expand Down Expand Up @@ -6623,7 +6630,6 @@ def test_rolling_keep_attrs(funcname, argument):


def test_rolling_keep_attrs_deprecated():

attrs_da = {"da_attr": "test"}

data = np.linspace(10, 15, 100)
Expand Down Expand Up @@ -6957,7 +6963,6 @@ def test_rolling_exp(da, dim, window_type, window):

@requires_numbagg
def test_rolling_exp_keep_attrs(da):

attrs = {"attrs": "da"}
da.attrs = attrs

Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,14 +1024,14 @@ def test_isel(self):
data.isel(not_a_dim=slice(0, 2))
with raises_regex(
ValueError,
r"dimensions {'not_a_dim'} do not exist. Expected "
r"Dimensions {'not_a_dim'} do not exist. Expected "
r"one or more of "
r"[\w\W]*'time'[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*",
):
data.isel(not_a_dim=slice(0, 2))
with pytest.warns(
UserWarning,
match=r"dimensions {'not_a_dim'} do not exist. "
match=r"Dimensions {'not_a_dim'} do not exist. "
r"Expected one or more of "
r"[\w\W]*'time'[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*",
):
Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,13 +1270,13 @@ def test_isel(self):
assert_identical(v.isel(time=[]), v[[]])
with raises_regex(
ValueError,
r"dimensions {'not_a_dim'} do not exist. Expected one or more of "
r"Dimensions {'not_a_dim'} do not exist. Expected one or more of "
r"\('time', 'x'\)",
):
v.isel(not_a_dim=0)
with pytest.warns(
UserWarning,
match=r"dimensions {'not_a_dim'} do not exist. Expected one or more of "
match=r"Dimensions {'not_a_dim'} do not exist. Expected one or more of "
r"\('time', 'x'\)",
):
v.isel(not_a_dim=0, missing_dims="warn")
Expand Down

0 comments on commit 31d540f

Please sign in to comment.