Skip to content
forked from pydata/xarray

Commit

Permalink
Refactor inserting of in memory data
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Nov 3, 2023
1 parent b2e644c commit f6557f7
Showing 1 changed file with 37 additions and 22 deletions.
59 changes: 37 additions & 22 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,35 @@ def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping
return slice(None)


def _insert_in_memory_data_in_graph(
graph, gname, name, variable, chunk_index, chunk_bounds
):
import dask

# non-dask array possibly with dimensions chunked on other variables
# index into variable appropriately
subsetter = {
dim: _get_chunk_slicer(dim, chunk_index, chunk_bounds) for dim in variable.dims
}
subset = variable.isel(subsetter)
if name in chunk_index:
# We are including a dimension coordinate,
# minimize duplication by not copying it in the graph for every chunk.
chunk_tuple = (chunk_index[name],)
else:
chunk_tuple = tuple(chunk_index.values())

chunk_variable_task = (
f"{name}-{gname}-{dask.base.tokenize(subset)}",
) + chunk_tuple
if chunk_variable_task not in graph:
graph[chunk_variable_task] = (
tuple,
[subset.dims, subset._data, subset.attrs],
)
return chunk_variable_task


def map_blocks(
func: Callable[..., T_Xarray],
obj: DataArray | Dataset,
Expand Down Expand Up @@ -450,28 +479,14 @@ def subset_dataset_to_block(
[variable.dims, chunk, variable.attrs],
)
else:
# non-dask array possibly with dimensions chunked on other variables
# index into variable appropriately
subsetter = {
dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds)
for dim in variable.dims
}
subset = variable.isel(subsetter)
if name in chunk_index:
# We are including a dimension coordinate,
# minimize duplication by not copying it in the graph for every chunk.
this_var_chunk_tuple = (chunk_index[name],)
else:
this_var_chunk_tuple = chunk_tuple

chunk_variable_task = (
f"{name}-{gname}-{dask.base.tokenize(subset)}",
) + this_var_chunk_tuple
if chunk_variable_task not in graph:
graph[chunk_variable_task] = (
tuple,
[subset.dims, subset._data, subset.attrs],
)
chunk_variable_task = _insert_in_memory_data_in_graph(
graph,
gname,
name,
variable,
chunk_index,
input_chunk_bounds,
)

# this task creates dict mapping variable name to above tuple
if name in dataset._coord_names:
Expand Down

0 comments on commit f6557f7

Please sign in to comment.