Skip to content

Commit

Permalink
Fix map_blocks HLG layering
Browse files Browse the repository at this point in the history
This fixes an issue with the HighLevelGraph noted in
pydata#3584, and exposed by a recent
change in Dask to do more HLG fusion.
  • Loading branch information
TomAugspurger committed Dec 5, 2019
1 parent 87a25b6 commit a9a5e93
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
13 changes: 10 additions & 3 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
except ImportError:
pass

import collections
import itertools
import operator
from typing import (
Any,
Callable,
Dict,
DefaultDict,
Hashable,
Mapping,
Sequence,
Expand Down Expand Up @@ -222,6 +224,7 @@ def _wrapper(func, obj, to_array, args, kwargs):
indexes.update({k: template.indexes[k] for k in new_indexes})

graph: Dict[Any, Any] = {}
new_layers: DefaultDict[str, Dict[Any, Any]] = collections.defaultdict(dict)
gname = "{}-{}".format(
dask.utils.funcname(func), dask.base.tokenize(dataset, args, kwargs)
)
Expand Down Expand Up @@ -310,9 +313,13 @@ def _wrapper(func, obj, to_array, args, kwargs):
# unchunked dimensions in the input have one chunk in the result
key += (0,)

graph[key] = (operator.getitem, from_wrapper, name)
new_layers[gname_l][key] = (operator.getitem, from_wrapper, name)

graph = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset])
hlg = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset])

for gname_l, layer in new_layers.items():
hlg.dependencies[gname_l] = {gname}
hlg.layers[gname_l] = layer

result = Dataset(coords=indexes, attrs=template.attrs)
for name, gname_l in var_key_map.items():
Expand All @@ -325,7 +332,7 @@ def _wrapper(func, obj, to_array, args, kwargs):
var_chunks.append((len(indexes[dim]),))

data = dask.array.Array(
graph, name=gname_l, chunks=var_chunks, dtype=template[name].dtype
hlg, name=gname_l, chunks=var_chunks, dtype=template[name].dtype
)
result[name] = (dims, data, template[name].attrs)

Expand Down
7 changes: 7 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,13 @@ def func(obj):
assert_identical(expected.compute(), actual.compute())


def test_map_blocks_hlg_layers():
ds = xr.Dataset({"x": (("y",), dask.array.ones(10, chunks=(5,)))})
mapped = ds.map_blocks(lambda x: x)

xr.testing.assert_equal(mapped, ds) # does not work


def test_make_meta(map_ds):
from ..core.parallel import make_meta

Expand Down

0 comments on commit a9a5e93

Please sign in to comment.