Skip to content

as_shared_dtype coerces scalars into numpy regardless of other array types #4231



Related to #4212

When trying to get the Calculating Seasonal Averages from Timeseries of Monthly Means example from the documentation to work with cupy I'm experiencing an unexpected Unsupported type <class 'numpy.ndarray'> error when calling ds_unweighted = ds.groupby('time.season').mean('time')

I dug through this with @quasiben and it seems to be related to the as_shared_dtype function.

What happened:

Running the MCVE below results in Unsupported type <class 'numpy.ndarray'>. It seems at somewhere in the stack there is a call to _replace_nan(a, 0) where the cupy array is having nan values replaced with 0. This ends up as a call to xarray.core.duck_array_ops.where with the "is nan", 0 and the cupy array being passed.

However _where calls as_shared_dtype on the 0 and cupy array, which converts the 0 to a scalar numpy array.

Cupy is then passed this numpy array to it's where function which does raises the exception.

What you expected to happen:

The cupy.where function can either take a Python int/float or a cupy array, not a numpy scalar.

Therefore a few things could be done here:

  1. Xarray could not convert the int/float to a numpy array
  2. It could convert it to a cupy array
  3. Cupy could be modified to accept a numpy scalar.

We thew together a quick fix for option 2, which I'll put in a draft PR. But happy to discuss the alternatives.

Minimal Complete Verifiable Example:

import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

import cupy as cp

# Load data
ds = xr.tutorial.open_dataset("rasm").load()

# Move data to GPU = cp.asarray(

ds_unweighted = ds.groupby("time.season").mean("time")

# Calculate the weights by grouping by 'time.season'.
month_length = ds.time.dt.days_in_month
weights = (
    month_length.groupby("time.season") / month_length.groupby("time.season").sum()
# Test that the sum of the weights for each season is 1.0
np.testing.assert_allclose(weights.groupby("time.season").sum().values, np.ones(4))

# Move weights to GPU = cp.asarray(

# Calculate the weighted average
ds_weighted = ds * weights
ds_weighted = ds_weighted.groupby("time.season")
ds_weighted = ds_weighted.sum(dim="time")
Traceback (most recent call last):
  File "/home/jacob/miniconda3/envs/dask/lib/python3.7/", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/jacob/miniconda3/envs/dask/lib/python3.7/", line 85, in _run_code
    exec(code, run_globals)
  File "/home/jacob/.vscode-server/extensions/ms-python.python-2020.6.91350/pythonFiles/lib/python/debugpy/", line 45, in <module>
  File "/home/jacob/.vscode-server/extensions/ms-python.python-2020.6.91350/pythonFiles/lib/python/debugpy/../debugpy/server/", line 430, in main
  File "/home/jacob/.vscode-server/extensions/ms-python.python-2020.6.91350/pythonFiles/lib/python/debugpy/../debugpy/server/", line 267, in run_file
    runpy.run_path(, run_name=compat.force_str("__main__"))
  File "/home/jacob/miniconda3/envs/dask/lib/python3.7/", line 263, in run_path
    pkg_name=pkg_name, script_name=fname)
  File "/home/jacob/miniconda3/envs/dask/lib/python3.7/", line 96, in _run_module_code
    mod_name, mod_spec, pkg_name, script_name)
  File "/home/jacob/miniconda3/envs/dask/lib/python3.7/", line 85, in _run_code
    exec(code, run_globals)
  File "/home/jacob/Projects/pydata/xarray/", line 32, in <module>
    ds_weighted = ds_weighted.sum(dim="time")
  File "/home/jacob/Projects/pydata/xarray/xarray/core/", line 84, in wrapped_func
    func, dim, skipna=skipna, numeric_only=numeric_only, **kwargs
  File "/home/jacob/Projects/pydata/xarray/xarray/core/", line 994, in reduce
  File "/home/jacob/Projects/pydata/xarray/xarray/core/", line 923, in map
    return self._combine(applied)
  File "/home/jacob/Projects/pydata/xarray/xarray/core/", line 943, in _combine
    applied_example, applied = peek_at(applied)
  File "/home/jacob/Projects/pydata/xarray/xarray/core/", line 183, in peek_at
    peek = next(gen)
  File "/home/jacob/Projects/pydata/xarray/xarray/core/", line 922, in <genexpr>
    applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped())
  File "/home/jacob/Projects/pydata/xarray/xarray/core/", line 990, in reduce_dataset
    return ds.reduce(func, dim, keep_attrs, **kwargs)
  File "/home/jacob/Projects/pydata/xarray/xarray/core/", line 4313, in reduce
  File "/home/jacob/Projects/pydata/xarray/xarray/core/", line 1591, in reduce
    data = func(input_data, axis=axis, **kwargs)
  File "/home/jacob/Projects/pydata/xarray/xarray/core/", line 324, in f
    return func(values, axis=axis, **kwargs)
  File "/home/jacob/Projects/pydata/xarray/xarray/core/", line 111, in nansum
    a, mask = _replace_nan(a, 0)
  File "/home/jacob/Projects/pydata/xarray/xarray/core/", line 21, in _replace_nan
    return where_method(val, mask, a), mask
  File "/home/jacob/Projects/pydata/xarray/xarray/core/", line 274, in where_method
    return where(cond, data, other)
  File "/home/jacob/Projects/pydata/xarray/xarray/core/", line 268, in where
    return _where(condition, *as_shared_dtype([x, y]))
  File "/home/jacob/Projects/pydata/xarray/xarray/core/", line 56, in f
    return wrapped(*args, **kwargs)
  File "<__array_function__ internals>", line 6, in where
  File "cupy/core/core.pyx", line 1343, in cupy.core.core.ndarray.__array_function__
  File "/home/jacob/miniconda3/envs/dask/lib/python3.7/site-packages/cupy/sorting/", line 211, in where
    return _where_ufunc(condition.astype('?'), x, y)
  File "cupy/core/_kernel.pyx", line 906, in cupy.core._kernel.ufunc.__call__
  File "cupy/core/_kernel.pyx", line 90, in cupy.core._kernel._preprocess_args
TypeError: Unsupported type <class 'numpy.ndarray'>

Anything else we need to know?:


Output of xr.show_versions()


commit: 52043bc
python: 3.7.6 | packaged by conda-forge | (default, Jun 1 2020, 18:57:50)
[GCC 7.5.0]
python-bits: 64
OS: Linux
OS-release: 5.3.0-62-generic
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
libhdf5: None
libnetcdf: None

xarray: 0.15.1
pandas: 0.25.3
numpy: 1.18.5
scipy: 1.5.0
netCDF4: None
pydap: None
h5netcdf: None
h5py: None
Nio: None
zarr: None
cftime: 1.2.0
nc_time_axis: None
PseudoNetCDF: None
rasterio: None
iris: None
bottleneck: None
dask: 2.20.0
distributed: 2.20.0
matplotlib: 3.2.2
cartopy: 0.17.0
seaborn: 0.10.1
numbagg: None
pint: None
setuptools: 49.1.0.post20200704
pip: 20.1.1
conda: None
pytest: 5.4.3
IPython: 7.16.1
sphinx: None



No one assigned


    No labels
    No labels


    No type


    No projects


    No milestone


    None yet


    No branches or pull requests

    Issue actions