Skip to content

Commit d65569b

Browse files
committed
support dask arrays in DataArray coordinates
1 parent 9df0af7 commit d65569b

File tree

3 files changed

+35
-14
lines changed

3 files changed

+35
-14
lines changed

xarray/core/dataarray.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -577,33 +577,33 @@ def reset_coords(self, names=None, drop=False, inplace=False):
577577
return dataset
578578

579579
def __dask_graph__(self):
580-
return self._variable.__dask_graph__()
580+
return self._to_temp_dataset().__dask_graph__()
581581

582582
def __dask_keys__(self):
583-
return self._variable.__dask_keys__()
583+
return self._to_temp_dataset().__dask_keys__()
584584

585585
@property
586586
def __dask_optimize__(self):
587-
return self._variable.__dask_optimize__
587+
return self._to_temp_dataset().__dask_optimize__
588588

589589
@property
590590
def __dask_scheduler__(self):
591-
return self._variable.__dask_scheduler__
591+
return self._to_temp_dataset().__dask_optimize__
592592

593593
def __dask_postcompute__(self):
594-
variable_func, variable_args = self._variable.__dask_postcompute__()
595-
return self._dask_finalize, (variable_func, variable_args,
596-
self._coords, self._name)
594+
func, args = self._to_temp_dataset().__dask_postcompute__()
595+
return self._dask_finalize, (func, args, self.name)
597596

598597
def __dask_postpersist__(self):
599-
variable_func, variable_args = self._variable.__dask_postpersist__()
600-
return self._dask_finalize, (variable_func, variable_args,
601-
self._coords, self._name)
598+
func, args = self._to_temp_dataset().__dask_postpersist__()
599+
return self._dask_finalize, (func, args, self.name)
602600

603601
@staticmethod
604-
def _dask_finalize(results, variable_func, variable_args, coords, name):
605-
var = variable_func(results, *variable_args)
606-
return DataArray(var, coords=coords, name=name)
602+
def _dask_finalize(results, func, args, name):
603+
ds = func(results, *args)
604+
variable = ds._variables.pop(_THIS_ARRAY)
605+
coords = ds._variables
606+
return DataArray(variable, coords, name=name, fastpath=True)
607607

608608
def load(self, **kwargs):
609609
"""Manually trigger loading of this array's data from disk or a

xarray/core/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ def __dask_postpersist__(self):
538538
@staticmethod
539539
def _dask_postcompute(results, info, *args):
540540
variables = OrderedDict()
541-
results2 = results[::-1]
541+
results2 = list(results[::-1])
542542
for is_dask, k, v in info:
543543
if is_dask:
544544
func, args2 = v

xarray/tests/test_dask.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -771,3 +771,24 @@ def test_persist_DataArray(persist):
771771

772772
assert len(z.data.dask) == n
773773
assert len(zz.data.dask) == zz.data.npartitions
774+
775+
776+
@pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4',
777+
reason='Need dask 0.16 for new interface')
778+
def test_dataarray_with_dask_coords():
779+
import toolz
780+
x = xr.Variable('x', da.arange(8, chunks=(4,)))
781+
y = xr.Variable('y', da.arange(8, chunks=(4,)) * 2)
782+
data = da.random.random((8, 8), chunks=(4, 4)) + 1
783+
array = xr.DataArray(data, dims=['x', 'y'])
784+
array.coords['xx'] = x
785+
array.coords['yy'] = y
786+
787+
assert dict(array.__dask_graph__()) == toolz.merge(data.__dask_graph__(),
788+
x.__dask_graph__(),
789+
y.__dask_graph__())
790+
791+
(array2,) = dask.compute(array)
792+
assert not dask.is_dask_collection(array2)
793+
794+
assert all(isinstance(v._variable.data, np.ndarray) for v in array2.coords.values())

0 commit comments

Comments
 (0)