Skip to content

Commit

Permalink
Merge pull request #27 from tomwhite/rechunk-dask-array
Browse files Browse the repository at this point in the history
Allow rechunk to accept a Dask array
  • Loading branch information
rabernat authored Jul 21, 2020
2 parents fdddf0f + 93f9b59 commit 23a3492
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 10 deletions.
1 change: 1 addition & 0 deletions docs/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ v0.1 - Unreleased
-----------------

- Documentation update and tutorial.
- Allow rechunk to accept a Dask array.


v0.0.1 - 2020-07-15
Expand Down
37 changes: 27 additions & 10 deletions rechunker/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,11 @@ def rechunk(
source, target_chunks, max_mem, target_store, temp_store=None,
):
"""
Rechunk a Zarr Array or Group
Rechunk a Zarr Array or Group, or a Dask Array
Parameters
----------
source : zarr.Array or zarr.Group
source : zarr.Array, zarr.Group, or dask.array.Array
Named dimensions in the Arrays will be parsed according to the
Xarray :ref:`xarray:zarr_encoding`.
target_chunks : tuple, dict, or None
Expand Down Expand Up @@ -260,7 +260,7 @@ def rechunk(

return rechunked

elif isinstance(source, zarr.core.Array):
elif isinstance(source, zarr.core.Array) or isinstance(source, dask.array.Array):
return _rechunk_array(
source,
target_chunks,
Expand All @@ -271,7 +271,7 @@ def rechunk(
)

else:
raise ValueError("Source must be a Zarr Array or Group.")
raise ValueError("Source must be a Zarr Array or Group, or a Dask Array.")


def _rechunk_array(
Expand All @@ -287,7 +287,11 @@ def _rechunk_array(
):

shape = source_array.shape
source_chunks = source_array.chunks
source_chunks = (
source_array.chunksize
if isinstance(source_array, dask.array.Array)
else source_array.chunks
)
dtype = source_array.dtype
itemsize = dtype.itemsize

Expand All @@ -307,13 +311,23 @@ def _rechunk_array(

max_mem = dask.utils.parse_bytes(max_mem)

# don't consolidate reads for Dask arrays
consolidate_reads = isinstance(source_array, zarr.core.Array)
read_chunks, int_chunks, write_chunks = rechunking_plan(
shape, source_chunks, target_chunks, itemsize, max_mem
shape,
source_chunks,
target_chunks,
itemsize,
max_mem,
consolidate_reads=consolidate_reads,
)

source_read = dsa.from_zarr(
source_array, chunks=read_chunks, storage_options=source_storage_options
)
if isinstance(source_array, dask.array.Array):
source_read = source_array
else:
source_read = dsa.from_zarr(
source_array, chunks=read_chunks, storage_options=source_storage_options
)

# create target
shape = tuple(int(x) for x in shape) # ensure python ints for serialization
Expand All @@ -324,7 +338,10 @@ def _rechunk_array(
target_array = _zarr_empty(
shape, target_store_or_group, target_chunks, dtype, name=name
)
target_array.attrs.update(source_array.attrs)
try:
target_array.attrs.update(source_array.attrs)
except AttributeError:
pass

if read_chunks == write_chunks:
target_store_delayed = dsa.store(
Expand Down
33 changes: 33 additions & 0 deletions tests/test_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,39 @@ def test_rechunk_array(
assert dsa.equal(a_tar, 1).all().compute()


@pytest.mark.parametrize("shape", [(8000, 8000)])
@pytest.mark.parametrize("source_chunks", [(200, 8000), (800, 8000)])
@pytest.mark.parametrize("dtype", ["f4"])
@pytest.mark.parametrize("max_mem", [25600000])
@pytest.mark.parametrize(
"target_chunks", [(200, 8000), (800, 8000), (8000, 200), (400, 8000),],
)
def test_rechunk_dask_array(
tmp_path, shape, source_chunks, dtype, target_chunks, max_mem
):

### Create source array ###
source_array = dsa.ones(shape, chunks=source_chunks, dtype=dtype)

### Create targets ###
target_store = str(tmp_path / "target.zarr")
temp_store = str(tmp_path / "temp.zarr")

delayed = api.rechunk(
source_array, target_chunks, max_mem, target_store, temp_store=temp_store
)
assert isinstance(delayed, api.Rechunked)

target_array = zarr.open(target_store)

assert target_array.chunks == tuple(target_chunks)

result = delayed.execute()
assert isinstance(result, zarr.Array)
a_tar = dsa.from_zarr(target_array)
assert dsa.equal(a_tar, 1).all().compute()


def test_rechunk_group(tmp_path):
store_source = str(tmp_path / "source.zarr")
group = zarr.group(store_source)
Expand Down

0 comments on commit 23a3492

Please sign in to comment.