-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix coordinate attr handling in xr.where(..., keep_attrs=True)
#7229
Conversation
xarray/core/computation.py
Outdated
@@ -1860,7 +1860,7 @@ def where(cond, x, y, keep_attrs=None): | |||
if keep_attrs is True: | |||
# keep the attributes of x, the second parameter, by default to | |||
# be consistent with the `where` method of `DataArray` and `Dataset` | |||
keep_attrs = lambda attrs, context: getattr(x, "attrs", {}) | |||
keep_attrs = lambda attrs, context: attrs[1] if len(attrs) > 1 else {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xarray/tests/test_computation.py
Outdated
# x as a scalar, takes attrs from y | ||
actual = xr.where(cond, 0, y, keep_attrs=True) | ||
expected = xr.DataArray([0, 0], coords={"x": [0, 1]}, attrs={"attr": "y_da"}) | ||
expected["x"].attrs = {"attr": "y_coord"} | ||
assert_identical(expected, actual) | ||
|
||
# y as a scalar, takes attrs from x | ||
actual = xr.where(cond, x, 0, keep_attrs=True) | ||
expected = xr.DataArray([1, 0], coords={"x": [0, 1]}, attrs={"attr": "x_da"}) | ||
expected["x"].attrs = {"attr": "x_coord"} | ||
assert_identical(expected, actual) | ||
|
||
# x and y as a scalar, takes coord attrs only from cond | ||
actual = xr.where(cond, 1, 0, keep_attrs=True) | ||
assert actual.attrs == {} | ||
expected = xr.DataArray([1, 0], coords={"x": [0, 1]}) | ||
expected["x"].attrs = {"attr": "cond_coord"} | ||
assert_identical(expected, actual) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This gives basically the same result as prior to #6461, where the doc statement that we only take attrs from x
is misleading. I actually like the way it works now, and found it hard to implement something that only pulled from x
and didn't break other uses of apply_variable_ufunc
. If we're ok maintaining this behavior, I tweaked the docstring slightly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry for the late reply, @slevang.
After some investigation, I'm not sure what the best way to handle this would be...
apply_ufunc
handles keep_attrs
, but will ignore the attrs of any not-xarray
objects. This means that for *args
of [data_array, scalar, array_like]
it would pass [{...}]
to merge_attrs
, but for [scalar, data_array, array_like]
the result would the same. In both cases, getitem(attrs, 1, {})
would return {}
, which does not make sense for the second example.
As such, I think we have two options to solve this reliably:
- change
apply_ufunc
to assume every argument should be considered to haveattrs
- wrap any non-
xarray
object passed toxr.where
inxr.Variable
I can't really tell whether option 1 is possible at all (and in any case would be a decent amount of work), so I'd probably choose option 2. For scalars this should be pretty easy to implement:
xr.Variable((), scalar)
However, I don't think passing bare arrays to xr.where
makes sense? Those are not allowed for DataArray.where
, Dataset.where
, and Variable.where
, so I guess that's just a mistake in the docstring? If so, we should definitely fix that...
I considered the As far as passing bare arrays, despite what the docstrings say it seems like you can actually do this with >>> x = xr.DataArray([1, 1], coords={"x": [0, 1]})
>>> cond = xr.DataArray([True, False], coords={"x": [0, 1]})
>>> x.where(cond, other=np.array([0, 2]))
<xarray.DataArray (x: 2)>
array([1, 2])
Coordinates:
* x (x) int64 0 1 Which I don't think makes sense, but is mostly a separate issue. You do get a broadcast error if After poking around I agree that this isn't easy to totally fix. I sort of started to go down the route of I'm just keen to get this merged in some form because the regression of #6461 is pretty bad. For example: ds = xr.tutorial.load_dataset('air_temperature')
xr.where(ds.air>10, ds.air, 10, keep_attrs=True).to_netcdf('foo.nc')
# completely fails because the time attrs have been overwritten by ds.air attrs
ValueError: failed to prevent overwriting existing key units in attrs on variable 'time'. This is probably an encoding field used by xarray to describe how a variable is serialized. To proceed, remove this key from the variable's attributes manually. I hit exactly this issue on some existing scripts so this is preventing me from upgrading beyond |
That particular example might be rare, but In any case, I will try to push this forward so this can be fixed as soon as possible. |
The latest commit should do what we want, consistently taking attrs of The only way it deviates from this (spelled out in the tests) is to pull coord attrs from x, then y, then cond if any of these are scalars. I think this makes sense because if I pass |
assert actual.attrs == {} | ||
expected = xr.DataArray([1, 0], coords={"x": [0, 1]}) | ||
expected["x"].attrs = {"attr": "cond_coord"} | ||
assert_identical(expected, actual) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one seems confusing but I don't have a strong opinion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now it doesn't take the cond
coord attrs. I still think this would be a convenient (albeit confusing) feature, because I happen to have a bunch of code like xr.where(x>10, 10, x)
and would rather keep the attrs. But this is obviously a better use case for DataArray.where
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or DataArray.clip(max=10)
:) https://docs.xarray.dev/en/stable/generated/xarray.DataArray.clip.html
actual = xr.where(True, x, y, keep_attrs=True) | ||
expected = xr.DataArray([1, 1], coords={"x": [0, 1]}, attrs={"attr": "x_da"}) | ||
expected["x"].attrs = {"attr": "x_coord"} | ||
assert_identical(expected, actual) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a Dataset test too with Dataset.attrs={"foo": "bar'}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As written [DataArray, Dataset, Dataset]
takes the Dataset attrs of y. Not sure how to fix this without going deep into the chain of apply_ufunc calls. I'm starting to think rebuilding all the attrs after apply_ufunc might be easiest to get consistent behavior for now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This case should be covered now. Looks like we can pass one DataArray and one Dataset so I've included a test on that, but there are a lot of permutations.
actual = xr.where(True, x, y, keep_attrs=True) | ||
expected = xr.DataArray([1, 1], coords={"x": [0, 1]}, attrs={"attr": "x_da"}) | ||
expected["x"].attrs = {"attr": "x_coord"} | ||
assert_identical(expected, actual) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This case should be covered now. Looks like we can pass one DataArray and one Dataset so I've included a test on that, but there are a lot of permutations.
xarray/core/computation.py
Outdated
@@ -1874,6 +1872,24 @@ def where(cond, x, y, keep_attrs=None): | |||
keep_attrs=keep_attrs, | |||
) | |||
|
|||
# make sure we have the attrs of x across Dataset, DataArray, and coords |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably a neater way to do this but it seems to work.
assert actual.attrs == {} | ||
expected = xr.DataArray([1, 0], coords={"x": [0, 1]}) | ||
expected["x"].attrs = {"attr": "cond_coord"} | ||
assert_identical(expected, actual) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now it doesn't take the cond
coord attrs. I still think this would be a convenient (albeit confusing) feature, because I happen to have a bunch of code like xr.where(x>10, 10, x)
and would rather keep the attrs. But this is obviously a better use case for DataArray.where
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the change to explicitly constructing the attrs
instead of working around quirks of apply_ufunc
sounds good to me: when discussing this in the last meeting we did get the feeling that in the long run it would be better to think about redesigning that part of apply_ufunc
.
Yeah I think this would be worth doing eventually. Trying to index a list of attributes of unpredictable length doesn't feel very xarray-like. Any further refinements to the current approach of reconstructing attributes after |
Thanks @slevang sorry for the delay here! |
* upstream/main: (39 commits) Support the new compression argument in netCDF4 > 1.6.0 (pydata#6981) Remove setuptools-scm-git-archive, require setuptools-scm>=7 (pydata#7253) Fix mypy failures (pydata#7343) Docs: add example of writing and reading groups to netcdf (pydata#7338) Reset file pointer to 0 when reading file stream (pydata#7304) Enable mypy warn unused ignores (pydata#7335) Optimize some copying (pydata#7209) Add parse_dims func (pydata#7051) Fix coordinate attr handling in `xr.where(..., keep_attrs=True)` (pydata#7229) Remove code used to support h5py<2.10.0 (pydata#7334) [pre-commit.ci] pre-commit autoupdate (pydata#7330) Fix PR number in what’s new (pydata#7331) Enable `origin` and `offset` arguments in `resample` (pydata#7284) fix doctests: supress urllib3 warning (pydata#7326) fix flake8 config (pydata#7321) implement Zarr v3 spec support (pydata#6475) Fix polyval overloads (pydata#7315) deprecate pynio backend (pydata#7301) mypy - Remove some ignored packages and modules (pydata#7319) Switch to T_DataArray in .coords (pydata#7285) ...
xr.where(..., keep_attrs=True)
overwrites coordinate attributes #7220whats-new.rst
Reverts the
getattr
method used inxr.where(..., keep_attrs=True)
from #6461, but keeps handling for scalar inputs. Adds some test cases to ensure consistent attribute handling.