Skip to content

where dtype upcast with numpy 2 #8402

Closed
@djhoese

Description

@djhoese

What happened?

I'm testing my code with numpy 2.0 and current main xarray and dask and ran into a change that I guess is expected given the way xarray does things, but want to make sure as it could be unexpected for many users.

Doing DataArray.where with an integer array less than 64-bits and an integer as the new value will upcast the array to 64-bit integers (python's int). With old versions of numpy this would preserve the dtype of the array. As far as I can tell the relevant xarray code hasn't changed so this seems to be more about numpy making things more consistent.

The main problem seems to come down to:

arrays = [asarray(x, xp=xp) for x in scalars_or_arrays]

As this converts my scalar input int to a numpy array. If it didn't do this array conversion then numpy works as expected. See the MCVE for the xarray specific example, but here's the numpy equivalent:

import numpy as np

a = np.zeros((2, 2), dtype=np.uint16)

# what I'm intending to do with my xarray `data_arr.where(cond, 2)`
np.where(a != 0, a, 2).dtype
# dtype('uint16')

# equivalent to what xarray does:
np.where(a != 0, a, np.asarray(2)).dtype
# dtype('int64')

# workaround, cast my scalar to a specific numpy type
np.where(a != 0, a, np.asarray(np.uint16(2))).dtype
# dtype('uint16')

From a numpy point of view, the second where call makes sense that 2 arrays should be upcast to the same dtype so they can be combined. But from an xarray user point of view, I'm entering a scalar so I expect it to be the same as the first where call above.

What did you expect to happen?

See above.

Minimal Complete Verifiable Example

import xarray as xr
import numpy as np

data_arr = xr.DataArray(np.array([1, 2], dtype=np.uint16))
print(data_arr.where(data_arr == 2, 3).dtype)
# int64

MVCE confirmation

  • Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
  • Complete example — the example is self-contained, including all data and the text of any traceback.
  • Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
  • New issue — a search of GitHub Issues suggests this is not a duplicate.
  • Recent environment — the issue occurs with the latest version of xarray and its dependencies.

Relevant log output

No response

Anything else we need to know?

Numpy 1.x preserves the dtype.

In [1]: import numpy as np

In [2]: np.asarray(2).dtype
Out[2]: dtype('int64')

In [3]: a = np.zeros((2, 2), dtype=np.uint16)

In [4]: np.where(a != 0, a, np.asarray(2)).dtype
Out[4]: dtype('uint16')

In [5]: np.where(a != 0, a, np.asarray(np.uint16(2))).dtype
Out[5]: dtype('uint16')

Environment

INSTALLED VERSIONS
------------------
commit: None
python: 3.11.4 | packaged by conda-forge | (main, Jun 10 2023, 18:08:17) [GCC 12.2.0]
python-bits: 64
OS: Linux
OS-release: 6.4.6-76060406-generic
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: en_US.UTF-8
LOCALE: ('en_US', 'UTF-8')
libhdf5: 1.14.2
libnetcdf: 4.9.2

xarray: 2023.10.2.dev21+gfcdc8102
pandas: 2.2.0.dev0+495.gecf449b503
numpy: 2.0.0.dev0+git20231031.42c33f3
scipy: 1.12.0.dev0+1903.18d0a2f
netCDF4: 1.6.5
pydap: None
h5netcdf: 1.2.0
h5py: 3.10.0
Nio: None
zarr: 2.16.1
cftime: 1.6.3
nc_time_axis: None
PseudoNetCDF: None
iris: None
bottleneck: 1.3.7.post0.dev7
dask: 2023.10.1+4.g91098a63
distributed: 2023.10.1+5.g76dd8003
matplotlib: 3.9.0.dev0
cartopy: None
seaborn: None
numbagg: None
fsspec: 2023.6.0
cupy: None
pint: 0.22
sparse: None
flox: None
numpy_groupies: None
setuptools: 68.0.0
pip: 23.2.1
conda: None
pytest: 7.4.0
mypy: None
IPython: 8.14.0
sphinx: 7.1.2

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions