From 12292e627ad4f8ec7bfecc364fd334c3cac6d50b Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Thu, 5 Dec 2019 12:16:10 -0600 Subject: [PATCH] Fix map_blocks HLG layering This fixes an issue with the HighLevelGraph noted in https://github.com/pydata/xarray/pull/3584, and exposed by a recent change in Dask to do more HLG fusion. --- xarray/core/parallel.py | 8 +++++++- xarray/tests/test_dask.py | 7 +++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index fbb5ef94ca2..5c98db0045b 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -7,6 +7,7 @@ except ImportError: pass +import collections import itertools import operator from typing import ( @@ -222,6 +223,7 @@ def _wrapper(func, obj, to_array, args, kwargs): indexes.update({k: template.indexes[k] for k in new_indexes}) graph: Dict[Any, Any] = {} + new_layers = collections.defaultdict(dict) gname = "{}-{}".format( dask.utils.funcname(func), dask.base.tokenize(dataset, args, kwargs) ) @@ -310,10 +312,14 @@ def _wrapper(func, obj, to_array, args, kwargs): # unchunked dimensions in the input have one chunk in the result key += (0,) - graph[key] = (operator.getitem, from_wrapper, name) + new_layers[gname_l][key] = (operator.getitem, from_wrapper, name) graph = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset]) + for gname_l, layer in new_layers.items(): + graph.dependencies[gname_l] = {gname} + graph.layers[gname_l] = layer + result = Dataset(coords=indexes, attrs=template.attrs) for name, gname_l in var_key_map.items(): dims = template[name].dims diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index f3b10e3370c..ed6b9d5e4bd 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1189,6 +1189,13 @@ def func(obj): assert_identical(expected.compute(), actual.compute()) +def test_map_blocks_hlg_layers(): + ds = xr.Dataset({'x': (('y',), dask.array.ones(10, chunks=(5,)))}) + mapped = ds.map_blocks(lambda x: x) + + xr.testing.assert_equal(mapped, ds) # does not work + + def test_make_meta(map_ds): from ..core.parallel import make_meta