Skip to content

Commit 20195ca

Browse files
committed
Fix optimize for chunked DataArray
Previously we generated in invalidate Dask task graph, becuase the lines removed here dropped keys that were referenced elsewhere in the task graph. The original implementation had a comment indicating that this was to cull: https://github.com/pydata/xarray/blame/502a988ad5b87b9f3aeec3033bf55c71272e1053/xarray/core/variable.py#L384 Just spot-checking things, I think we're OK here though. Something like `dask.visualize(arr[[0]], optimize_graph=True)` indicates that we're OK. Closes #3698
1 parent bb4c7b4 commit 20195ca

File tree

3 files changed

+8
-3
lines changed

3 files changed

+8
-3
lines changed

doc/whats-new.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ Bug fixes
8484
- Fix `KeyError` when doing linear interpolation to an nd `DataArray`
8585
that contains NaNs (:pull:`4233`).
8686
By `Jens Svensmark <https://github.com/jenssss>`_
87+
- Fix ``dask.optimize`` on ``DataArray`` producing an invalid Dask task graph (:issue:`3698`)
8788

8889
Documentation
8990
~~~~~~~~~~~~~

xarray/core/variable.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -501,9 +501,6 @@ def __dask_postpersist__(self):
501501

502502
@staticmethod
503503
def _dask_finalize(results, array_func, array_args, dims, attrs, encoding):
504-
if isinstance(results, dict): # persist case
505-
name = array_args[0]
506-
results = {k: v for k, v in results.items() if k[0] == name}
507504
data = array_func(results, *array_args)
508505
return Variable(dims, data, attrs=attrs, encoding=encoding)
509506

xarray/tests/test_dask.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,3 +1607,10 @@ def test_more_transforms_pass_lazy_array_equiv(map_da, map_ds):
16071607
assert_equal(map_da._from_temp_dataset(map_da._to_temp_dataset()), map_da)
16081608
assert_equal(map_da.astype(map_da.dtype), map_da)
16091609
assert_equal(map_da.transpose("y", "x", transpose_coords=False).cxy, map_da.cxy)
1610+
1611+
1612+
def test_optimize():
1613+
a = dask.array.ones((10, 5), chunks=(1, 3))
1614+
arr = xr.DataArray(a).chunk(5)
1615+
(arr2,) = dask.optimize(arr)
1616+
arr2.compute()

0 commit comments

Comments
 (0)