Skip to content

Commit

Permalink
Fix map_blocks HLG layering
Browse files Browse the repository at this point in the history
This fixes an issue with the HighLevelGraph noted in
pydata#3584, and exposed by a recent
change in Dask to do more HLG fusion.
  • Loading branch information
TomAugspurger committed Dec 5, 2019
1 parent 87a25b6 commit 12292e6
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
8 changes: 7 additions & 1 deletion xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
except ImportError:
pass

import collections
import itertools
import operator
from typing import (
Expand Down Expand Up @@ -222,6 +223,7 @@ def _wrapper(func, obj, to_array, args, kwargs):
indexes.update({k: template.indexes[k] for k in new_indexes})

graph: Dict[Any, Any] = {}
new_layers = collections.defaultdict(dict)
gname = "{}-{}".format(
dask.utils.funcname(func), dask.base.tokenize(dataset, args, kwargs)
)
Expand Down Expand Up @@ -310,10 +312,14 @@ def _wrapper(func, obj, to_array, args, kwargs):
# unchunked dimensions in the input have one chunk in the result
key += (0,)

graph[key] = (operator.getitem, from_wrapper, name)
new_layers[gname_l][key] = (operator.getitem, from_wrapper, name)

graph = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset])

for gname_l, layer in new_layers.items():
graph.dependencies[gname_l] = {gname}
graph.layers[gname_l] = layer

result = Dataset(coords=indexes, attrs=template.attrs)
for name, gname_l in var_key_map.items():
dims = template[name].dims
Expand Down
7 changes: 7 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,13 @@ def func(obj):
assert_identical(expected.compute(), actual.compute())


def test_map_blocks_hlg_layers():
ds = xr.Dataset({'x': (('y',), dask.array.ones(10, chunks=(5,)))})
mapped = ds.map_blocks(lambda x: x)

xr.testing.assert_equal(mapped, ds) # does not work


def test_make_meta(map_ds):
from ..core.parallel import make_meta

Expand Down

0 comments on commit 12292e6

Please sign in to comment.