Skip to content

Commit

Permalink
Generalize handling of chunked array types (pydata#7019)
Browse files Browse the repository at this point in the history
* generalise chunk methods to allow cubed

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fic typing typo

* fixed circular import

* fix some mypy errors

* added cubed to mypy ignore list

* simplify __array_ufunc__ check

* Revert "simplify __array_ufunc__ check" as I pushed to wrong branch

This reverts commit cdcb3fb.

* update cubed array type

* fix missed conflict

* sketch for ChunkManager adapter class

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove erroneous docstring about usage of map_blocks

Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>

* apply_ufunc -> apply_gufunc

Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>

* chunk -> from_array

Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>

* remove staticmethods

* attempt to type methods of ABC

* from_array

* attempt to specify types

* method for checking array type

* Update pyproject.toml

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixed import errors

* generalize .chunk method kwargs

* used dask functions in dask chunkmanager

* define signatures for apply_gufunc, blockwise, map_blocks

* prototype function to detect which parallel backend to use

* add cubed.apply_gufunc

* ruffify

* add rechunk and compute methods for cubed

* xr.apply_ufunc now dispatches to chunkmanager.apply_gufunc

* CubedManager.chunks

* attempt to keep dask and cubed imports lazy

* generalize idxmax

* move unify_chunks import to ChunkManager

* generalize Dataset.load()

* check explicitly for chunks attribute instead of hard-coding cubed

* better function names

* add cubed version of unify_chunks

* recognize wrapped duck dask arrays (e.g. pint wrapping dask)

* add some tests for fetching ChunkManagers

* add from_array_kwargs to open_dataset

* add from_array_kwargs to open_zarr

* pipe constructors through chunkmanager

* generalize map_blocks inside coding

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixed full_like

* add from_array_kwargs to open_zarr

* don't import dask.tokenize

* fix bugs with passing from_array_kwargs down

* generalise reductions by adding to chunkmanager

* moved nanfirst/nanlast to duck_array_ops from dask_array_ops

* generalize interp

* generalized chunk_hint function inside indexing

* DaskIndexingAdapter->ChunkedIndexingAdapter

* Revert "DaskIndexingAdapter->ChunkedIndexingAdapter"

This reverts commit 4ca044b.

* pass cubed-related kwargs down through to_zarr by adding .store to ChunkManager

* fix typing_extensions on py3.9

* fix ImportError with cubed array type

* give up trying to import TypeAlias in CI

* fix import of T_Chunks

* fix no_implicit_optional warnings

* don't define CubedManager if cubed can't be imported

* fix local mypy errors

* don't explicitly pass enforce_ndim into dask.array.map_blocks

* fix drop_axis default

* use indexing adapter on cubed arrays too

* use array API-compatible version of astype function

* whatsnew

* document new kwargs

* add chunkmanager entrypoint

* move CubedManager to a separate package

* guess chunkmanager based on whats available

* fix bug with tokenizing

* adapt tests to emulate existence of entrypoint

* use fixture to setup/teardown dummy entrypoint

* refactor to make DaskManager unavailable if dask not installed

* typing

* move whatsnew to latest xarray version

* remove superfluous lines from whatsnew

* fix bug where zarr backend attempted to use dask when not installed

* Remove rogue print statement

Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>

* Clarify what's new

Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>

* use monkeypatch to mock registering of dummy chunkmanager

* more tests for guessing chunkmanager correctly

* raise TypeError if no chunkmanager found for array types

* Correct is_chunked_array check

Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>

* vendor dask.array.core.normalize_chunks

* add default implementation of rechunk in ABC

* remove cubed-specific type check in daskmanager

* nanfirst->chunked_nanfirst

* revert adding cubed to NON_NUMPY_SUPPORTED_ARRAY_TYPES

* licensing to vendor functions from dask

* fix bug

* ignore mypy error

* separate chunk_manager kwarg from from_array_kwargs dict

* rename kwarg to chunked_array_type

* refactor from_array_kwargs in .chunk ready for deprecation

* print statements in test so I can comment on them

* remove print statements now I've commented on them in PR

* should fix dask naming tests

* make dask-specific kwargs explicit in from_array

* debugging print statements

* Revert "debugging print statements"

This reverts commit 7dc6581.

* fix gnarly bug with auto-determining chunksizes caused by not referring to dask.config

* hopefully fix broken docstring

* Revert "make dask-specific kwargs explicit in from_array"

This reverts commit 53d6094.

* show chunksize limit used in failing tests

* move lazy indexing adapter up out of chunkmanager code

* try upgrading minimum version of dask

* Revert "try upgrading minimum version of dask"

This reverts commit 796a577.

* un-vendor dask.array.core.normalize_chunks

* refactored to all passing ChunkManagerEntrypoint objects directly

* Remove redundant Nones from types

Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com>

* From future import annotations

Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* From functools import annotations

Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* From future import annotations

Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* defined type for NormalizedChunks

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* standardized capitalization of ChunkManagerEntrypoint

* ensure ruff doesn't remove import

* ignore remaining typing errors stemming from unclear dask typing for chunks arguments

* rename store_kwargs->chunkmanager_store_kwargs

* missed return value

* array API fixes for astype

* Revert "array API fixes for astype"

This reverts commit 9cd9078.

* Apply suggestions from code review

* Update xarray/tests/test_parallelcompat.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* overridden -> subclassed

Co-authored-by: Justus Magin <keewis@users.noreply.github.com>

* from_array_kwargs is optional

Co-authored-by: Justus Magin <keewis@users.noreply.github.com>

* ensured all compute calls go through chunkmanager

* Raise if multiple chunkmanagers recognize array type

Co-authored-by: Justus Magin <keewis@users.noreply.github.com>

* from_array_kwargs is optional

Co-authored-by: Justus Magin <keewis@users.noreply.github.com>

* from_array_kwargs is optional

Co-authored-by: Justus Magin <keewis@users.noreply.github.com>

* from_array_kwargs is optional

Co-authored-by: Justus Magin <keewis@users.noreply.github.com>

* from_array_kwargs is optional

Co-authored-by: Justus Magin <keewis@users.noreply.github.com>

* from_array_kwargs is optional

Co-authored-by: Justus Magin <keewis@users.noreply.github.com>

* fixes for chunk methods

* correct readme to reflect fact we aren't vendoring dask in this PR any more

* update whatsnew

* more docstring corrections

* remove comment

* Raise NotImplementedErrors in all abstract methods

Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com>

* type hints for every arg in ChunkManagerEntryPOint methods

* more explicit typing + fixes for mypy errors revealed

* Keyword-only arguments in full_like etc.

Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* None as default instead of {}

* fix bug apparently introduced by changing default type of drop_axis kwarg to map_blocks

* Removed hopefully-unnecessary mypy ignore

Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* removed unnecessary mypy ignores

* change default value of drop_axis kwarg in map_blocks and catch when dask version < 2022.9.1

* fix checking of dask version in map_blocks

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com>
Co-authored-by: Justus Magin <keewis@users.noreply.github.com>
  • Loading branch information
5 people authored and dstansby committed Jun 28, 2023
1 parent 3b20610 commit 030e5c1
Show file tree
Hide file tree
Showing 29 changed files with 1,406 additions and 209 deletions.
5 changes: 5 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ Documentation
Internal Changes
~~~~~~~~~~~~~~~~

- Experimental support for wrapping chunked array libraries other than dask.
A new ABC is defined - :py:class:`xr.core.parallelcompat.ChunkManagerEntrypoint` - which can be subclassed and then
registered by alternative chunked array implementations. (:issue:`6807`, :pull:`7019`)
By `Tom Nicholas <https://github.com/TomNicholas>`_.


.. _whats-new.2023.04.2:

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ module = [
"cf_units.*",
"cfgrib.*",
"cftime.*",
"cubed.*",
"cupy.*",
"fsspec.*",
"h5netcdf.*",
Expand Down
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ xarray =
static/css/*
static/html/*

[options.entry_points]
xarray.chunkmanagers =
dask = xarray.core.daskmanager:DaskManager

[tool:pytest]
python_files = test_*.py
testpaths = xarray/tests properties
Expand Down
85 changes: 73 additions & 12 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,16 @@
from glob import glob
from io import BytesIO
from numbers import Number
from typing import TYPE_CHECKING, Any, Callable, Final, Literal, Union, cast, overload
from typing import (
TYPE_CHECKING,
Any,
Callable,
Final,
Literal,
Union,
cast,
overload,
)

import numpy as np

Expand All @@ -20,9 +29,11 @@
_nested_combine,
combine_by_coords,
)
from xarray.core.daskmanager import DaskManager
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset, _get_chunk, _maybe_chunk
from xarray.core.indexes import Index
from xarray.core.parallelcompat import guess_chunkmanager
from xarray.core.utils import is_remote_uri

if TYPE_CHECKING:
Expand All @@ -38,6 +49,7 @@
CompatOptions,
JoinOptions,
NestedSequence,
T_Chunks,
)

T_NetcdfEngine = Literal["netcdf4", "scipy", "h5netcdf"]
Expand All @@ -48,7 +60,6 @@
str, # no nice typing support for custom backends
None,
]
T_Chunks = Union[int, dict[Any, Any], Literal["auto"], None]
T_NetcdfTypes = Literal[
"NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", "NETCDF3_CLASSIC"
]
Expand Down Expand Up @@ -297,17 +308,27 @@ def _chunk_ds(
chunks,
overwrite_encoded_chunks,
inline_array,
chunked_array_type,
from_array_kwargs,
**extra_tokens,
):
from dask.base import tokenize
chunkmanager = guess_chunkmanager(chunked_array_type)

# TODO refactor to move this dask-specific logic inside the DaskManager class
if isinstance(chunkmanager, DaskManager):
from dask.base import tokenize

mtime = _get_mtime(filename_or_obj)
token = tokenize(filename_or_obj, mtime, engine, chunks, **extra_tokens)
name_prefix = f"open_dataset-{token}"
mtime = _get_mtime(filename_or_obj)
token = tokenize(filename_or_obj, mtime, engine, chunks, **extra_tokens)
name_prefix = "open_dataset-"
else:
# not used
token = (None,)
name_prefix = None

variables = {}
for name, var in backend_ds.variables.items():
var_chunks = _get_chunk(var, chunks)
var_chunks = _get_chunk(var, chunks, chunkmanager)
variables[name] = _maybe_chunk(
name,
var,
Expand All @@ -316,6 +337,8 @@ def _chunk_ds(
name_prefix=name_prefix,
token=token,
inline_array=inline_array,
chunked_array_type=chunkmanager,
from_array_kwargs=from_array_kwargs.copy(),
)
return backend_ds._replace(variables)

Expand All @@ -328,6 +351,8 @@ def _dataset_from_backend_dataset(
cache,
overwrite_encoded_chunks,
inline_array,
chunked_array_type,
from_array_kwargs,
**extra_tokens,
):
if not isinstance(chunks, (int, dict)) and chunks not in {None, "auto"}:
Expand All @@ -346,6 +371,8 @@ def _dataset_from_backend_dataset(
chunks,
overwrite_encoded_chunks,
inline_array,
chunked_array_type,
from_array_kwargs,
**extra_tokens,
)

Expand Down Expand Up @@ -373,6 +400,8 @@ def open_dataset(
decode_coords: Literal["coordinates", "all"] | bool | None = None,
drop_variables: str | Iterable[str] | None = None,
inline_array: bool = False,
chunked_array_type: str | None = None,
from_array_kwargs: dict[str, Any] | None = None,
backend_kwargs: dict[str, Any] | None = None,
**kwargs,
) -> Dataset:
Expand Down Expand Up @@ -465,6 +494,15 @@ def open_dataset(
itself, and each chunk refers to that task by its key. With
``inline_array=True``, Dask will instead inline the array directly
in the values of the task graph. See :py:func:`dask.array.from_array`.
chunked_array_type: str, optional
Which chunked array type to coerce this datasets' arrays to.
Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEnetryPoint` system.
Experimental API that should not be relied upon.
from_array_kwargs: dict
Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create
chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg.
For example if :py:func:`dask.array.Array` objects are used for chunking, additional kwargs will be passed
to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon.
backend_kwargs: dict
Additional keyword arguments passed on to the engine open function,
equivalent to `**kwargs`.
Expand Down Expand Up @@ -508,6 +546,9 @@ def open_dataset(
if engine is None:
engine = plugins.guess_engine(filename_or_obj)

if from_array_kwargs is None:
from_array_kwargs = {}

backend = plugins.get_backend(engine)

decoders = _resolve_decoders_kwargs(
Expand Down Expand Up @@ -536,6 +577,8 @@ def open_dataset(
cache,
overwrite_encoded_chunks,
inline_array,
chunked_array_type,
from_array_kwargs,
drop_variables=drop_variables,
**decoders,
**kwargs,
Expand All @@ -546,8 +589,8 @@ def open_dataset(
def open_dataarray(
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
*,
engine: T_Engine = None,
chunks: T_Chunks = None,
engine: T_Engine | None = None,
chunks: T_Chunks | None = None,
cache: bool | None = None,
decode_cf: bool | None = None,
mask_and_scale: bool | None = None,
Expand All @@ -558,6 +601,8 @@ def open_dataarray(
decode_coords: Literal["coordinates", "all"] | bool | None = None,
drop_variables: str | Iterable[str] | None = None,
inline_array: bool = False,
chunked_array_type: str | None = None,
from_array_kwargs: dict[str, Any] | None = None,
backend_kwargs: dict[str, Any] | None = None,
**kwargs,
) -> DataArray:
Expand Down Expand Up @@ -652,6 +697,15 @@ def open_dataarray(
itself, and each chunk refers to that task by its key. With
``inline_array=True``, Dask will instead inline the array directly
in the values of the task graph. See :py:func:`dask.array.from_array`.
chunked_array_type: str, optional
Which chunked array type to coerce the underlying data array to.
Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEnetryPoint` system.
Experimental API that should not be relied upon.
from_array_kwargs: dict
Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create
chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg.
For example if :py:func:`dask.array.Array` objects are used for chunking, additional kwargs will be passed
to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon.
backend_kwargs: dict
Additional keyword arguments passed on to the engine open function,
equivalent to `**kwargs`.
Expand Down Expand Up @@ -695,6 +749,8 @@ def open_dataarray(
cache=cache,
drop_variables=drop_variables,
inline_array=inline_array,
chunked_array_type=chunked_array_type,
from_array_kwargs=from_array_kwargs,
backend_kwargs=backend_kwargs,
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
Expand Down Expand Up @@ -726,7 +782,7 @@ def open_dataarray(

def open_mfdataset(
paths: str | NestedSequence[str | os.PathLike],
chunks: T_Chunks = None,
chunks: T_Chunks | None = None,
concat_dim: str
| DataArray
| Index
Expand All @@ -736,7 +792,7 @@ def open_mfdataset(
| None = None,
compat: CompatOptions = "no_conflicts",
preprocess: Callable[[Dataset], Dataset] | None = None,
engine: T_Engine = None,
engine: T_Engine | None = None,
data_vars: Literal["all", "minimal", "different"] | list[str] = "all",
coords="different",
combine: Literal["by_coords", "nested"] = "by_coords",
Expand Down Expand Up @@ -1490,6 +1546,7 @@ def to_zarr(
safe_chunks: bool = True,
storage_options: dict[str, str] | None = None,
zarr_version: int | None = None,
chunkmanager_store_kwargs: dict[str, Any] | None = None,
) -> backends.ZarrStore:
...

Expand All @@ -1512,6 +1569,7 @@ def to_zarr(
safe_chunks: bool = True,
storage_options: dict[str, str] | None = None,
zarr_version: int | None = None,
chunkmanager_store_kwargs: dict[str, Any] | None = None,
) -> Delayed:
...

Expand All @@ -1531,6 +1589,7 @@ def to_zarr(
safe_chunks: bool = True,
storage_options: dict[str, str] | None = None,
zarr_version: int | None = None,
chunkmanager_store_kwargs: dict[str, Any] | None = None,
) -> backends.ZarrStore | Delayed:
"""This function creates an appropriate datastore for writing a dataset to
a zarr ztore
Expand Down Expand Up @@ -1652,7 +1711,9 @@ def to_zarr(
writer = ArrayWriter()
# TODO: figure out how to properly handle unlimited_dims
dump_to_store(dataset, zstore, writer, encoding=encoding)
writes = writer.sync(compute=compute)
writes = writer.sync(
compute=compute, chunkmanager_store_kwargs=chunkmanager_store_kwargs
)

if compute:
_finalize_store(writes, zstore)
Expand Down
15 changes: 10 additions & 5 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

from xarray.conventions import cf_encoder
from xarray.core import indexing
from xarray.core.pycompat import is_duck_dask_array
from xarray.core.parallelcompat import get_chunked_array_type
from xarray.core.pycompat import is_chunked_array
from xarray.core.utils import FrozenDict, NdimSizeLenMixin, is_remote_uri

if TYPE_CHECKING:
Expand Down Expand Up @@ -153,7 +154,7 @@ def __init__(self, lock=None):
self.lock = lock

def add(self, source, target, region=None):
if is_duck_dask_array(source):
if is_chunked_array(source):
self.sources.append(source)
self.targets.append(target)
self.regions.append(region)
Expand All @@ -163,21 +164,25 @@ def add(self, source, target, region=None):
else:
target[...] = source

def sync(self, compute=True):
def sync(self, compute=True, chunkmanager_store_kwargs=None):
if self.sources:
import dask.array as da
chunkmanager = get_chunked_array_type(*self.sources)

# TODO: consider wrapping targets with dask.delayed, if this makes
# for any discernible difference in perforance, e.g.,
# targets = [dask.delayed(t) for t in self.targets]

delayed_store = da.store(
if chunkmanager_store_kwargs is None:
chunkmanager_store_kwargs = {}

delayed_store = chunkmanager.store(
self.sources,
self.targets,
lock=self.lock,
compute=compute,
flush=True,
regions=self.regions,
**chunkmanager_store_kwargs,
)
self.sources = []
self.targets = []
Expand Down
2 changes: 1 addition & 1 deletion xarray/backends/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def refresh_engines() -> None:

def guess_engine(
store_spec: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
):
) -> str | type[BackendEntrypoint]:
engines = list_engines()

for engine, backend in engines.items():
Expand Down
23 changes: 21 additions & 2 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from xarray.backends.store import StoreBackendEntrypoint
from xarray.core import indexing
from xarray.core.parallelcompat import guess_chunkmanager
from xarray.core.pycompat import integer_types
from xarray.core.utils import (
FrozenDict,
Expand Down Expand Up @@ -716,6 +717,8 @@ def open_zarr(
decode_timedelta=None,
use_cftime=None,
zarr_version=None,
chunked_array_type: str | None = None,
from_array_kwargs: dict[str, Any] | None = None,
**kwargs,
):
"""Load and decode a dataset from a Zarr store.
Expand Down Expand Up @@ -800,6 +803,15 @@ def open_zarr(
The desired zarr spec version to target (currently 2 or 3). The default
of None will attempt to determine the zarr version from ``store`` when
possible, otherwise defaulting to 2.
chunked_array_type: str, optional
Which chunked array type to coerce this datasets' arrays to.
Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEnetryPoint` system.
Experimental API that should not be relied upon.
from_array_kwargs: dict, optional
Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create
chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg.
Defaults to {'manager': 'dask'}, meaning additional kwargs will be passed eventually to
:py:func:`dask.array.from_array`. Experimental API that should not be relied upon.
Returns
-------
Expand All @@ -817,12 +829,17 @@ def open_zarr(
"""
from xarray.backends.api import open_dataset

if from_array_kwargs is None:
from_array_kwargs = {}

if chunks == "auto":
try:
import dask.array # noqa
guess_chunkmanager(
chunked_array_type
) # attempt to import that parallel backend

chunks = {}
except ImportError:
except ValueError:
chunks = None

if kwargs:
Expand Down Expand Up @@ -851,6 +868,8 @@ def open_zarr(
engine="zarr",
chunks=chunks,
drop_variables=drop_variables,
chunked_array_type=chunked_array_type,
from_array_kwargs=from_array_kwargs,
backend_kwargs=backend_kwargs,
decode_timedelta=decode_timedelta,
use_cftime=use_cftime,
Expand Down
Loading

0 comments on commit 030e5c1

Please sign in to comment.