Skip to content

Commit

Permalink
pydata#8062 Dataset.chunk() and DataArray.chunk() now correctly set e…
Browse files Browse the repository at this point in the history
…ncoding attribute
  • Loading branch information
Metamess committed Aug 14, 2023
1 parent eceec5f commit abbbdf8
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
3 changes: 2 additions & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2594,7 +2594,7 @@ def chunk(
already as dask array.
chunked_array_type: str, optional
Which chunked array type to coerce this datasets' arrays to.
Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEnetryPoint` system.
Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEntryPoint` system.
Experimental API that should not be relied upon.
from_array_kwargs: dict, optional
Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create
Expand Down Expand Up @@ -2647,6 +2647,7 @@ def chunk(
token,
lock,
name_prefix,
overwrite_encoded_chunks=True,
inline_array=inline_array,
chunked_array_type=chunkmanager,
from_array_kwargs=from_array_kwargs.copy(),
Expand Down
3 changes: 3 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,14 +851,17 @@ def test_chunk(self) -> None:

blocked = unblocked.chunk()
assert blocked.chunks == ((3,), (4,))
assert blocked.encoding.get("chunks", None) == (3, 4)
first_dask_name = blocked.data.name

blocked = unblocked.chunk(chunks=((2, 1), (2, 2)))
assert blocked.chunks == ((2, 1), (2, 2))
assert blocked.encoding.get("chunks", None) == (2, 2)
assert blocked.data.name != first_dask_name

blocked = unblocked.chunk(chunks=(3, 3))
assert blocked.chunks == ((3,), (3, 1))
assert blocked.encoding.get("chunks", None) == (3, 3)
assert blocked.data.name != first_dask_name

# name doesn't change when rechunking by same amount
Expand Down
11 changes: 9 additions & 2 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,17 +1111,24 @@ def test_chunk(self) -> None:

# test kwargs form of chunks
assert data.chunk(expected_chunks).chunks == expected_chunks
# Verify the encoding attributes have been set
for da_reblocked in reblocked.values():
assert da_reblocked.encoding.get("chunks", None) == da_reblocked.shape

def get_dask_names(ds):
return {k: v.data.name for k, v in ds.items()}

orig_dask_names = get_dask_names(reblocked)

reblocked = data.chunk({"time": 5, "dim1": 5, "dim2": 5, "dim3": 5})
desired_chunks = {"time": 6, "dim1": 5, "dim2": 4, "dim3": 3}
reblocked = data.chunk(desired_chunks)
# time is not a dim in any of the data_vars, so it
# doesn't get chunked
expected_chunks = {"dim1": (5, 3), "dim2": (5, 4), "dim3": (5, 5)}
expected_chunks = {"dim1": (5, 3), "dim2": (4, 4, 1), "dim3": (3, 3, 3, 1)}
assert reblocked.chunks == expected_chunks
# Verify the encoding attributes have been set
for da_reblocked in reblocked.values():
assert da_reblocked.encoding.get("chunks", None) == tuple(desired_chunks[d] for d in da_reblocked.dims)

# make sure dask names change when rechunking by different amounts
# regression test for GH3350
Expand Down

0 comments on commit abbbdf8

Please sign in to comment.