Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow callables to .drop_vars #8511

Merged
merged 6 commits into from
Dec 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ New Features
- :py:meth:`~xarray.DataArray.rank` now operates on dask-backed arrays, assuming
the core dim has exactly one chunk. (:pull:`8475`).
By `Maximilian Roos <https://github.com/max-sixty>`_.
- :py:meth:`Dataset.drop_vars` & :py:meth:`DataArray.drop_vars` allow passing a
callable, similar to :py:meth:`Dataset.where` & :py:meth:`Dataset.sortby` & others.
(:pull:`8511`).
By `Maximilian Roos <https://github.com/max-sixty>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
17 changes: 14 additions & 3 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3041,16 +3041,17 @@ def T(self) -> Self:

def drop_vars(
self,
names: Hashable | Iterable[Hashable],
names: str | Iterable[Hashable] | Callable[[Self], str | Iterable[Hashable]],
*,
errors: ErrorOptions = "raise",
) -> Self:
"""Returns an array with dropped variables.

Parameters
----------
names : Hashable or iterable of Hashable
Name(s) of variables to drop.
names : Hashable or iterable of Hashable or Callable
Name(s) of variables to drop. If a Callable, this object is passed as its
only argument and its result is used.
errors : {"raise", "ignore"}, default: "raise"
If 'raise', raises a ValueError error if any of the variable
passed are not in the dataset. If 'ignore', any given names that are in the
Expand Down Expand Up @@ -3100,7 +3101,17 @@ def drop_vars(
[ 6, 7, 8],
[ 9, 10, 11]])
Dimensions without coordinates: x, y

>>> da.drop_vars(lambda x: x.coords)
<xarray.DataArray (x: 4, y: 3)>
array([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
Dimensions without coordinates: x, y
"""
if callable(names):
names = names(self)
ds = self._to_temp_dataset().drop_vars(names, errors=errors)
return self._from_temp_dataset(ds)

Expand Down
49 changes: 33 additions & 16 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5723,16 +5723,17 @@ def _assert_all_in_dataset(

def drop_vars(
self,
names: Hashable | Iterable[Hashable],
names: str | Iterable[Hashable] | Callable[[Self], str | Iterable[Hashable]],
*,
errors: ErrorOptions = "raise",
) -> Self:
"""Drop variables from this dataset.

Parameters
----------
names : hashable or iterable of hashable
Name(s) of variables to drop.
names : Hashable or iterable of Hashable or Callable
Name(s) of variables to drop. If a Callable, this object is passed as its
only argument and its result is used.
errors : {"raise", "ignore"}, default: "raise"
If 'raise', raises a ValueError error if any of the variable
passed are not in the dataset. If 'ignore', any given names that are in the
Expand Down Expand Up @@ -5774,7 +5775,7 @@ def drop_vars(
humidity (time, latitude, longitude) float64 65.0 63.8 58.2 59.6
wind_speed (time, latitude, longitude) float64 10.2 8.5 12.1 9.8

# Drop the 'humidity' variable
Drop the 'humidity' variable

>>> dataset.drop_vars(["humidity"])
<xarray.Dataset>
Expand All @@ -5787,7 +5788,7 @@ def drop_vars(
temperature (time, latitude, longitude) float64 25.5 26.3 27.1 28.0
wind_speed (time, latitude, longitude) float64 10.2 8.5 12.1 9.8

# Drop the 'humidity', 'temperature' variables
Drop the 'humidity', 'temperature' variables

>>> dataset.drop_vars(["humidity", "temperature"])
<xarray.Dataset>
Expand All @@ -5799,7 +5800,18 @@ def drop_vars(
Data variables:
wind_speed (time, latitude, longitude) float64 10.2 8.5 12.1 9.8

# Attempt to drop non-existent variable with errors="ignore"
Drop all indexes

>>> dataset.drop_vars(lambda x: x.indexes)
<xarray.Dataset>
Dimensions: (time: 1, latitude: 2, longitude: 2)
Dimensions without coordinates: time, latitude, longitude
Data variables:
temperature (time, latitude, longitude) float64 25.5 26.3 27.1 28.0
humidity (time, latitude, longitude) float64 65.0 63.8 58.2 59.6
wind_speed (time, latitude, longitude) float64 10.2 8.5 12.1 9.8

Attempt to drop non-existent variable with errors="ignore"

>>> dataset.drop_vars(["pressure"], errors="ignore")
<xarray.Dataset>
Expand All @@ -5813,7 +5825,7 @@ def drop_vars(
humidity (time, latitude, longitude) float64 65.0 63.8 58.2 59.6
wind_speed (time, latitude, longitude) float64 10.2 8.5 12.1 9.8

# Attempt to drop non-existent variable with errors="raise"
Attempt to drop non-existent variable with errors="raise"

>>> dataset.drop_vars(["pressure"], errors="raise")
Traceback (most recent call last):
Expand All @@ -5833,36 +5845,38 @@ def drop_vars(
DataArray.drop_vars

"""
if callable(names):
names = names(self)
# the Iterable check is required for mypy
if is_scalar(names) or not isinstance(names, Iterable):
names = {names}
names_set = {names}
else:
names = set(names)
names_set = set(names)
if errors == "raise":
self._assert_all_in_dataset(names)
self._assert_all_in_dataset(names_set)

# GH6505
other_names = set()
for var in names:
for var in names_set:
maybe_midx = self._indexes.get(var, None)
if isinstance(maybe_midx, PandasMultiIndex):
idx_coord_names = set(maybe_midx.index.names + [maybe_midx.dim])
idx_other_names = idx_coord_names - set(names)
idx_other_names = idx_coord_names - set(names_set)
other_names.update(idx_other_names)
if other_names:
names |= set(other_names)
names_set |= set(other_names)
warnings.warn(
f"Deleting a single level of a MultiIndex is deprecated. Previously, this deleted all levels of a MultiIndex. "
f"Please also drop the following variables: {other_names!r} to avoid an error in the future.",
DeprecationWarning,
stacklevel=2,
)

assert_no_index_corrupted(self.xindexes, names)
assert_no_index_corrupted(self.xindexes, names_set)

variables = {k: v for k, v in self._variables.items() if k not in names}
variables = {k: v for k, v in self._variables.items() if k not in names_set}
coord_names = {k for k in self._coord_names if k in variables}
indexes = {k: v for k, v in self._indexes.items() if k not in names}
indexes = {k: v for k, v in self._indexes.items() if k not in names_set}
return self._replace_with_new_dims(
variables, coord_names=coord_names, indexes=indexes
)
Expand Down Expand Up @@ -5962,6 +5976,9 @@ def drop(
PendingDeprecationWarning,
stacklevel=2,
)
# for mypy
if is_scalar(labels):
labels = [labels]
return self.drop_vars(labels, errors=errors)
if dim is not None:
warnings.warn(
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _drop_coords(self) -> T_Xarray:
obj = self._obj
for k, v in obj.coords.items():
if k != self._dim and self._dim in v.dims:
obj = obj.drop_vars(k)
obj = obj.drop_vars([k])
return obj

def pad(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray:
Expand Down Expand Up @@ -244,7 +244,7 @@ def map(
# dimension, then we need to do so before we can rename the proxy
# dimension we used.
if self._dim in combined.coords:
combined = combined.drop_vars(self._dim)
combined = combined.drop_vars([self._dim])

if RESAMPLE_DIM in combined.dims:
combined = combined.rename({RESAMPLE_DIM: self._dim})
Expand Down
8 changes: 8 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2652,6 +2652,14 @@ def test_drop_coordinates(self) -> None:
actual = renamed.drop_vars("foo", errors="ignore")
assert_identical(actual, renamed)

def test_drop_vars_callable(self) -> None:
A = DataArray(
np.random.randn(2, 3), dims=["x", "y"], coords={"x": [1, 2], "y": [3, 4, 5]}
)
expected = A.drop_vars(["x", "y"])
actual = A.drop_vars(lambda x: x.indexes)
assert_identical(expected, actual)

def test_drop_multiindex_level(self) -> None:
# GH6505
expected = self.mda.drop_vars(["x", "level_1", "level_2"])
Expand Down
Loading