Skip to content

assert_equal and dask #3350

Closed
Closed
@dcherian

Description

@dcherian

MCVE Code Sample

Example 1

import xarray as xr
import numpy as np

da = xr.DataArray(np.random.randn(10, 20), name="a")
ds = da.to_dataset()
xr.testing.assert_equal(da, da.chunk({"dim_0": 2})) # works
xr.testing.assert_equal(da.chunk(), da.chunk({"dim_0": 2})) # works

xr.testing.assert_equal(ds, ds.chunk({"dim_0": 2})) # works
xr.testing.assert_equal(ds.chunk(), ds.chunk({"dim_0": 2}))  # does not work

I get

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-1-bc8216a67408> in <module>
      8 
      9 xr.testing.assert_equal(ds, ds.chunk({"dim_0": 2})) # works
---> 10 xr.testing.assert_equal(ds.chunk(), ds.chunk({"dim_0": 2}))  # does not work

~/work/python/xarray/xarray/testing.py in assert_equal(a, b)
     56         assert a.equals(b), formatting.diff_array_repr(a, b, "equals")
     57     elif isinstance(a, Dataset):
---> 58         assert a.equals(b), formatting.diff_dataset_repr(a, b, "equals")
     59     else:
     60         raise TypeError("{} not supported by assertion comparison".format(type(a)))

~/work/python/xarray/xarray/core/dataset.py in equals(self, other)
   1322         """
   1323         try:
-> 1324             return self._all_compat(other, "equals")
   1325         except (TypeError, AttributeError):
   1326             return False

~/work/python/xarray/xarray/core/dataset.py in _all_compat(self, other, compat_str)
   1285 
   1286         return self._coord_names == other._coord_names and utils.dict_equiv(
-> 1287             self._variables, other._variables, compat=compat
   1288         )
   1289 

~/work/python/xarray/xarray/core/utils.py in dict_equiv(first, second, compat)
    335     """
    336     for k in first:
--> 337         if k not in second or not compat(first[k], second[k]):
    338             return False
    339     for k in second:

~/work/python/xarray/xarray/core/dataset.py in compat(x, y)
   1282         # require matching order for equality
   1283         def compat(x: Variable, y: Variable) -> bool:
-> 1284             return getattr(x, compat_str)(y)
   1285 
   1286         return self._coord_names == other._coord_names and utils.dict_equiv(

~/work/python/xarray/xarray/core/variable.py in equals(self, other, equiv)
   1558         try:
   1559             return self.dims == other.dims and (
-> 1560                 self._data is other._data or equiv(self.data, other.data)
   1561             )
   1562         except (TypeError, AttributeError):

~/work/python/xarray/xarray/core/duck_array_ops.py in array_equiv(arr1, arr2)
    201         warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
    202         flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2))
--> 203         return bool(flag_array.all())
    204 
    205 

~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/array/core.py in __bool__(self)
   1380             )
   1381         else:
-> 1382             return bool(self.compute())
   1383 
   1384     __nonzero__ = __bool__  # python 2

~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/base.py in compute(self, **kwargs)
    173         dask.base.compute
    174         """
--> 175         (result,) = compute(self, traverse=False, **kwargs)
    176         return result
    177 

~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs)
    444     keys = [x.__dask_keys__() for x in collections]
    445     postcomputes = [x.__dask_postcompute__() for x in collections]
--> 446     results = schedule(dsk, keys, **kwargs)
    447     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    448 

~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/threaded.py in get(dsk, result, cache, num_workers, pool, **kwargs)
     80         get_id=_thread_get_id,
     81         pack_exception=pack_exception,
---> 82         **kwargs
     83     )
     84 

~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/local.py in get_async(apply_async, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, **kwargs)
    489                         _execute_task(task, data)  # Re-execute locally
    490                     else:
--> 491                         raise_exception(exc, tb)
    492                 res, worker_id = loads(res_info)
    493                 state["cache"][key] = res

~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/compatibility.py in reraise(exc, tb)
    128         if exc.__traceback__ is not tb:
    129             raise exc.with_traceback(tb)
--> 130         raise exc
    131 
    132     import pickle as cPickle

~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/local.py in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    231     try:
    232         task, data = loads(task_info)
--> 233         result = _execute_task(task, data)
    234         id = get_id()
    235         result = dumps((result, id))

~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
    117         func, args = arg[0], arg[1:]
    118         args2 = [_execute_task(a, cache) for a in args]
--> 119         return func(*args2)
    120     elif not ishashable(arg):
    121         return arg

~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/optimization.py in __call__(self, *args)
   1057         if not len(args) == len(self.inkeys):
   1058             raise ValueError("Expected %d args, got %d" % (len(self.inkeys), len(args)))
-> 1059         return core.get(self.dsk, self.outkey, dict(zip(self.inkeys, args)))
   1060 
   1061     def __reduce__(self):

~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/core.py in get(dsk, out, cache)
    147     for key in toposort(dsk):
    148         task = dsk[key]
--> 149         result = _execute_task(task, cache)
    150         cache[key] = result
    151     result = _execute_task(out, cache)

~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
    117         func, args = arg[0], arg[1:]
    118         args2 = [_execute_task(a, cache) for a in args]
--> 119         return func(*args2)
    120     elif not ishashable(arg):
    121         return arg

ValueError: operands could not be broadcast together with shapes (0,20) (2,20) 

Example 2

The relevant xarray line in the previous traceback is flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2)), so I tried

(ds.isnull() & ds.chunk({"dim_0": 1}).isnull()).compute() # works
(ds.chunk().isnull() & ds.chunk({"dim_0": 1}).isnull()).compute() # does not work?!
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-4-abdfbeda355a> in <module>
      1 (ds.isnull() & ds.chunk({"dim_0": 1}).isnull()).compute() # works
----> 2 (ds.chunk().isnull() & ds.chunk({"dim_0": 1}).isnull()).compute() # does not work?!

~/work/python/xarray/xarray/core/dataset.py in compute(self, **kwargs)
    791         """
    792         new = self.copy(deep=False)
--> 793         return new.load(**kwargs)
    794 
    795     def _persist_inplace(self, **kwargs) -> "Dataset":

~/work/python/xarray/xarray/core/dataset.py in load(self, **kwargs)
    645 
    646             for k, data in zip(lazy_data, evaluated_data):
--> 647                 self.variables[k].data = data
    648 
    649         # load everything else sequentially

~/work/python/xarray/xarray/core/variable.py in data(self, data)
    331         data = as_compatible_data(data)
    332         if data.shape != self.shape:
--> 333             raise ValueError("replacement data must match the Variable's shape")
    334         self._data = data
    335 

ValueError: replacement data must match the Variable's shape

Problem Description

I don't know what's going on here. I expect assert_equal should return True for all these examples.

Our test for isnull with dask always calls load before comparing:

def test_isnull_with_dask():
    da = construct_dataarray(2, np.float32, contains_nan=True, dask=True)
    assert isinstance(da.isnull().data, dask_array_type)
    assert_equal(da.isnull().load(), da.load().isnull())

Output of xr.show_versions()

xarray master & dask 2.3.0

# Paste the output here xr.show_versions() here INSTALLED VERSIONS ------------------ commit: 6ece6a1 python: 3.7.3 | packaged by conda-forge | (default, Jul 1 2019, 21:52:21) [GCC 7.3.0] python-bits: 64 OS: Linux OS-release: 4.15.0-64-generic machine: x86_64 processor: x86_64 byteorder: little LC_ALL: None LANG: en_US.UTF-8 LOCALE: en_US.UTF-8 libhdf5: 1.10.4 libnetcdf: 4.6.2

xarray: 0.13.0+13.g6ece6a1c
pandas: 0.25.1
numpy: 1.17.2
scipy: 1.3.1
netCDF4: 1.5.1.2
pydap: None
h5netcdf: 0.7.4
h5py: 2.9.0
Nio: None
zarr: None
cftime: 1.0.3.4
nc_time_axis: None
PseudoNetCDF: None
rasterio: None
cfgrib: 0.9.7.1
iris: None
bottleneck: 1.2.1
dask: 2.3.0
distributed: 2.3.2
matplotlib: 3.1.1
cartopy: 0.17.0
seaborn: 0.9.0
numbagg: None
setuptools: 41.2.0
pip: 19.2.3
conda: 4.7.11
pytest: 5.1.2
IPython: 7.8.0
sphinx: 2.2.0

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions