From bd4650c2f886e59a76a49d477999694624d44be8 Mon Sep 17 00:00:00 2001 From: keewis Date: Wed, 5 May 2021 18:37:25 +0200 Subject: [PATCH] also apply combine_attrs to the attrs of the variables (#4902) --- doc/whats-new.rst | 8 +- xarray/core/concat.py | 4 +- xarray/core/merge.py | 16 +++- xarray/core/variable.py | 67 +++++++++++++-- xarray/tests/test_combine.py | 154 +++++++++++++++++++++++++++++++++-- xarray/tests/test_concat.py | 1 - xarray/tests/test_dataset.py | 6 +- xarray/tests/test_merge.py | 58 ++++++------- 8 files changed, 264 insertions(+), 50 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ed1d9c0f079..59857c67bf7 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,7 +22,9 @@ v0.17.1 (unreleased) New Features ~~~~~~~~~~~~ - +- apply ``combine_attrs`` on data variables and coordinate variables when concatenating + and merging datasets and dataarrays (:pull:`4902`). + By `Justus Magin `_. - Add :py:meth:`Dataset.to_pandas` (:pull:`5247`) By `Giacomo Caria `_. - Add :py:meth:`DataArray.plot.surface` which wraps matplotlib's `plot_surface` to make @@ -144,6 +146,10 @@ Breaking changes ``ds.coarsen(...).mean(keep_attrs=False)`` instead of ``ds.coarsen(..., keep_attrs=False).mean()``. Further, coarsen now keeps attributes per default (:pull:`5227`). By `Mathias Hauser `_. +- switch the default of the :py:func:`merge` ``combine_attrs`` parameter to + ``"override"``. This will keep the current behavior for merging the ``attrs`` of + variables but stop dropping the ``attrs`` of the main objects (:pull:`4902`). + By `Justus Magin `_. Deprecations ~~~~~~~~~~~~ diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 7a958eb1404..9eca99918d4 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -508,7 +508,7 @@ def ensure_common_dims(vars): vars = ensure_common_dims([ds[k].variable for ds in datasets]) except KeyError: raise ValueError("%r is not present in all datasets." % k) - combined = concat_vars(vars, dim, positions) + combined = concat_vars(vars, dim, positions, combine_attrs=combine_attrs) assert isinstance(combined, Variable) result_vars[k] = combined elif k in result_vars: @@ -572,7 +572,7 @@ def _dataarray_concat( positions, fill_value=fill_value, join=join, - combine_attrs="drop", + combine_attrs=combine_attrs, ) merged_attrs = merge_attrs([da.attrs for da in arrays], combine_attrs) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index ec95563bda9..ec3c9b0f065 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -164,6 +164,7 @@ def merge_collected( grouped: Dict[Hashable, List[MergeElement]], prioritized: Mapping[Hashable, MergeElement] = None, compat: str = "minimal", + combine_attrs="override", ) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, pd.Index]]: """Merge dicts of variables, while resolving conflicts appropriately. @@ -222,11 +223,18 @@ def merge_collected( % (name, variable.attrs, other_variable.attrs) ) merged_vars[name] = variable + merged_vars[name].attrs = merge_attrs( + [var.attrs for var, _ in indexed_elements], + combine_attrs=combine_attrs, + ) merged_indexes[name] = index else: variables = [variable for variable, _ in elements_list] try: merged_vars[name] = unique_variable(name, variables, compat) + merged_vars[name].attrs = merge_attrs( + [var.attrs for var in variables], combine_attrs=combine_attrs + ) except MergeError: if compat != "minimal": # we need more than "minimal" compatibility (for which @@ -613,7 +621,9 @@ def merge_core( collected = collect_variables_and_indexes(aligned) prioritized = _get_priority_vars_and_indexes(aligned, priority_arg, compat=compat) - variables, out_indexes = merge_collected(collected, prioritized, compat=compat) + variables, out_indexes = merge_collected( + collected, prioritized, compat=compat, combine_attrs=combine_attrs + ) assert_unique_multiindex_level_names(variables) dims = calculate_dimensions(variables) @@ -649,7 +659,7 @@ def merge( compat: str = "no_conflicts", join: str = "outer", fill_value: object = dtypes.NA, - combine_attrs: str = "drop", + combine_attrs: str = "override", ) -> "Dataset": """Merge any number of xarray objects into a single Dataset as variables. @@ -688,7 +698,7 @@ def merge( variable names to fill values. Use a data array's name to refer to its values. combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ - "override"}, default: "drop" + "override"}, default: "override" String indicating how to combine attrs of the objects being merged: - "drop": empty attrs on returned Dataset. diff --git a/xarray/core/variable.py b/xarray/core/variable.py index d1d6698cea7..6f828a5128c 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1771,7 +1771,14 @@ def reduce( return Variable(dims, data, attrs=attrs) @classmethod - def concat(cls, variables, dim="concat_dim", positions=None, shortcut=False): + def concat( + cls, + variables, + dim="concat_dim", + positions=None, + shortcut=False, + combine_attrs="override", + ): """Concatenate variables along a new or existing dimension. Parameters @@ -1794,6 +1801,18 @@ def concat(cls, variables, dim="concat_dim", positions=None, shortcut=False): This option is used internally to speed-up groupby operations. If `shortcut` is True, some checks of internal consistency between arrays to concatenate are skipped. + combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ + "override"}, default: "override" + String indicating how to combine attrs of the objects being merged: + + - "drop": empty attrs on returned Dataset. + - "identical": all attrs must be the same on every object. + - "no_conflicts": attrs from all objects are combined, any that have + the same name must also have the same value. + - "drop_conflicts": attrs from all objects are combined, any that have + the same name but different values are dropped. + - "override": skip comparing and copy attrs from the first dataset to + the result. Returns ------- @@ -1801,6 +1820,8 @@ def concat(cls, variables, dim="concat_dim", positions=None, shortcut=False): Concatenated Variable formed by stacking all the supplied variables along the given dimension. """ + from .merge import merge_attrs + if not isinstance(dim, str): (dim,) = dim.dims @@ -1825,7 +1846,9 @@ def concat(cls, variables, dim="concat_dim", positions=None, shortcut=False): dims = (dim,) + first_var.dims data = duck_array_ops.stack(arrays, axis=axis) - attrs = dict(first_var.attrs) + attrs = merge_attrs( + [var.attrs for var in variables], combine_attrs=combine_attrs + ) encoding = dict(first_var.encoding) if not shortcut: for var in variables: @@ -2581,12 +2604,21 @@ def __setitem__(self, key, value): raise TypeError("%s values cannot be modified" % type(self).__name__) @classmethod - def concat(cls, variables, dim="concat_dim", positions=None, shortcut=False): + def concat( + cls, + variables, + dim="concat_dim", + positions=None, + shortcut=False, + combine_attrs="override", + ): """Specialized version of Variable.concat for IndexVariable objects. This exists because we want to avoid converting Index objects to NumPy arrays, if possible. """ + from .merge import merge_attrs + if not isinstance(dim, str): (dim,) = dim.dims @@ -2613,12 +2645,13 @@ def concat(cls, variables, dim="concat_dim", positions=None, shortcut=False): # keep as str if possible as pandas.Index uses object (converts to numpy array) data = maybe_coerce_to_str(data, variables) - attrs = dict(first_var.attrs) + attrs = merge_attrs( + [var.attrs for var in variables], combine_attrs=combine_attrs + ) if not shortcut: for var in variables: if var.dims != first_var.dims: raise ValueError("inconsistent dimensions") - utils.remove_incompatible_items(attrs, var.attrs) return cls(first_var.dims, data, attrs) @@ -2792,7 +2825,13 @@ def _broadcast_compat_data(self, other): return self_data, other_data, dims -def concat(variables, dim="concat_dim", positions=None, shortcut=False): +def concat( + variables, + dim="concat_dim", + positions=None, + shortcut=False, + combine_attrs="override", +): """Concatenate variables along a new or existing dimension. Parameters @@ -2815,6 +2854,18 @@ def concat(variables, dim="concat_dim", positions=None, shortcut=False): This option is used internally to speed-up groupby operations. If `shortcut` is True, some checks of internal consistency between arrays to concatenate are skipped. + combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ + "override"}, default: "override" + String indicating how to combine attrs of the objects being merged: + + - "drop": empty attrs on returned Dataset. + - "identical": all attrs must be the same on every object. + - "no_conflicts": attrs from all objects are combined, any that have + the same name must also have the same value. + - "drop_conflicts": attrs from all objects are combined, any that have + the same name but different values are dropped. + - "override": skip comparing and copy attrs from the first dataset to + the result. Returns ------- @@ -2824,9 +2875,9 @@ def concat(variables, dim="concat_dim", positions=None, shortcut=False): """ variables = list(variables) if all(isinstance(v, IndexVariable) for v in variables): - return IndexVariable.concat(variables, dim, positions, shortcut) + return IndexVariable.concat(variables, dim, positions, shortcut, combine_attrs) else: - return Variable.concat(variables, dim, positions, shortcut) + return Variable.concat(variables, dim, positions, shortcut, combine_attrs) def assert_unique_multiindex_level_names(variables): diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index ebc8ab73604..3b6aaec60f2 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -5,7 +5,14 @@ import numpy as np import pytest -from xarray import DataArray, Dataset, combine_by_coords, combine_nested, concat +from xarray import ( + DataArray, + Dataset, + MergeError, + combine_by_coords, + combine_nested, + concat, +) from xarray.core import dtypes from xarray.core.combine import ( _check_shape_tile_ids, @@ -476,7 +483,8 @@ def test_concat_name_symmetry(self): assert_identical(x_first, y_first) def test_concat_one_dim_merge_another(self): - data = create_test_data() + data = create_test_data(add_attrs=False) + data1 = data.copy(deep=True) data2 = data.copy(deep=True) @@ -502,7 +510,7 @@ def test_auto_combine_2d(self): assert_equal(result, expected) def test_auto_combine_2d_combine_attrs_kwarg(self): - ds = create_test_data + ds = lambda x: create_test_data(x, add_attrs=False) partway1 = concat([ds(0), ds(3)], dim="dim1") partway2 = concat([ds(1), ds(4)], dim="dim1") @@ -675,8 +683,8 @@ def test_combine_by_coords(self): with pytest.raises(ValueError, match=r"Every dimension needs a coordinate"): combine_by_coords(objs) - def test_empty_input(self): - assert_identical(Dataset(), combine_by_coords([])) + def test_empty_input(self): + assert_identical(Dataset(), combine_by_coords([])) @pytest.mark.parametrize( "join, expected", @@ -754,6 +762,142 @@ def test_combine_nested_combine_attrs_drop_conflicts(self): ) assert_identical(expected, actual) + @pytest.mark.parametrize( + "combine_attrs, attrs1, attrs2, expected_attrs, expect_exception", + [ + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 1, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + False, + ), + ("no_conflicts", {"a": 1, "b": 2}, {}, {"a": 1, "b": 2}, False), + ("no_conflicts", {}, {"a": 1, "c": 3}, {"a": 1, "c": 3}, False), + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 4, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + True, + ), + ("drop", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "b": 2}, {"a": 1, "b": 2}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {"a": 1, "b": 2}, True), + ( + "override", + {"a": 1, "b": 2}, + {"a": 4, "b": 5, "c": 3}, + {"a": 1, "b": 2}, + False, + ), + ( + "drop_conflicts", + {"a": 1, "b": 2, "c": 3}, + {"b": 1, "c": 3, "d": 4}, + {"a": 1, "c": 3, "d": 4}, + False, + ), + ], + ) + def test_combine_nested_combine_attrs_variables( + self, combine_attrs, attrs1, attrs2, expected_attrs, expect_exception + ): + """check that combine_attrs is used on data variables and coords""" + data1 = Dataset( + { + "a": ("x", [1, 2], attrs1), + "b": ("x", [3, -1], attrs1), + "x": ("x", [0, 1], attrs1), + } + ) + data2 = Dataset( + { + "a": ("x", [2, 3], attrs2), + "b": ("x", [-2, 1], attrs2), + "x": ("x", [2, 3], attrs2), + } + ) + + if expect_exception: + with pytest.raises(MergeError, match="combine_attrs"): + combine_by_coords([data1, data2], combine_attrs=combine_attrs) + else: + actual = combine_by_coords([data1, data2], combine_attrs=combine_attrs) + expected = Dataset( + { + "a": ("x", [1, 2, 2, 3], expected_attrs), + "b": ("x", [3, -1, -2, 1], expected_attrs), + }, + {"x": ("x", [0, 1, 2, 3], expected_attrs)}, + ) + + assert_identical(actual, expected) + + @pytest.mark.parametrize( + "combine_attrs, attrs1, attrs2, expected_attrs, expect_exception", + [ + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 1, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + False, + ), + ("no_conflicts", {"a": 1, "b": 2}, {}, {"a": 1, "b": 2}, False), + ("no_conflicts", {}, {"a": 1, "c": 3}, {"a": 1, "c": 3}, False), + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 4, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + True, + ), + ("drop", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "b": 2}, {"a": 1, "b": 2}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {"a": 1, "b": 2}, True), + ( + "override", + {"a": 1, "b": 2}, + {"a": 4, "b": 5, "c": 3}, + {"a": 1, "b": 2}, + False, + ), + ( + "drop_conflicts", + {"a": 1, "b": 2, "c": 3}, + {"b": 1, "c": 3, "d": 4}, + {"a": 1, "c": 3, "d": 4}, + False, + ), + ], + ) + def test_combine_by_coords_combine_attrs_variables( + self, combine_attrs, attrs1, attrs2, expected_attrs, expect_exception + ): + """check that combine_attrs is used on data variables and coords""" + data1 = Dataset( + {"x": ("a", [0], attrs1), "y": ("a", [0], attrs1), "a": ("a", [0], attrs1)} + ) + data2 = Dataset( + {"x": ("a", [1], attrs2), "y": ("a", [1], attrs2), "a": ("a", [1], attrs2)} + ) + + if expect_exception: + with pytest.raises(MergeError, match="combine_attrs"): + combine_by_coords([data1, data2], combine_attrs=combine_attrs) + else: + actual = combine_by_coords([data1, data2], combine_attrs=combine_attrs) + expected = Dataset( + { + "x": ("a", [0, 1], expected_attrs), + "y": ("a", [0, 1], expected_attrs), + "a": ("a", [0, 1], expected_attrs), + } + ) + + assert_identical(actual, expected) + def test_infer_order_from_coords(self): data = create_test_data() objs = [data.isel(dim2=slice(4, 9)), data.isel(dim2=slice(4))] diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 66cf55d13f6..9cfc134e4fe 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -318,7 +318,6 @@ def test_concat_combine_attrs_kwarg( assert_identical(actual, expected) - @pytest.mark.skip(reason="not implemented, yet (see #4827)") @pytest.mark.parametrize( "combine_attrs, attrs1, attrs2, expected_attrs, expect_exception", [ diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 1e526a3787f..ef8db0374d5 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -61,7 +61,7 @@ ] -def create_test_data(seed=None): +def create_test_data(seed=None, add_attrs=True): rs = np.random.RandomState(seed) _vars = { "var1": ["dim1", "dim2"], @@ -76,7 +76,9 @@ def create_test_data(seed=None): obj["dim3"] = ("dim3", list("abcdefghij")) for v, dims in sorted(_vars.items()): data = rs.normal(size=tuple(_dims[d] for d in dims)) - obj[v] = (dims, data, {"foo": "variable"}) + obj[v] = (dims, data) + if add_attrs: + obj[v].attrs = {"foo": "variable"} obj.coords["numbers"] = ( "dim3", np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64"), diff --git a/xarray/tests/test_merge.py b/xarray/tests/test_merge.py index bee5c951cf9..680c2a3a679 100644 --- a/xarray/tests/test_merge.py +++ b/xarray/tests/test_merge.py @@ -29,13 +29,14 @@ def test_broadcast_dimension_size(self): class TestMergeFunction: def test_merge_arrays(self): - data = create_test_data() + data = create_test_data(add_attrs=False) + actual = xr.merge([data.var1, data.var2]) expected = data[["var1", "var2"]] assert_identical(actual, expected) def test_merge_datasets(self): - data = create_test_data() + data = create_test_data(add_attrs=False) actual = xr.merge([data[["var1"]], data[["var2"]]]) expected = data[["var1", "var2"]] @@ -52,14 +53,17 @@ def test_merge_dataarray_unnamed(self): def test_merge_arrays_attrs_default(self): var1_attrs = {"a": 1, "b": 2} var2_attrs = {"a": 1, "c": 3} - expected_attrs = {} + expected_attrs = {"a": 1, "b": 2} + + data = create_test_data(add_attrs=False) + expected = data[["var1", "var2"]].copy() + expected.var1.attrs = var1_attrs + expected.var2.attrs = var2_attrs + expected.attrs = expected_attrs - data = create_test_data() data.var1.attrs = var1_attrs data.var2.attrs = var2_attrs actual = xr.merge([data.var1, data.var2]) - expected = data[["var1", "var2"]] - expected.attrs = expected_attrs assert_identical(actual, expected) @pytest.mark.parametrize( @@ -110,19 +114,17 @@ def test_merge_arrays_attrs_default(self): def test_merge_arrays_attrs( self, combine_attrs, var1_attrs, var2_attrs, expected_attrs, expect_exception ): - data = create_test_data() - data.var1.attrs = var1_attrs - data.var2.attrs = var2_attrs + data1 = xr.Dataset(attrs=var1_attrs) + data2 = xr.Dataset(attrs=var2_attrs) if expect_exception: - with pytest.raises(MergeError, match=r"combine_attrs"): - actual = xr.merge([data.var1, data.var2], combine_attrs=combine_attrs) + with pytest.raises(MergeError, match="combine_attrs"): + actual = xr.merge([data1, data2], combine_attrs=combine_attrs) else: - actual = xr.merge([data.var1, data.var2], combine_attrs=combine_attrs) - expected = data[["var1", "var2"]] - expected.attrs = expected_attrs + actual = xr.merge([data1, data2], combine_attrs=combine_attrs) + expected = xr.Dataset(attrs=expected_attrs) + assert_identical(actual, expected) - @pytest.mark.skip(reason="not implemented, yet (see #4827)") @pytest.mark.parametrize( "combine_attrs, attrs1, attrs2, expected_attrs, expect_exception", [ @@ -165,22 +167,22 @@ def test_merge_arrays_attrs_variables( self, combine_attrs, attrs1, attrs2, expected_attrs, expect_exception ): """check that combine_attrs is used on data variables and coords""" - data = create_test_data() - data1 = data.copy() - data1.var1.attrs = attrs1 - data1.dim1.attrs = attrs1 - data2 = data.copy() - data2.var1.attrs = attrs2 - data2.dim1.attrs = attrs2 + data1 = xr.Dataset( + {"var1": ("dim1", [], attrs1)}, coords={"dim1": ("dim1", [], attrs1)} + ) + data2 = xr.Dataset( + {"var1": ("dim1", [], attrs2)}, coords={"dim1": ("dim1", [], attrs2)} + ) if expect_exception: - with pytest.raises(MergeError, match=r"combine_attrs"): + with pytest.raises(MergeError, match="combine_attrs"): actual = xr.merge([data1, data2], combine_attrs=combine_attrs) else: actual = xr.merge([data1, data2], combine_attrs=combine_attrs) - expected = data.copy() - expected.var1.attrs = expected_attrs - expected.dim1.attrs = expected_attrs + expected = xr.Dataset( + {"var1": ("dim1", [], expected_attrs)}, + coords={"dim1": ("dim1", [], expected_attrs)}, + ) assert_identical(actual, expected) @@ -252,7 +254,7 @@ def test_merge_no_conflicts_single_var(self): xr.merge([ds1, ds3], compat="no_conflicts") def test_merge_no_conflicts_multi_var(self): - data = create_test_data() + data = create_test_data(add_attrs=False) data1 = data.copy(deep=True) data2 = data.copy(deep=True) @@ -271,7 +273,7 @@ def test_merge_no_conflicts_multi_var(self): def test_merge_no_conflicts_preserve_attrs(self): data = xr.Dataset({"x": ([], 0, {"foo": "bar"})}) - actual = xr.merge([data, data]) + actual = xr.merge([data, data], combine_attrs="no_conflicts") assert_identical(data, actual) def test_merge_no_conflicts_broadcast(self):