From 84ba7453b0d6f68044e00309c1f785348a13369f Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 19 Dec 2023 09:32:18 -0700 Subject: [PATCH] reorder --- xarray/core/parallel.py | 138 ++++++++++++++++++++-------------------- 1 file changed, 69 insertions(+), 69 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 622cb7063ea..24ceb089c89 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -147,6 +147,75 @@ def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping return slice(None) +def subset_dataset_to_block( + graph: dict, gname: str, dataset: Dataset, input_chunk_bounds, chunk_index +): + """ + Creates a task that subsets an xarray dataset to a block determined by chunk_index. + Block extents are determined by input_chunk_bounds. + Also subtasks that subset the constituent variables of a dataset. + """ + import dask + + # this will become [[name1, variable1], + # [name2, variable2], + # ...] + # which is passed to dict and then to Dataset + data_vars = [] + coords = [] + + chunk_tuple = tuple(chunk_index.values()) + chunk_dims_set = set(chunk_index) + variable: Variable + for name, variable in dataset.variables.items(): + # make a task that creates tuple of (dims, chunk) + if dask.is_dask_collection(variable.data): + # get task name for chunk + chunk = ( + variable.data.name, + *tuple(chunk_index[dim] for dim in variable.dims), + ) + + chunk_variable_task = (f"{name}-{gname}-{chunk[0]!r}",) + chunk_tuple + graph[chunk_variable_task] = ( + tuple, + [variable.dims, chunk, variable.attrs], + ) + else: + assert name in dataset.dims or variable.ndim == 0 + + # non-dask array possibly with dimensions chunked on other variables + # index into variable appropriately + subsetter = { + dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds) + for dim in variable.dims + } + if set(variable.dims) < chunk_dims_set: + this_var_chunk_tuple = tuple(chunk_index[dim] for dim in variable.dims) + else: + this_var_chunk_tuple = chunk_tuple + + chunk_variable_task = ( + f"{name}-{gname}-{dask.base.tokenize(subsetter)}", + ) + this_var_chunk_tuple + # We are including a dimension coordinate, + # minimize duplication by not copying it in the graph for every chunk. + if variable.ndim == 0 or chunk_variable_task not in graph: + subset = variable.isel(subsetter) + graph[chunk_variable_task] = ( + tuple, + [subset.dims, subset._data, subset.attrs], + ) + + # this task creates dict mapping variable name to above tuple + if name in dataset._coord_names: + coords.append([name, chunk_variable_task]) + else: + data_vars.append([name, chunk_variable_task]) + + return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs) + + def map_blocks( func: Callable[..., T_Xarray], obj: DataArray | Dataset, @@ -451,75 +520,6 @@ def _wrapper( dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in output_chunks.items() } - def subset_dataset_to_block( - graph: dict, gname: str, dataset: Dataset, input_chunk_bounds, chunk_index - ): - """ - Creates a task that subsets an xarray dataset to a block determined by chunk_index. - Block extents are determined by input_chunk_bounds. - Also subtasks that subset the constituent variables of a dataset. - """ - - # this will become [[name1, variable1], - # [name2, variable2], - # ...] - # which is passed to dict and then to Dataset - data_vars = [] - coords = [] - - chunk_tuple = tuple(chunk_index.values()) - chunk_dims_set = set(chunk_index) - variable: Variable - for name, variable in dataset.variables.items(): - # make a task that creates tuple of (dims, chunk) - if dask.is_dask_collection(variable.data): - # get task name for chunk - chunk = ( - variable.data.name, - *tuple(chunk_index[dim] for dim in variable.dims), - ) - - chunk_variable_task = (f"{name}-{gname}-{chunk[0]!r}",) + chunk_tuple - graph[chunk_variable_task] = ( - tuple, - [variable.dims, chunk, variable.attrs], - ) - else: - assert name in dataset.dims or variable.ndim == 0 - - # non-dask array possibly with dimensions chunked on other variables - # index into variable appropriately - subsetter = { - dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds) - for dim in variable.dims - } - if set(variable.dims) < chunk_dims_set: - this_var_chunk_tuple = tuple( - chunk_index[dim] for dim in variable.dims - ) - else: - this_var_chunk_tuple = chunk_tuple - - chunk_variable_task = ( - f"{name}-{gname}-{dask.base.tokenize(subsetter)}", - ) + this_var_chunk_tuple - # We are including a dimension coordinate, - # minimize duplication by not copying it in the graph for every chunk. - if variable.ndim == 0 or chunk_variable_task not in graph: - subset = variable.isel(subsetter) - graph[chunk_variable_task] = ( - tuple, - [subset.dims, subset._data, subset.attrs], - ) - - # this task creates dict mapping variable name to above tuple - if name in dataset._coord_names: - coords.append([name, chunk_variable_task]) - else: - data_vars.append([name, chunk_variable_task]) - - return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs) - include_variables = set(template.variables) - set(coordinates.indexes) # iterate over all possible chunk combinations for chunk_tuple in itertools.product(*ichunk.values()):