From cafcaeea897894e3a2f44a38bd33c50a48c86215 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Fri, 6 Dec 2019 22:30:18 -0600 Subject: [PATCH] Fix map_blocks HLG layering (#3598) * Fix map_blocks HLG layering This fixes an issue with the HighLevelGraph noted in https://github.com/pydata/xarray/pull/3584, and exposed by a recent change in Dask to do more HLG fusion. * update * black * update --- doc/whats-new.rst | 2 ++ xarray/core/parallel.py | 24 +++++++++++++++++++++--- xarray/tests/test_dask.py | 13 +++++++++++++ 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 96e5eeacf95..554f0bc4695 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -36,6 +36,8 @@ Bug fixes ~~~~~~~~~ - Fix plotting with transposed 2D non-dimensional coordinates. (:issue:`3138`, :pull:`3441`) By `Deepak Cherian `_. +- Fix issue with Dask-backed datasets raising a ``KeyError`` on some computations involving ``map_blocks`` (:pull:`3598`) + By `Tom Augspurger `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index fbb5ef94ca2..dd6c67338d8 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -7,12 +7,14 @@ except ImportError: pass +import collections import itertools import operator from typing import ( Any, Callable, Dict, + DefaultDict, Hashable, Mapping, Sequence, @@ -221,7 +223,12 @@ def _wrapper(func, obj, to_array, args, kwargs): indexes = {dim: dataset.indexes[dim] for dim in preserved_indexes} indexes.update({k: template.indexes[k] for k in new_indexes}) + # We're building a new HighLevelGraph hlg. We'll have one new layer + # for each variable in the dataset, which is the result of the + # func applied to the values. + graph: Dict[Any, Any] = {} + new_layers: DefaultDict[str, Dict[Any, Any]] = collections.defaultdict(dict) gname = "{}-{}".format( dask.utils.funcname(func), dask.base.tokenize(dataset, args, kwargs) ) @@ -310,9 +317,20 @@ def _wrapper(func, obj, to_array, args, kwargs): # unchunked dimensions in the input have one chunk in the result key += (0,) - graph[key] = (operator.getitem, from_wrapper, name) + # We're adding multiple new layers to the graph: + # The first new layer is the result of the computation on + # the array. + # Then we add one layer per variable, which extracts the + # result for that variable, and depends on just the first new + # layer. + new_layers[gname_l][key] = (operator.getitem, from_wrapper, name) + + hlg = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset]) - graph = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset]) + for gname_l, layer in new_layers.items(): + # This adds in the getitems for each variable in the dataset. + hlg.dependencies[gname_l] = {gname} + hlg.layers[gname_l] = layer result = Dataset(coords=indexes, attrs=template.attrs) for name, gname_l in var_key_map.items(): @@ -325,7 +343,7 @@ def _wrapper(func, obj, to_array, args, kwargs): var_chunks.append((len(indexes[dim]),)) data = dask.array.Array( - graph, name=gname_l, chunks=var_chunks, dtype=template[name].dtype + hlg, name=gname_l, chunks=var_chunks, dtype=template[name].dtype ) result[name] = (dims, data, template[name].attrs) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index f3b10e3370c..6122e987154 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1189,6 +1189,19 @@ def func(obj): assert_identical(expected.compute(), actual.compute()) +def test_map_blocks_hlg_layers(): + # regression test for #3599 + ds = xr.Dataset( + { + "x": (("a",), dask.array.ones(10, chunks=(5,))), + "z": (("b",), dask.array.ones(10, chunks=(5,))), + } + ) + mapped = ds.map_blocks(lambda x: x) + + xr.testing.assert_equal(mapped, ds) + + def test_make_meta(map_ds): from ..core.parallel import make_meta