Skip to content

Allow for All-NaN in argmax, argmin #3884

@christian-oreilly

Description

@christian-oreilly

Background: In data analyses, it is a common occurrence to have multidimensional datasets with missing conditions. For example, having a data array of power measurements for a multi-channel recording device with dimension nb_channel X nb_subjects X [...], it is current that some channels might be missing for some subjects, in which cases the array will have only NaN for this condition. Functions like Dataset.mean() performs well in this situation and will output a "RuntimeWarning: Mean of empty slice" and set the mean values for this all-NaN slice to NaN, which is what is expected for such a use case. Depending on the use case, the user has the possibility to filter these warnings to either ignore them or raise them as errors. This is all fine.

Problem: However, in the case of the Dataset.argmax(), there is no such option. The function will raise a "ValueError: All-NaN slice encountered" exception. I think it would be better if the behaviour of Dataset.argmax() was modelled on the behaviour of Dataset.mean() such that it would raise warning and set a NaN value. It seems fair to consider that the index that maximize the value of a all-NaN slice is NaN.

The implementation of such a feature may (but don't need to) depend on numpy/numpy#12352

MCVE Code Sample

dat = xr.DataArray(np.ones((3, 2)), coords=dict(x=[0.1, 0.2, 0.3], y=[1, 2]), dims=["x", "y"])
dat[:, 1] = np.nan
dat.mean("x")
dat.argmax("x")

output

<xarray.DataArray (y: 2)>
array([ 1., nan])
Coordinates:
  * y        (y) int64 1 2

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-56-b53e860d1e36> in <module>
      2 dat[:, 1] = np.nan
      3 print(dat.mean("x"))
----> 4 print(dat.argmax("x"))

~/anaconda3/lib/python3.7/site-packages/xarray/core/common.py in wrapped_func(self, dim, axis, skipna, **kwargs)
     44 
     45             def wrapped_func(self, dim=None, axis=None, skipna=None, **kwargs):
---> 46                 return self.reduce(func, dim, axis, skipna=skipna, **kwargs)
     47 
     48         else:

~/anaconda3/lib/python3.7/site-packages/xarray/core/dataarray.py in reduce(self, func, dim, axis, keep_attrs, keepdims, **kwargs)
   2233         """
   2234 
-> 2235         var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, **kwargs)
   2236         return self._replace_maybe_drop_dims(var)
   2237 

~/anaconda3/lib/python3.7/site-packages/xarray/core/variable.py in reduce(self, func, dim, axis, keep_attrs, keepdims, allow_lazy, **kwargs)
   1533 
   1534         if axis is not None:
-> 1535             data = func(input_data, axis=axis, **kwargs)
   1536         else:
   1537             data = func(input_data, **kwargs)

~/anaconda3/lib/python3.7/site-packages/xarray/core/duck_array_ops.py in f(values, axis, skipna, **kwargs)
    305 
    306         try:
--> 307             return func(values, axis=axis, **kwargs)
    308         except AttributeError:
    309             if isinstance(values, dask_array_type):

~/anaconda3/lib/python3.7/site-packages/xarray/core/nanops.py in nanargmax(a, axis)
    105 
    106     module = dask_array if isinstance(a, dask_array_type) else nputils
--> 107     return module.nanargmax(a, axis=axis)
    108 
    109 

~/anaconda3/lib/python3.7/site-packages/xarray/core/nputils.py in f(values, axis, **kwargs)
    213             result = bn_func(values, axis=axis, **kwargs)
    214         else:
--> 215             result = getattr(npmodule, name)(values, axis=axis, **kwargs)
    216 
    217         return result

<__array_function__ internals> in nanargmax(*args, **kwargs)

~/anaconda3/lib/python3.7/site-packages/numpy/lib/nanfunctions.py in nanargmax(a, axis)
    549         mask = np.all(mask, axis=axis)
    550         if np.any(mask):
--> 551             raise ValueError("All-NaN slice encountered")
    552     return res
    553 

ValueError: All-NaN slice encountered

Expected Output

<xarray.DataArray (y: 2)>
array([ 1., nan])
Coordinates:
  * y        (y) int64 1 2
<xarray.DataArray (y: 2)>
array([0, nan])
Coordinates:
  * y        (y) int64 1 2

Versions

Output of `xr.show_versions()`

INSTALLED VERSIONS

commit: None
python: 3.7.4 (default, Aug 13 2019, 20:35:49)
[GCC 7.3.0]
python-bits: 64
OS: Linux
OS-release: 5.3.0-42-generic
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: en_CA.UTF-8
LOCALE: en_CA.UTF-8
libhdf5: 1.10.4
libnetcdf: 4.6.3

xarray: 0.15.0
pandas: 1.0.1
numpy: 1.18.2
scipy: 1.4.1
netCDF4: 1.5.3
pydap: None
h5netcdf: None
h5py: 2.9.0
Nio: None
zarr: None
cftime: 1.0.4.2
nc_time_axis: None
PseudoNetCDF: None
rasterio: None
cfgrib: None
iris: None
bottleneck: None
dask: 2.11.0
distributed: None
matplotlib: 3.1.2
cartopy: None
seaborn: 0.9.0
numbagg: None
setuptools: 45.2.0.post20200210
pip: 20.0.2
conda: 4.8.2
pytest: None
IPython: 7.12.0
sphinx: 2.3.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions