Skip to content

Commit 9c578d6

Browse files
committed
Fix dataset attrs
1 parent b86623c commit 9c578d6

File tree

4 files changed

+22
-19
lines changed

4 files changed

+22
-19
lines changed

xarray/core/dataset.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4899,18 +4899,20 @@ def from_dict(cls, d):
48994899
return obj
49004900

49014901
@staticmethod
4902-
def _unary_op(f, keep_attrs=None):
4903-
if keep_attrs is None:
4904-
keep_attrs = _get_keep_attrs(default=True)
4905-
4902+
def _unary_op(f):
49064903
@functools.wraps(f)
49074904
def func(self, *args, **kwargs):
49084905
variables = {}
4906+
keep_attrs = kwargs.pop("keep_attrs", None)
4907+
if keep_attrs is None:
4908+
keep_attrs = _get_keep_attrs(default=True)
49094909
for k, v in self._variables.items():
49104910
if k in self._coord_names:
49114911
variables[k] = v
49124912
else:
49134913
variables[k] = f(v, *args, **kwargs)
4914+
if keep_attrs:
4915+
variables[k].attrs = v._attrs
49144916
attrs = self._attrs if keep_attrs else None
49154917
return self._replace_with_new_dims(variables, attrs=attrs)
49164918

xarray/core/options.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def _get_keep_attrs(default):
7171
return global_choice
7272
else:
7373
raise ValueError(
74-
"The global option keep_attrs must be one of" " True, False or 'default'."
74+
"The global option keep_attrs must be one of True, False or 'default'."
7575
)
7676

7777

xarray/core/variable.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2035,16 +2035,20 @@ def imag(self):
20352035
return type(self)(self.dims, self.data.imag, self._attrs)
20362036

20372037
def __array_wrap__(self, obj, context=None):
2038-
keep_attrs = _get_keep_attrs(default=False)
2039-
attrs = self._attrs if keep_attrs else {}
2040-
return Variable(self.dims, obj, attrs)
2038+
return Variable(self.dims, obj)
20412039

20422040
@staticmethod
20432041
def _unary_op(f):
20442042
@functools.wraps(f)
20452043
def func(self, *args, **kwargs):
2044+
keep_attrs = kwargs.pop("keep_attrs", None)
2045+
if keep_attrs is None:
2046+
keep_attrs = _get_keep_attrs(default=True)
20462047
with np.errstate(all="ignore"):
2047-
return self.__array_wrap__(f(self.data, *args, **kwargs))
2048+
result = self.__array_wrap__(f(self.data, *args, **kwargs))
2049+
if keep_attrs:
2050+
result.attrs = self._attrs
2051+
return result
20482052

20492053
return func
20502054

xarray/tests/test_dataset.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4415,22 +4415,19 @@ def test_fillna(self):
44154415
assert actual.a.name == "a"
44164416
assert actual.a.attrs == ds.a.attrs
44174417

4418-
def test_propagate_attrs(self):
4418+
@pytest.mark.parametrize(
4419+
"func", [lambda x: x.clip(0, 1), lambda x: np.float64(1.0) * x, np.abs, abs]
4420+
)
4421+
def test_propagate_attrs(self, func):
44194422

44204423
da = DataArray(range(5), name="a", attrs={"attr": "da"})
44214424
ds = Dataset({"a": da}, attrs={"attr": "ds"})
44224425

44234426
# test defaults
4424-
assert ds.clip(0, 1).attrs == ds.attrs
4425-
assert (np.float64(1.0) * ds).attrs == ds.attrs
4426-
assert np.abs(ds).attrs == ds.attrs
4427-
assert abs(ds).attrs == ds.attrs
4428-
4427+
assert func(ds).attrs == ds.attrs
44294428
with set_options(keep_attrs=False):
4430-
assert ds.clip(0, 1).attrs != ds.attrs
4431-
assert (np.float64(1.0) * ds).attrs != ds.attrs
4432-
assert np.abs(ds).attrs != ds.attrs
4433-
assert abs(ds).attrs != ds.attrs
4429+
assert func(ds).attrs != ds.attrs
4430+
assert func(ds).a.attrs != ds.a.attrs
44344431

44354432
def test_where(self):
44364433
ds = Dataset({"a": ("x", range(5))})

0 commit comments

Comments
 (0)