Skip to content

map_blocks output inference problems #3575

@rabernat

Description

@rabernat

I am excited about using map_blocks to overcome a long-standing challenge related to calculating climatologies / anomalies with dask arrays. However, I hit what feels like a bug. I don't love how the new map_blocks function does this:

The function will be first run on mocked-up data, that looks like ‘obj’ but has sizes 0, to determine properties of the returned object such as dtype, variable names, new dimensions and new indexes (if any).

The problem is that many functions will simply error on size 0 data. As in the example below

MCVE Code Sample

import xarray as xr
ds = xr.tutorial.load_dataset('rasm').chunk({'y': 20})


def calculate_anomaly(ds):
    # needed to workaround xarray's check with zero dimensions
    #if len(ds['time']) == 0:
    #    return ds
    gb = ds.groupby("time.month")
    clim = gb.mean(dim='T')
    return gb - clim

xr.map_blocks(calculate_anomaly, ds)

Raises

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
/srv/conda/envs/notebook/lib/python3.7/site-packages/xarray/core/dataset.py in _construct_dataarray(self, name)
   1145         try:
-> 1146             variable = self._variables[name]
   1147         except KeyError:

KeyError: 'time.month'

During handling of the above exception, another exception occurred:

AttributeError                            Traceback (most recent call last)
/srv/conda/envs/notebook/lib/python3.7/site-packages/xarray/core/parallel.py in infer_template(func, obj, *args, **kwargs)
     77     try:
---> 78         template = func(*meta_args, **kwargs)
     79     except Exception as e:

<ipython-input-40-d7b2b2978c29> in calculate_anomaly(ds)
      5     #    return ds
----> 6     gb = ds.groupby("time.month")
      7     clim = gb.mean(dim='T')

/srv/conda/envs/notebook/lib/python3.7/site-packages/xarray/core/common.py in groupby(self, group, squeeze, restore_coord_dims)
    656         return self._groupby_cls(
--> 657             self, group, squeeze=squeeze, restore_coord_dims=restore_coord_dims
    658         )

/srv/conda/envs/notebook/lib/python3.7/site-packages/xarray/core/groupby.py in __init__(self, obj, group, squeeze, grouper, bins, restore_coord_dims, cut_kwargs)
    298                 )
--> 299             group = obj[group]
    300             if len(group) == 0:

/srv/conda/envs/notebook/lib/python3.7/site-packages/xarray/core/dataset.py in __getitem__(self, key)
   1235         if hashable(key):
-> 1236             return self._construct_dataarray(key)
   1237         else:

/srv/conda/envs/notebook/lib/python3.7/site-packages/xarray/core/dataset.py in _construct_dataarray(self, name)
   1148             _, name, variable = _get_virtual_variable(
-> 1149                 self._variables, name, self._level_coords, self.dims
   1150             )

/srv/conda/envs/notebook/lib/python3.7/site-packages/xarray/core/dataset.py in _get_virtual_variable(variables, key, level_vars, dim_sizes)
    157         else:
--> 158             data = getattr(ref_var, var_name).data
    159         virtual_var = Variable(ref_var.dims, data)

AttributeError: 'IndexVariable' object has no attribute 'month'

The above exception was the direct cause of the following exception:

Exception                                 Traceback (most recent call last)
<ipython-input-40-d7b2b2978c29> in <module>
      8     return gb - clim
      9 
---> 10 xr.map_blocks(calculate_anomaly, ds)

/srv/conda/envs/notebook/lib/python3.7/site-packages/xarray/core/parallel.py in map_blocks(func, obj, args, kwargs)
    203     input_chunks = dataset.chunks
    204 
--> 205     template: Union[DataArray, Dataset] = infer_template(func, obj, *args, **kwargs)
    206     if isinstance(template, DataArray):
    207         result_is_array = True

/srv/conda/envs/notebook/lib/python3.7/site-packages/xarray/core/parallel.py in infer_template(func, obj, *args, **kwargs)
     80         raise Exception(
     81             "Cannot infer object returned from running user provided function."
---> 82         ) from e
     83 
     84     if not isinstance(template, (Dataset, DataArray)):

Exception: Cannot infer object returned from running user provided function.

Problem Description

We should try to imitate what dask does in map_blocks: https://docs.dask.org/en/latest/array-api.html#dask.array.map_blocks

Specifically:

  • We should allow the user to override the checks by explicitly specifying output dtype and shape
  • Maybe the check should be on small, rather than zero size, test data

Output of xr.show_versions()

# Paste the output here xr.show_versions() here

INSTALLED VERSIONS

commit: None
python: 3.7.3 | packaged by conda-forge | (default, Jul 1 2019, 21:52:21)
[GCC 7.3.0]
python-bits: 64
OS: Linux
OS-release: 4.14.138+
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: en_US.UTF-8
LANG: en_US.UTF-8
LOCALE: en_US.UTF-8
libhdf5: 1.10.5
libnetcdf: 4.6.2

xarray: 0.14.0
pandas: 0.25.3
numpy: 1.17.3
scipy: 1.3.2
netCDF4: 1.5.1.2
pydap: installed
h5netcdf: 0.7.4
h5py: 2.10.0
Nio: None
zarr: 2.3.2
cftime: 1.0.4.2
nc_time_axis: 1.2.0
PseudoNetCDF: None
rasterio: 1.0.25
cfgrib: None
iris: 2.2.0
bottleneck: 1.3.0
dask: 2.7.0
distributed: 2.7.0
matplotlib: 3.1.2
cartopy: 0.17.0
seaborn: 0.9.0
numbagg: None
setuptools: 41.6.0.post20191101
pip: 19.3.1
conda: None
pytest: 5.3.1
IPython: 7.9.0
sphinx: None

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions