-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Description
Background: In data analyses, it is a common occurrence to have multidimensional datasets with missing conditions. For example, having a data array of power measurements for a multi-channel recording device with dimension nb_channel X nb_subjects X [...], it is current that some channels might be missing for some subjects, in which cases the array will have only NaN for this condition. Functions like Dataset.mean() performs well in this situation and will output a "RuntimeWarning: Mean of empty slice" and set the mean values for this all-NaN slice to NaN, which is what is expected for such a use case. Depending on the use case, the user has the possibility to filter these warnings to either ignore them or raise them as errors. This is all fine.
Problem: However, in the case of the Dataset.argmax(), there is no such option. The function will raise a "ValueError: All-NaN slice encountered" exception. I think it would be better if the behaviour of Dataset.argmax() was modelled on the behaviour of Dataset.mean() such that it would raise warning and set a NaN value. It seems fair to consider that the index that maximize the value of a all-NaN slice is NaN.
The implementation of such a feature may (but don't need to) depend on numpy/numpy#12352
MCVE Code Sample
dat = xr.DataArray(np.ones((3, 2)), coords=dict(x=[0.1, 0.2, 0.3], y=[1, 2]), dims=["x", "y"])
dat[:, 1] = np.nan
dat.mean("x")
dat.argmax("x")
output
<xarray.DataArray (y: 2)>
array([ 1., nan])
Coordinates:
* y (y) int64 1 2
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-56-b53e860d1e36> in <module>
2 dat[:, 1] = np.nan
3 print(dat.mean("x"))
----> 4 print(dat.argmax("x"))
~/anaconda3/lib/python3.7/site-packages/xarray/core/common.py in wrapped_func(self, dim, axis, skipna, **kwargs)
44
45 def wrapped_func(self, dim=None, axis=None, skipna=None, **kwargs):
---> 46 return self.reduce(func, dim, axis, skipna=skipna, **kwargs)
47
48 else:
~/anaconda3/lib/python3.7/site-packages/xarray/core/dataarray.py in reduce(self, func, dim, axis, keep_attrs, keepdims, **kwargs)
2233 """
2234
-> 2235 var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, **kwargs)
2236 return self._replace_maybe_drop_dims(var)
2237
~/anaconda3/lib/python3.7/site-packages/xarray/core/variable.py in reduce(self, func, dim, axis, keep_attrs, keepdims, allow_lazy, **kwargs)
1533
1534 if axis is not None:
-> 1535 data = func(input_data, axis=axis, **kwargs)
1536 else:
1537 data = func(input_data, **kwargs)
~/anaconda3/lib/python3.7/site-packages/xarray/core/duck_array_ops.py in f(values, axis, skipna, **kwargs)
305
306 try:
--> 307 return func(values, axis=axis, **kwargs)
308 except AttributeError:
309 if isinstance(values, dask_array_type):
~/anaconda3/lib/python3.7/site-packages/xarray/core/nanops.py in nanargmax(a, axis)
105
106 module = dask_array if isinstance(a, dask_array_type) else nputils
--> 107 return module.nanargmax(a, axis=axis)
108
109
~/anaconda3/lib/python3.7/site-packages/xarray/core/nputils.py in f(values, axis, **kwargs)
213 result = bn_func(values, axis=axis, **kwargs)
214 else:
--> 215 result = getattr(npmodule, name)(values, axis=axis, **kwargs)
216
217 return result
<__array_function__ internals> in nanargmax(*args, **kwargs)
~/anaconda3/lib/python3.7/site-packages/numpy/lib/nanfunctions.py in nanargmax(a, axis)
549 mask = np.all(mask, axis=axis)
550 if np.any(mask):
--> 551 raise ValueError("All-NaN slice encountered")
552 return res
553
ValueError: All-NaN slice encountered
Expected Output
<xarray.DataArray (y: 2)>
array([ 1., nan])
Coordinates:
* y (y) int64 1 2
<xarray.DataArray (y: 2)>
array([0, nan])
Coordinates:
* y (y) int64 1 2
Versions
Output of `xr.show_versions()`
INSTALLED VERSIONS
commit: None
python: 3.7.4 (default, Aug 13 2019, 20:35:49)
[GCC 7.3.0]
python-bits: 64
OS: Linux
OS-release: 5.3.0-42-generic
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: en_CA.UTF-8
LOCALE: en_CA.UTF-8
libhdf5: 1.10.4
libnetcdf: 4.6.3
xarray: 0.15.0
pandas: 1.0.1
numpy: 1.18.2
scipy: 1.4.1
netCDF4: 1.5.3
pydap: None
h5netcdf: None
h5py: 2.9.0
Nio: None
zarr: None
cftime: 1.0.4.2
nc_time_axis: None
PseudoNetCDF: None
rasterio: None
cfgrib: None
iris: None
bottleneck: None
dask: 2.11.0
distributed: None
matplotlib: 3.1.2
cartopy: None
seaborn: 0.9.0
numbagg: None
setuptools: 45.2.0.post20200210
pip: 20.0.2
conda: 4.8.2
pytest: None
IPython: 7.12.0
sphinx: 2.3.1