Skip to content

Commit

Permalink
typing
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Nov 4, 2024
1 parent 492ffee commit bbed44f
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 36 deletions.
83 changes: 65 additions & 18 deletions icechunk-python/python/icechunk/dask.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,103 @@
import itertools
from typing import Any, Callable, Sequence

from toolz import partition_all

from collections.abc import MutableMapping
from icechunk import IcechunkStore

from typing import (
Any,
Callable,
Literal,
Iterable,
Mapping,
Sequence,
Sized,
TYPE_CHECKING,
overload,
TypeAlias,
)
from dask.typing import Graph
from dask import config
from dask.array.core import Array
from dask.base import compute_as_if_collection, tokenize
from dask.core import flatten
from dask.delayed import Delayed
from dask.highlevelgraph import HighLevelGraph

SimpleGraph: TypeAlias = Mapping[tuple[str, int], tuple[Any, ...]]


# tree-reduce all changesets, regardless of array
def partial_reduce(
aggregate: Callable,
keys: Sequence[tuple[Any, ...]],
aggregate: Callable[..., Any],
keys: Iterable[tuple[Any, ...]],
*,
layer_name: str,
split_every: int,
):
) -> SimpleGraph:
"""
Creates a new dask graph layer, that aggregates `split_every` keys together.
"""
from toolz import partition_all

return {
(layer_name, i): (aggregate, *keys_batch)
for i, keys_batch in enumerate(partition_all(split_every, keys))
}


@overload
def stateful_store_reduce(
stored_arrays: Sequence[Array],
*,
chunk: Callable,
aggregate: Callable,
chunk: Callable[..., Any],
aggregate: Callable[..., Any],
prefix: str | None = None,
split_every: int | None = None,
compute: Literal[False] = False,
**kwargs: Any,
) -> Delayed: ...


@overload
def stateful_store_reduce(
stored_arrays: Sequence[Array],
*,
chunk: Callable[..., Any],
aggregate: Callable[..., Any],
compute: Literal[True] = True,
prefix: str | None = None,
split_every: int | None = None,
**kwargs: Any,
) -> IcechunkStore: ...


def stateful_store_reduce(
stored_arrays: Sequence[Array],
*,
chunk: Callable[..., Any],
aggregate: Callable[..., Any],
compute: bool = True,
**kwargs,
):
prefix: str | None = None,
split_every: int | None = None,
**kwargs: Any,
) -> IcechunkStore | Delayed:
split_every = split_every or config.get("split_every", 8)

layers = {}
dependencies = {}
layers: MutableMapping[str, SimpleGraph] = {}
dependencies: MutableMapping[str, set[str]] = {}

array_names = tuple(a.name for a in stored_arrays)
all_array_keys = list(
itertools.chain(*[flatten(array.__dask_keys__()) for array in stored_arrays])
# flatten is untyped..
itertools.chain(*[flatten(array.__dask_keys__()) for array in stored_arrays]) # type: ignore[no-untyped-call]
)
token = tokenize(array_names, chunk, aggregate, split_every)

# Each write task returns one Zarr array,
# now extract the changeset (as bytes) from each of those Zarr arrays
map_layer_name = f"{prefix}-blockwise-{token}"
map_dsk = {(map_layer_name, i): (chunk, key) for i, key in enumerate(all_array_keys)}
map_dsk: SimpleGraph = {
(map_layer_name, i): (chunk, key) for i, key in enumerate(all_array_keys)
}
layers[map_layer_name] = map_dsk
dependencies[map_layer_name] = set(array_names)
latest_layer = map_layer_name
Expand Down Expand Up @@ -85,14 +130,16 @@ def stateful_store_reduce(
dependencies[latest_layer] = {previous_layer}

store_dsk = HighLevelGraph.merge(
HighLevelGraph(layers, dependencies),
HighLevelGraph(layers, dependencies), # type: ignore[arg-type]
*[array.__dask_graph__() for array in stored_arrays],
)
if compute:
# copied from dask.array.store
merged_store, *_ = compute_as_if_collection(
merged_store, *_ = compute_as_if_collection( # type: ignore[no-untyped-call]
Array, store_dsk, list(layers[latest_layer].keys()), **kwargs
)
if TYPE_CHECKING:
assert isinstance(merged_store, IcechunkStore)
return merged_store

else:
Expand All @@ -101,4 +148,4 @@ def stateful_store_reduce(
HighLevelGraph({key: {key: (latest_layer, 0)}}, {key: {latest_layer}}),
store_dsk,
)
return Delayed(key, store_dsk)
return Delayed(key, store_dsk) # type: ignore[no-untyped-call]
43 changes: 25 additions & 18 deletions icechunk-python/python/icechunk/xarray.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
#!/usr/bin/env python3
from collections.abc import Hashable, Mapping, MutableMapping
from typing import Any, Literal, Union
from typing import Any, Literal, cast


import numpy as np
import zarr

from icechunk import IcechunkStore
from .dask import stateful_store_reduce

from xarray import Dataset
from xarray.core.types import ZarrWriteModes # TODO: delete this
from xarray.backends.zarr import ZarrStore
from xarray.backends.common import ArrayWriter
from dataclasses import dataclass, field

ZarrWriteModes = Literal["w", "w-", "a", "a-", "r+", "r"]

# TODO: import-time check on Xarray version
#
try:
Expand All @@ -24,7 +27,7 @@


def extract_stores(zarray: zarr.Array) -> IcechunkStore:
return zarray.store
return cast(zarray.store, "IcechunkStore")


def merge_stores(*stores: IcechunkStore) -> IcechunkStore:
Expand All @@ -45,13 +48,13 @@ def is_chunked_array(x: Any) -> bool:

class LazyArrayWriter(ArrayWriter):
def __init__(self) -> None:
super().__init__()
super().__init__() # type: ignore[no-untyped-call]

self.eager_sources = []
self.eager_targets = []
self.eager_regions = []
self.eager_sources: list[np.ndarray[Any, Any]] = []
self.eager_targets: list[zarr.Array] = []
self.eager_regions: list[tuple[slice, ...]] = []

def add(self, source, target, region=None):
def add(self, source: Any, target: Any, region: Any = None) -> Any:
if is_chunked_array(source):
self.sources.append(source)
self.targets.append(target)
Expand Down Expand Up @@ -96,7 +99,7 @@ def write_metadata(
*,
group: str | None = None,
mode: ZarrWriteModes | None = None,
encoding: Mapping | None = None,
encoding: Mapping[Any, Any] | None = None,
append_dim: Hashable | None = None,
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
write_empty_chunks: bool | None = None,
Expand Down Expand Up @@ -192,12 +195,14 @@ def write_metadata(
# validate Dataset keys, DataArray names
_validate_dataset_names(self.dataset)

self.mode = _choose_default_mode(mode=mode, append_dim=append_dim, region=region)
concrete_mode: ZarrWriteModes = _choose_default_mode(
mode=mode, append_dim=append_dim, region=region
)

self.xarray_store = ZarrStore.open_group(
store=self.store,
group=group,
mode=mode,
mode=concrete_mode,
zarr_format=3,
append_dim=append_dim,
write_region=region,
Expand All @@ -217,11 +222,11 @@ def write_metadata(

# This writes the metadata (zarr.json) for all arrays
self.writer = LazyArrayWriter()
dump_to_store(dataset, self.xarray_store, self.writer, encoding=encoding)
dump_to_store(dataset, self.xarray_store, self.writer, encoding=encoding) # type: ignore[no-untyped-call]

self._initialized = True

def write_eager(self):
def write_eager(self) -> None:
"""
Write in-memory variables to store.
Expand All @@ -235,7 +240,7 @@ def write_eager(self):

def write_lazy(
self,
chunkmanager_store_kwargs: MutableMapping | None = None,
chunkmanager_store_kwargs: MutableMapping[Any, Any] | None = None,
split_every: int | None = None,
) -> None:
"""
Expand All @@ -255,7 +260,7 @@ def write_lazy(
# each of those zarr.Array.store contains the changesets we need
stored_arrays = self.writer.sync(
compute=False, chunkmanager_store_kwargs=chunkmanager_store_kwargs
)
) # type: ignore[no-untyped-call]

# Now we tree-reduce all changesets
merged_store = stateful_store_reduce(
Expand All @@ -280,10 +285,10 @@ def to_icechunk(
safe_chunks: bool = True,
append_dim: Hashable | None = None,
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
encoding: Mapping | None = None,
chunkmanager_store_kwargs: MutableMapping | None = None,
encoding: Mapping[Any, Any] | None = None,
chunkmanager_store_kwargs: MutableMapping[Any, Any] | None = None,
split_every: int | None = None,
**kwargs,
**kwargs: Any,
) -> None:
"""
Write an Xarray Dataset to a group of an icechunk store.
Expand Down Expand Up @@ -358,6 +363,8 @@ def to_icechunk(
Additional keyword arguments passed on to the `ChunkManager.store` method used to store
chunked arrays. For example for a dask array additional kwargs will be passed eventually to
`dask.array.store()`. Experimental API that should not be relied upon.
split_every: int, optional
Number of tasks to merge at every level of the tree reduction.
Returns
-------
Expand Down

0 comments on commit bbed44f

Please sign in to comment.