Description
What happened?
I'm testing my code with numpy 2.0 and current main
xarray and dask and ran into a change that I guess is expected given the way xarray does things, but want to make sure as it could be unexpected for many users.
Doing DataArray.where
with an integer array less than 64-bits and an integer as the new value will upcast the array to 64-bit integers (python's int
). With old versions of numpy this would preserve the dtype of the array. As far as I can tell the relevant xarray code hasn't changed so this seems to be more about numpy making things more consistent.
The main problem seems to come down to:
xarray/xarray/core/duck_array_ops.py
Line 218 in d933578
As this converts my scalar input int
to a numpy array. If it didn't do this array conversion then numpy works as expected. See the MCVE for the xarray specific example, but here's the numpy equivalent:
import numpy as np
a = np.zeros((2, 2), dtype=np.uint16)
# what I'm intending to do with my xarray `data_arr.where(cond, 2)`
np.where(a != 0, a, 2).dtype
# dtype('uint16')
# equivalent to what xarray does:
np.where(a != 0, a, np.asarray(2)).dtype
# dtype('int64')
# workaround, cast my scalar to a specific numpy type
np.where(a != 0, a, np.asarray(np.uint16(2))).dtype
# dtype('uint16')
From a numpy point of view, the second where call makes sense that 2 arrays should be upcast to the same dtype so they can be combined. But from an xarray user point of view, I'm entering a scalar so I expect it to be the same as the first where call above.
What did you expect to happen?
See above.
Minimal Complete Verifiable Example
import xarray as xr
import numpy as np
data_arr = xr.DataArray(np.array([1, 2], dtype=np.uint16))
print(data_arr.where(data_arr == 2, 3).dtype)
# int64
MVCE confirmation
- Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
- Complete example — the example is self-contained, including all data and the text of any traceback.
- Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
- New issue — a search of GitHub Issues suggests this is not a duplicate.
- Recent environment — the issue occurs with the latest version of xarray and its dependencies.
Relevant log output
No response
Anything else we need to know?
Numpy 1.x preserves the dtype.
In [1]: import numpy as np
In [2]: np.asarray(2).dtype
Out[2]: dtype('int64')
In [3]: a = np.zeros((2, 2), dtype=np.uint16)
In [4]: np.where(a != 0, a, np.asarray(2)).dtype
Out[4]: dtype('uint16')
In [5]: np.where(a != 0, a, np.asarray(np.uint16(2))).dtype
Out[5]: dtype('uint16')
Environment
INSTALLED VERSIONS
------------------
commit: None
python: 3.11.4 | packaged by conda-forge | (main, Jun 10 2023, 18:08:17) [GCC 12.2.0]
python-bits: 64
OS: Linux
OS-release: 6.4.6-76060406-generic
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: en_US.UTF-8
LOCALE: ('en_US', 'UTF-8')
libhdf5: 1.14.2
libnetcdf: 4.9.2
xarray: 2023.10.2.dev21+gfcdc8102
pandas: 2.2.0.dev0+495.gecf449b503
numpy: 2.0.0.dev0+git20231031.42c33f3
scipy: 1.12.0.dev0+1903.18d0a2f
netCDF4: 1.6.5
pydap: None
h5netcdf: 1.2.0
h5py: 3.10.0
Nio: None
zarr: 2.16.1
cftime: 1.6.3
nc_time_axis: None
PseudoNetCDF: None
iris: None
bottleneck: 1.3.7.post0.dev7
dask: 2023.10.1+4.g91098a63
distributed: 2023.10.1+5.g76dd8003
matplotlib: 3.9.0.dev0
cartopy: None
seaborn: None
numbagg: None
fsspec: 2023.6.0
cupy: None
pint: 0.22
sparse: None
flox: None
numpy_groupies: None
setuptools: 68.0.0
pip: 23.2.1
conda: None
pytest: 7.4.0
mypy: None
IPython: 8.14.0
sphinx: 7.1.2