as_shared_dtype coerces scalars into numpy regardless of other array types #4231
Description
Related to #4212
When trying to get the Calculating Seasonal Averages from Timeseries of Monthly Means example from the documentation to work with cupy
I'm experiencing an unexpected Unsupported type <class 'numpy.ndarray'>
error when calling ds_unweighted = ds.groupby('time.season').mean('time')
I dug through this with @quasiben and it seems to be related to the as_shared_dtype
function.
What happened:
Running the MCVE below results in Unsupported type <class 'numpy.ndarray'>
. It seems at somewhere in the stack there is a call to _replace_nan(a, 0)
where the cupy array is having nan values replaced with 0
. This ends up as a call to xarray.core.duck_array_ops.where
with the "is nan", 0
and the cupy array being passed.
However _where
calls as_shared_dtype
on the 0
and cupy
array, which converts the 0
to a scalar numpy array.
Cupy is then passed this numpy array to it's where function which does raises the exception.
What you expected to happen:
The cupy.where
function can either take a Python int/float or a cupy array, not a numpy scalar.
Therefore a few things could be done here:
- Xarray could not convert the int/float to a numpy array
- It could convert it to a cupy array
- Cupy could be modified to accept a numpy scalar.
We thew together a quick fix for option 2, which I'll put in a draft PR. But happy to discuss the alternatives.
Minimal Complete Verifiable Example:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import cupy as cp
# Load data
ds = xr.tutorial.open_dataset("rasm").load()
# Move data to GPU
ds.Tair.data = cp.asarray(ds.Tair.data)
ds_unweighted = ds.groupby("time.season").mean("time")
# Calculate the weights by grouping by 'time.season'.
month_length = ds.time.dt.days_in_month
weights = (
month_length.groupby("time.season") / month_length.groupby("time.season").sum()
)
# Test that the sum of the weights for each season is 1.0
np.testing.assert_allclose(weights.groupby("time.season").sum().values, np.ones(4))
# Move weights to GPU
weights.data = cp.asarray(weights.data)
# Calculate the weighted average
ds_weighted = ds * weights
ds_weighted = ds_weighted.groupby("time.season")
ds_weighted = ds_weighted.sum(dim="time")
Traceback
Traceback (most recent call last):
File "/home/jacob/miniconda3/envs/dask/lib/python3.7/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/home/jacob/miniconda3/envs/dask/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/home/jacob/.vscode-server/extensions/ms-python.python-2020.6.91350/pythonFiles/lib/python/debugpy/__main__.py", line 45, in <module>
cli.main()
File "/home/jacob/.vscode-server/extensions/ms-python.python-2020.6.91350/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 430, in main
run()
File "/home/jacob/.vscode-server/extensions/ms-python.python-2020.6.91350/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 267, in run_file
runpy.run_path(options.target, run_name=compat.force_str("__main__"))
File "/home/jacob/miniconda3/envs/dask/lib/python3.7/runpy.py", line 263, in run_path
pkg_name=pkg_name, script_name=fname)
File "/home/jacob/miniconda3/envs/dask/lib/python3.7/runpy.py", line 96, in _run_module_code
mod_name, mod_spec, pkg_name, script_name)
File "/home/jacob/miniconda3/envs/dask/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/home/jacob/Projects/pydata/xarray/test_seasonal_averages.py", line 32, in <module>
ds_weighted = ds_weighted.sum(dim="time")
File "/home/jacob/Projects/pydata/xarray/xarray/core/common.py", line 84, in wrapped_func
func, dim, skipna=skipna, numeric_only=numeric_only, **kwargs
File "/home/jacob/Projects/pydata/xarray/xarray/core/groupby.py", line 994, in reduce
return self.map(reduce_dataset)
File "/home/jacob/Projects/pydata/xarray/xarray/core/groupby.py", line 923, in map
return self._combine(applied)
File "/home/jacob/Projects/pydata/xarray/xarray/core/groupby.py", line 943, in _combine
applied_example, applied = peek_at(applied)
File "/home/jacob/Projects/pydata/xarray/xarray/core/utils.py", line 183, in peek_at
peek = next(gen)
File "/home/jacob/Projects/pydata/xarray/xarray/core/groupby.py", line 922, in <genexpr>
applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped())
File "/home/jacob/Projects/pydata/xarray/xarray/core/groupby.py", line 990, in reduce_dataset
return ds.reduce(func, dim, keep_attrs, **kwargs)
File "/home/jacob/Projects/pydata/xarray/xarray/core/dataset.py", line 4313, in reduce
**kwargs,
File "/home/jacob/Projects/pydata/xarray/xarray/core/variable.py", line 1591, in reduce
data = func(input_data, axis=axis, **kwargs)
File "/home/jacob/Projects/pydata/xarray/xarray/core/duck_array_ops.py", line 324, in f
return func(values, axis=axis, **kwargs)
File "/home/jacob/Projects/pydata/xarray/xarray/core/nanops.py", line 111, in nansum
a, mask = _replace_nan(a, 0)
File "/home/jacob/Projects/pydata/xarray/xarray/core/nanops.py", line 21, in _replace_nan
return where_method(val, mask, a), mask
File "/home/jacob/Projects/pydata/xarray/xarray/core/duck_array_ops.py", line 274, in where_method
return where(cond, data, other)
File "/home/jacob/Projects/pydata/xarray/xarray/core/duck_array_ops.py", line 268, in where
return _where(condition, *as_shared_dtype([x, y]))
File "/home/jacob/Projects/pydata/xarray/xarray/core/duck_array_ops.py", line 56, in f
return wrapped(*args, **kwargs)
File "<__array_function__ internals>", line 6, in where
File "cupy/core/core.pyx", line 1343, in cupy.core.core.ndarray.__array_function__
File "/home/jacob/miniconda3/envs/dask/lib/python3.7/site-packages/cupy/sorting/search.py", line 211, in where
return _where_ufunc(condition.astype('?'), x, y)
File "cupy/core/_kernel.pyx", line 906, in cupy.core._kernel.ufunc.__call__
File "cupy/core/_kernel.pyx", line 90, in cupy.core._kernel._preprocess_args
TypeError: Unsupported type <class 'numpy.ndarray'>
Anything else we need to know?:
Environment:
Output of xr.show_versions()
INSTALLED VERSIONS
commit: 52043bc
python: 3.7.6 | packaged by conda-forge | (default, Jun 1 2020, 18:57:50)
[GCC 7.5.0]
python-bits: 64
OS: Linux
OS-release: 5.3.0-62-generic
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: en_GB.UTF-8
LOCALE: en_GB.UTF-8
libhdf5: None
libnetcdf: None
xarray: 0.15.1
pandas: 0.25.3
numpy: 1.18.5
scipy: 1.5.0
netCDF4: None
pydap: None
h5netcdf: None
h5py: None
Nio: None
zarr: None
cftime: 1.2.0
nc_time_axis: None
PseudoNetCDF: None
rasterio: None
cfgrib: 0.9.8.3
iris: None
bottleneck: None
dask: 2.20.0
distributed: 2.20.0
matplotlib: 3.2.2
cartopy: 0.17.0
seaborn: 0.10.1
numbagg: None
pint: None
setuptools: 49.1.0.post20200704
pip: 20.1.1
conda: None
pytest: 5.4.3
IPython: 7.16.1
sphinx: None