Closed
Description
Problem Description
While digging for #2511 I found that da.argmax()
causes compute on a dask array in nanargmax(a, axis=None)
:
Line 120 in 131f602
I feel like this shouldn't be the case as
da.max()
and da.data.argmax()
don't compute and it renders the laziness useless.
MCVE Code Sample
In [1]: import numpy as np
...: import dask
...: import xarray as xr
In [2]: class Scheduler:
...: """ From: https://stackoverflow.com/questions/53289286/ """
...:
...: def __init__(self, max_computes=0):
...: self.max_computes = max_computes
...: self.total_computes = 0
...:
...: def __call__(self, dsk, keys, **kwargs):
...: self.total_computes += 1
...: if self.total_computes > self.max_computes:
...: raise RuntimeError(
...: "Too many dask computations were scheduled: {}".format(
...: self.total_computes
...: )
...: )
...: return dask.get(dsk, keys, **kwargs)
...:
In [3]: scheduler = Scheduler()
In [4]: with dask.config.set(scheduler=scheduler):
...: da = xr.DataArray(
...: np.random.rand(2*3*4).reshape((2, 3, 4)),
...: ).chunk(dict(dim_0=1))
...:
...: dim = da.dims[-1]
...:
...: dask_idcs = da.data.argmax(axis=-1) # Computes 0 times
...: print("Total number of computes: %d" % scheduler.total_computes)
...:
...: da.max(dim) # Computes 0 times
...: print("Total number of computes: %d" % scheduler.total_computes)
...:
...: da.argmax(dim) # Does compute
...: print("Total number of computes: %d" % scheduler.total_computes)
...:
Total number of computes: 0
Total number of computes: 0
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-4-f95c8753dbe6> in <module>
12 print("Total number of computes: %d" % scheduler.total_computes)
13
---> 14 da.argmax(dim) # Does compute
15 print("Total number of computes: %d" % scheduler.total_computes)
16
~/src/xarray/xarray/core/common.py in wrapped_func(self, dim, axis, skipna, **kwargs)
42 def wrapped_func(self, dim=None, axis=None, skipna=None, **kwargs):
43 return self.reduce(
---> 44 func, dim, axis, skipna=skipna, allow_lazy=True, **kwargs
45 )
46
~/src/xarray/xarray/core/dataarray.py in reduce(self, func, dim, axis, keep_attrs, keepdims, **kwargs)
2120 """
2121
-> 2122 var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, **kwargs)
2123 return self._replace_maybe_drop_dims(var)
2124
~/src/xarray/xarray/core/variable.py in reduce(self, func, dim, axis, keep_attrs, keepdims, allow_lazy, **kwargs)
1456 input_data = self.data if allow_lazy else self.values
1457 if axis is not None:
-> 1458 data = func(input_data, axis=axis, **kwargs)
1459 else:
1460 data = func(input_data, **kwargs)
~/src/xarray/xarray/core/duck_array_ops.py in f(values, axis, skipna, **kwargs)
279
280 try:
--> 281 return func(values, axis=axis, **kwargs)
282 except AttributeError:
283 if isinstance(values, dask_array_type):
~/src/xarray/xarray/core/nanops.py in nanargmax(a, axis)
118 if mask is not None:
119 mask = mask.all(axis=axis)
--> 120 if mask.any():
121 raise ValueError("All-NaN slice encountered")
122 return res
/usr/lib/python3.7/site-packages/dask/array/core.py in __bool__(self)
1370 )
1371 else:
-> 1372 return bool(self.compute())
1373
1374 __nonzero__ = __bool__ # python 2
/usr/lib/python3.7/site-packages/dask/base.py in compute(self, **kwargs)
173 dask.base.compute
174 """
--> 175 (result,) = compute(self, traverse=False, **kwargs)
176 return result
177
/usr/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs)
444 keys = [x.__dask_keys__() for x in collections]
445 postcomputes = [x.__dask_postcompute__() for x in collections]
--> 446 results = schedule(dsk, keys, **kwargs)
447 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
448
Expected Output
None of the methods should actually compute:
Total number of computes: 0
Total number of computes: 0
Total number of computes: 0
Output of xr.show_versions()
INSTALLED VERSIONS
------------------
commit: None
python: 3.7.4 (default, Jul 16 2019, 07:12:58)
[GCC 9.1.0]
python-bits: 64
OS: Linux
OS-release: 5.2.9-arch1-1-ARCH
machine: x86_64
processor:
byteorder: little
LC_ALL: None
LANG: de_DE.utf8
LOCALE: de_DE.UTF-8
libhdf5: 1.10.5
libnetcdf: 4.7.0
xarray: 0.12.3+63.g131f6022
pandas: 0.25.0
numpy: 1.17.0
scipy: 1.3.1
netCDF4: 1.5.1.2
pydap: None
h5netcdf: 0.7.4
h5py: 2.9.0
Nio: None
zarr: None
cftime: 1.0.3.4
nc_time_axis: None
PseudoNetCDF: None
rasterio: 1.0.25
cfgrib: None
iris: None
bottleneck: 1.2.1
dask: 2.1.0
distributed: 1.27.1
matplotlib: 3.1.1
cartopy: 0.17.0
seaborn: 0.9.0
numbagg: None
setuptools: 41.0.1
pip: 19.0.3
conda: None
pytest: 5.0.1
IPython: 7.6.1
sphinx: 2.2.0