Skip to content

Commit

Permalink
support passing a function to combine_attrs (pydata#4896)
Browse files Browse the repository at this point in the history
* add a test for passing a function to merge_attrs

* support a callable combine_attrs

* also check that callable combine_attrs works with variables

* update the docstrings of merge, concat and combine_*

* also test the other functions that support combine_attrs

* update whats-new.rst [skip-ci]

* add a context kwarg which will be None for now

* fix the bad merge [skip-ci]

* fix the merge and combine tests

* fix the concat tests

* update the docs to account for the context object

* expose Context as part of the public API

* move to the newest section

* fix whats-new.rst
  • Loading branch information
keewis authored Jun 8, 2021
1 parent 9daf9b1 commit e87d65b
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 19 deletions.
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,7 @@ Advanced API
Variable
IndexVariable
as_variable
Context
register_dataset_accessor
register_dataarray_accessor
Dataset.set_close
Expand Down
14 changes: 7 additions & 7 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ v0.18.3 (unreleased)
New Features
~~~~~~~~~~~~
- Allow assigning values to a subset of a dataset using positional or label-based
indexing (:issue:`3015`, :pull:`5362`). By `Matthias Göbel <https://github.com/matzegoebel>`_.
indexing (:issue:`3015`, :pull:`5362`).
By `Matthias Göbel <https://github.com/matzegoebel>`_.
- Attempting to reduce a weighted object over missing dimensions now raises an error (:pull:`5362`).
By `Mattia Almansi <https://github.com/malmans2>`_.
- Add ``.sum`` to :py:meth:`~xarray.DataArray.rolling_exp` and
Expand All @@ -33,9 +34,10 @@ New Features
- :py:func:`xarray.cov` and :py:func:`xarray.corr` now lazily check for missing
values if inputs are dask arrays (:issue:`4804`, :pull:`5284`).
By `Andrew Williams <https://github.com/AndrewWilliams3142>`_.

- Attempting to ``concat`` list of elements that are not all ``Dataset`` or all ``DataArray`` now raises an error (:issue:`5051`, :pull:`5425`).
By `Thomas Hirtz <https://github.com/thomashirtz>`_.
- allow passing a function to ``combine_attrs`` (:pull:`4896`).
By `Justus Magin <https://github.com/keewis>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -346,13 +348,11 @@ Bug fixes
- Ensure standard calendar dates encoded with a calendar attribute with some or
all uppercase letters can be decoded or encoded to or from
``np.datetime64[ns]`` dates with or without ``cftime`` installed
(:issue:`5093`, :pull:`5180`). By `Spencer Clark
<https://github.com/spencerkclark>`_.
- Warn on passing ``keep_attrs`` to ``resample`` and ``rolling_exp`` as they are ignored, pass ``keep_attrs``
to the applied function instead (:pull:`5265`). By `Mathias Hauser <https://github.com/mathause>`_.

(:issue:`5093`, :pull:`5180`).
By `Spencer Clark <https://github.com/spencerkclark>`_.
- Warn on passing ``keep_attrs`` to ``resample`` and ``rolling_exp`` as they are ignored, pass ``keep_attrs``
to the applied function instead (:pull:`5265`).
By `Mathias Hauser <https://github.com/mathause>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
3 changes: 2 additions & 1 deletion xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .core.dataarray import DataArray
from .core.dataset import Dataset
from .core.extensions import register_dataarray_accessor, register_dataset_accessor
from .core.merge import MergeError, merge
from .core.merge import Context, MergeError, merge
from .core.options import set_options
from .core.parallel import map_blocks
from .core.variable import Coordinate, IndexVariable, Variable, as_variable
Expand Down Expand Up @@ -78,6 +78,7 @@
"zeros_like",
# Classes
"CFTimeIndex",
"Context",
"Coordinate",
"DataArray",
"Dataset",
Expand Down
16 changes: 12 additions & 4 deletions xarray/core/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,9 @@ def combine_nested(
those of the first object with that dimension. Indexes for the same
dimension must have the same size in all objects.
combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
"override"}, default: "drop"
String indicating how to combine attrs of the objects being merged:
"override"} or callable, default: "drop"
A callable or a string indicating how to combine attrs of the objects being
merged:
- "drop": empty attrs on returned Dataset.
- "identical": all attrs must be the same on every object.
Expand All @@ -444,6 +445,9 @@ def combine_nested(
- "override": skip comparing and copy attrs from the first dataset to
the result.
If a callable, it must expect a sequence of ``attrs`` dicts and a context object
as its only parameters.
Returns
-------
combined : xarray.Dataset
Expand Down Expand Up @@ -646,8 +650,9 @@ def combine_by_coords(
those of the first object with that dimension. Indexes for the same
dimension must have the same size in all objects.
combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
"override"}, default: "drop"
String indicating how to combine attrs of the objects being merged:
"override"} or callable, default: "drop"
A callable or a string indicating how to combine attrs of the objects being
merged:
- "drop": empty attrs on returned Dataset.
- "identical": all attrs must be the same on every object.
Expand All @@ -658,6 +663,9 @@ def combine_by_coords(
- "override": skip comparing and copy attrs from the first dataset to
the result.
If a callable, it must expect a sequence of ``attrs`` dicts and a context object
as its only parameters.
Returns
-------
combined : xarray.Dataset
Expand Down
8 changes: 6 additions & 2 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,9 @@ def concat(
those of the first object with that dimension. Indexes for the same
dimension must have the same size in all objects.
combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
"override"}, default: "override"
String indicating how to combine attrs of the objects being merged:
"override"} or callable, default: "override"
A callable or a string indicating how to combine attrs of the objects being
merged:
- "drop": empty attrs on returned Dataset.
- "identical": all attrs must be the same on every object.
Expand All @@ -155,6 +156,9 @@ def concat(
- "override": skip comparing and copy attrs from the first dataset to
the result.
If a callable, it must expect a sequence of ``attrs`` dicts and a context object
as its only parameters.
Returns
-------
concatenated : type of objs
Expand Down
23 changes: 18 additions & 5 deletions xarray/core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@
)


class Context:
"""object carrying the information of a call"""

def __init__(self, func):
self.func = func


def broadcast_dimension_size(variables: List[Variable]) -> Dict[Hashable, int]:
"""Extract dimension sizes from a dictionary of variables.
Expand Down Expand Up @@ -502,13 +509,15 @@ def assert_valid_explicit_coords(variables, dims, explicit_coords):
)


def merge_attrs(variable_attrs, combine_attrs):
def merge_attrs(variable_attrs, combine_attrs, context=None):
"""Combine attributes from different variables according to combine_attrs"""
if not variable_attrs:
# no attributes to merge
return None

if combine_attrs == "drop":
if callable(combine_attrs):
return combine_attrs(variable_attrs, context=context)
elif combine_attrs == "drop":
return {}
elif combine_attrs == "override":
return dict(variable_attrs[0])
Expand Down Expand Up @@ -585,7 +594,7 @@ def merge_core(
join : {"outer", "inner", "left", "right"}, optional
How to combine objects with different indexes.
combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
"override"}, optional
"override"} or callable, default: "override"
How to combine attributes of objects
priority_arg : int, optional
Optional argument in `objects` that takes precedence over the others.
Expand Down Expand Up @@ -696,8 +705,9 @@ def merge(
variable names to fill values. Use a data array's name to
refer to its values.
combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
"override"}, default: "override"
String indicating how to combine attrs of the objects being merged:
"override"} or callable, default: "override"
A callable or a string indicating how to combine attrs of the objects being
merged:
- "drop": empty attrs on returned Dataset.
- "identical": all attrs must be the same on every object.
Expand All @@ -708,6 +718,9 @@ def merge(
- "override": skip comparing and copy attrs from the first dataset to
the result.
If a callable, it must expect a sequence of ``attrs`` dicts and a context object
as its only parameters.
Returns
-------
Dataset
Expand Down
7 changes: 7 additions & 0 deletions xarray/tests/test_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,9 @@ def test_auto_combine_2d_combine_attrs_kwarg(self):
}
expected_dict["override"] = expected.copy(deep=True)
expected_dict["override"].attrs = {"a": 1}
f = lambda attrs, context: attrs[0]
expected_dict[f] = expected.copy(deep=True)
expected_dict[f].attrs = f([{"a": 1}], None)

datasets = [[ds(0), ds(1), ds(2)], [ds(3), ds(4), ds(5)]]

Expand Down Expand Up @@ -714,6 +717,10 @@ def test_combine_coords_join_exact(self):
Dataset({"x": [0, 1], "y": [0, 1]}, attrs={"a": 1, "b": 2}),
),
("override", Dataset({"x": [0, 1], "y": [0, 1]}, attrs={"a": 1})),
(
lambda attrs, context: attrs[1],
Dataset({"x": [0, 1], "y": [0, 1]}, attrs={"a": 1, "b": 2}),
),
],
)
def test_combine_coords_combine_attrs(self, combine_attrs, expected):
Expand Down
14 changes: 14 additions & 0 deletions xarray/tests/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,13 @@ def test_concat_join_kwarg(self):
{"a": 41, "c": 43, "d": 44},
False,
),
(
lambda attrs, context: {"a": -1, "b": 0, "c": 1} if any(attrs) else {},
{"a": 41, "b": 42, "c": 43},
{"b": 2, "c": 43, "d": 44},
{"a": -1, "b": 0, "c": 1},
False,
),
],
)
def test_concat_combine_attrs_kwarg(
Expand Down Expand Up @@ -354,6 +361,13 @@ def test_concat_combine_attrs_kwarg(
{"a": 41, "c": 43, "d": 44},
False,
),
(
lambda attrs, context: {"a": -1, "b": 0, "c": 1} if any(attrs) else {},
{"a": 41, "b": 42, "c": 43},
{"b": 2, "c": 43, "d": 44},
{"a": -1, "b": 0, "c": 1},
False,
),
],
)
def test_concat_combine_attrs_kwarg_variables(
Expand Down
14 changes: 14 additions & 0 deletions xarray/tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@ def test_merge_arrays_attrs_default(self):
{"a": 1, "c": np.array([3]), "d": 4},
False,
),
(
lambda attrs, context: attrs[1],
{"a": 1, "b": 2, "c": 3},
{"a": 4, "b": 3, "c": 1},
{"a": 4, "b": 3, "c": 1},
False,
),
],
)
def test_merge_arrays_attrs(
Expand Down Expand Up @@ -161,6 +168,13 @@ def test_merge_arrays_attrs(
{"a": 1, "c": 3, "d": 4},
False,
),
(
lambda attrs, context: attrs[1],
{"a": 1, "b": 2, "c": 3},
{"a": 4, "b": 3, "c": 1},
{"a": 4, "b": 3, "c": 1},
False,
),
],
)
def test_merge_arrays_attrs_variables(
Expand Down

0 comments on commit e87d65b

Please sign in to comment.