Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce graph size through writing indexes directly into graph for map_blocks #9658

Merged
merged 4 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class ExpectedDict(TypedDict):
shapes: dict[Hashable, int]
coords: set[Hashable]
data_vars: set[Hashable]
indexes: dict[Hashable, Index]


def unzip(iterable):
Expand Down Expand Up @@ -337,6 +336,7 @@ def _wrapper(
kwargs: dict,
arg_is_array: Iterable[bool],
expected: ExpectedDict,
expected_indexes: dict[Hashable, Index],
):
"""
Wrapper function that receives datasets in args; converts to dataarrays when necessary;
Expand Down Expand Up @@ -372,7 +372,7 @@ def _wrapper(

# ChainMap wants MutableMapping, but xindexes is Mapping
merged_indexes = collections.ChainMap(
expected["indexes"],
expected_indexes,
merged_coordinates.xindexes, # type: ignore[arg-type]
)
expected_index = merged_indexes.get(name, None)
Expand Down Expand Up @@ -412,6 +412,7 @@ def _wrapper(
try:
import dask
import dask.array
from dask.base import tokenize
from dask.highlevelgraph import HighLevelGraph

except ImportError:
Expand Down Expand Up @@ -551,6 +552,20 @@ def _wrapper(
for isxr, arg in zip(is_xarray, npargs, strict=True)
]

# only include new or modified indexes to minimize duplication of data
indexes = {
dcherian marked this conversation as resolved.
Show resolved Hide resolved
dim: coordinates.xindexes[dim][
_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)
]
for dim in (new_indexes | modified_indexes)
}

tokenized_indexes: dict[Hashable, str] = {}
for k, v in indexes.items():
tokenized_v = tokenize(v)
graph[f"{k}-coordinate-{tokenized_v}"] = v
tokenized_indexes[k] = f"{k}-coordinate-{tokenized_v}"

# raise nice error messages in _wrapper
expected: ExpectedDict = {
# input chunk 0 along a dimension maps to output chunk 0 along the same dimension
Expand All @@ -562,17 +577,18 @@ def _wrapper(
},
"data_vars": set(template.data_vars.keys()),
"coords": set(template.coords.keys()),
# only include new or modified indexes to minimize duplication of data, and graph size.
"indexes": {
dim: coordinates.xindexes[dim][
_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)
]
for dim in (new_indexes | modified_indexes)
},
}

from_wrapper = (gname,) + chunk_tuple
graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected)
graph[from_wrapper] = (
_wrapper,
func,
blocked_args,
kwargs,
is_array,
expected,
(dict, [[k, v] for k, v in tokenized_indexes.items()]),
)

# mapping from variable name to dask graph key
var_key_map: dict[Hashable, str] = {}
Expand Down
8 changes: 8 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from xarray import DataArray, Dataset, Variable
from xarray.core import duck_array_ops
from xarray.core.duck_array_ops import lazy_array_equiv
from xarray.core.indexes import PandasIndex
from xarray.testing import assert_chunks_equal
from xarray.tests import (
assert_allclose,
Expand Down Expand Up @@ -1375,6 +1376,13 @@ def test_map_blocks_da_ds_with_template(obj):
actual = xr.map_blocks(func, obj, template=template)
assert_identical(actual, template)

# Check that indexes are written into the graph directly
dsk = dict(actual.__dask_graph__())
assert len({k for k in dsk if "x-coordinate" in k})
assert all(
isinstance(v, PandasIndex) for k, v in dsk.items() if "x-coordinate" in k
)

with raise_if_dask_computes():
actual = obj.map_blocks(func, template=template)
assert_identical(actual, template)
Expand Down
Loading