diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 08a7577200c..da91d896901 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -144,6 +144,35 @@ def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping return slice(None) +def _insert_in_memory_data_in_graph( + graph, gname, name, variable, chunk_index, chunk_bounds +): + import dask + + # non-dask array possibly with dimensions chunked on other variables + # index into variable appropriately + subsetter = { + dim: _get_chunk_slicer(dim, chunk_index, chunk_bounds) for dim in variable.dims + } + subset = variable.isel(subsetter) + if name in chunk_index: + # We are including a dimension coordinate, + # minimize duplication by not copying it in the graph for every chunk. + chunk_tuple = (chunk_index[name],) + else: + chunk_tuple = tuple(chunk_index.values()) + + chunk_variable_task = ( + f"{name}-{gname}-{dask.base.tokenize(subset)}", + ) + chunk_tuple + if chunk_variable_task not in graph: + graph[chunk_variable_task] = ( + tuple, + [subset.dims, subset._data, subset.attrs], + ) + return chunk_variable_task + + def map_blocks( func: Callable[..., T_Xarray], obj: DataArray | Dataset, @@ -450,28 +479,14 @@ def subset_dataset_to_block( [variable.dims, chunk, variable.attrs], ) else: - # non-dask array possibly with dimensions chunked on other variables - # index into variable appropriately - subsetter = { - dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds) - for dim in variable.dims - } - subset = variable.isel(subsetter) - if name in chunk_index: - # We are including a dimension coordinate, - # minimize duplication by not copying it in the graph for every chunk. - this_var_chunk_tuple = (chunk_index[name],) - else: - this_var_chunk_tuple = chunk_tuple - - chunk_variable_task = ( - f"{name}-{gname}-{dask.base.tokenize(subset)}", - ) + this_var_chunk_tuple - if chunk_variable_task not in graph: - graph[chunk_variable_task] = ( - tuple, - [subset.dims, subset._data, subset.attrs], - ) + chunk_variable_task = _insert_in_memory_data_in_graph( + graph, + gname, + name, + variable, + chunk_index, + input_chunk_bounds, + ) # this task creates dict mapping variable name to above tuple if name in dataset._coord_names: