Skip to content

argmax() causes dask to compute #3237

Closed
@ulijh

Description

@ulijh

Problem Description

While digging for #2511 I found that da.argmax() causes compute on a dask array in nanargmax(a, axis=None):

if mask.any():

I feel like this shouldn't be the case as da.max() and da.data.argmax() don't compute and it renders the laziness useless.

MCVE Code Sample

In [1]: import numpy as np                                                                                                             
   ...: import dask                                                                                                                    
   ...: import xarray as xr                                                                                                            
                                                                                                                                       
In [2]: class Scheduler:                                                                                                               
   ...:     """ From: https://stackoverflow.com/questions/53289286/ """                                                                
   ...:                                                                                                                                
   ...:     def __init__(self, max_computes=0):                                                                                        
   ...:         self.max_computes = max_computes                                                                                       
   ...:         self.total_computes = 0                                                                                                
   ...:                                                                                                                                
   ...:     def __call__(self, dsk, keys, **kwargs):                                                                                   
   ...:         self.total_computes += 1                                                                                               
   ...:         if self.total_computes > self.max_computes:                                                                            
   ...:             raise RuntimeError(                                                                                                
   ...:                 "Too many dask computations were scheduled: {}".format(                                                        
   ...:                     self.total_computes                                                                                        
   ...:                 )                                                                                                              
   ...:             )                                                                                                                  
   ...:         return dask.get(dsk, keys, **kwargs)                                                                                   
   ...:                                                                                                                                
                                                                                                                                       
In [3]: scheduler = Scheduler()                                                                                                        
                                                                                                                                       
In [4]: with dask.config.set(scheduler=scheduler):                                                                                     
   ...:     da = xr.DataArray(                                                                                                         
   ...:         np.random.rand(2*3*4).reshape((2, 3, 4)),                                                                              
   ...:     ).chunk(dict(dim_0=1))                                                                                                     
   ...:                                                                                                                                
   ...:     dim = da.dims[-1]                                                                                                          
   ...:                                                                                                                                
   ...:     dask_idcs = da.data.argmax(axis=-1)  # Computes 0 times                                                                    
   ...:     print("Total number of computes: %d" % scheduler.total_computes)                                                           
   ...:                                                                                                                                
   ...:     da.max(dim)  # Computes 0 times                                                                                            
   ...:     print("Total number of computes: %d" % scheduler.total_computes)                                                           
   ...:                                                                                                                                
   ...:     da.argmax(dim)  # Does compute                                                                                             
   ...:     print("Total number of computes: %d" % scheduler.total_computes)                                                           
   ...:               
Total number of computes: 0                                                                                                                                                                                                                                                       
Total number of computes: 0                                                                                                                                                                                                                                                       
---------------------------------------------------------------------------                                                              
RuntimeError                              Traceback (most recent call last)                                                              
<ipython-input-4-f95c8753dbe6> in <module>                          
     12     print("Total number of computes: %d" % scheduler.total_computes)                                                             
     13                                                             
---> 14     da.argmax(dim)  # Does compute                          
     15     print("Total number of computes: %d" % scheduler.total_computes)                                                             
     16                                                             

~/src/xarray/xarray/core/common.py in wrapped_func(self, dim, axis, skipna, **kwargs)                                                    
     42             def wrapped_func(self, dim=None, axis=None, skipna=None, **kwargs):                                                  
     43                 return self.reduce(                         
---> 44                     func, dim, axis, skipna=skipna, allow_lazy=True, **kwargs                                                    
     45                 )                                           
     46                                                             

~/src/xarray/xarray/core/dataarray.py in reduce(self, func, dim, axis, keep_attrs, keepdims, **kwargs)                                   
   2120         """                                                 
   2121                                                             
-> 2122         var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, **kwargs)                                              
   2123         return self._replace_maybe_drop_dims(var)           
   2124                                                             

~/src/xarray/xarray/core/variable.py in reduce(self, func, dim, axis, keep_attrs, keepdims, allow_lazy, **kwargs)                        
   1456         input_data = self.data if allow_lazy else self.values                                                                    
   1457         if axis is not None:                                
-> 1458             data = func(input_data, axis=axis, **kwargs)    
   1459         else:                                               
   1460             data = func(input_data, **kwargs)               

~/src/xarray/xarray/core/duck_array_ops.py in f(values, axis, skipna, **kwargs)                                                          
    279                                                             
    280         try:                                                
--> 281             return func(values, axis=axis, **kwargs)        
    282         except AttributeError:                              
    283             if isinstance(values, dask_array_type):         

~/src/xarray/xarray/core/nanops.py in nanargmax(a, axis)            
    118     if mask is not None:                                    
    119         mask = mask.all(axis=axis)                          
--> 120         if mask.any():                                      
    121             raise ValueError("All-NaN slice encountered")   
    122     return res                                              

/usr/lib/python3.7/site-packages/dask/array/core.py in __bool__(self)                                                                    
   1370             )                                               
   1371         else:                                               
-> 1372             return bool(self.compute())                     
   1373                                                             
   1374     __nonzero__ = __bool__  # python 2                      

/usr/lib/python3.7/site-packages/dask/base.py in compute(self, **kwargs)                                                                 
    173         dask.base.compute                                   
    174         """                                                 
--> 175         (result,) = compute(self, traverse=False, **kwargs) 
    176         return result                                       
    177                                                             

/usr/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs)                                                                
    444     keys = [x.__dask_keys__() for x in collections]         
    445     postcomputes = [x.__dask_postcompute__() for x in collections]                                                               
--> 446     results = schedule(dsk, keys, **kwargs)                 
    447     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])                                                        
    448        

Expected Output

None of the methods should actually compute:

Total number of computes: 0                                                                                                                                                                                                                                                    
Total number of computes: 0                                                                                                                                                                                                   
Total number of computes: 0   

Output of xr.show_versions()

INSTALLED VERSIONS ------------------ commit: None python: 3.7.4 (default, Jul 16 2019, 07:12:58) [GCC 9.1.0] python-bits: 64 OS: Linux OS-release: 5.2.9-arch1-1-ARCH machine: x86_64 processor: byteorder: little LC_ALL: None LANG: de_DE.utf8 LOCALE: de_DE.UTF-8 libhdf5: 1.10.5 libnetcdf: 4.7.0

xarray: 0.12.3+63.g131f6022
pandas: 0.25.0
numpy: 1.17.0
scipy: 1.3.1
netCDF4: 1.5.1.2
pydap: None
h5netcdf: 0.7.4
h5py: 2.9.0
Nio: None
zarr: None
cftime: 1.0.3.4
nc_time_axis: None
PseudoNetCDF: None
rasterio: 1.0.25
cfgrib: None
iris: None
bottleneck: 1.2.1
dask: 2.1.0
distributed: 1.27.1
matplotlib: 3.1.1
cartopy: 0.17.0
seaborn: 0.9.0
numbagg: None
setuptools: 41.0.1
pip: 19.0.3
conda: None
pytest: 5.0.1
IPython: 7.6.1
sphinx: 2.2.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions