From 030e5c1f15f7295c4a27ea531eec5c2842e18169 Mon Sep 17 00:00:00 2001 From: Tom Nicholas Date: Thu, 18 May 2023 11:34:30 -0600 Subject: [PATCH] Generalize handling of chunked array types (#7019) * generalise chunk methods to allow cubed * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fic typing typo * fixed circular import * fix some mypy errors * added cubed to mypy ignore list * simplify __array_ufunc__ check * Revert "simplify __array_ufunc__ check" as I pushed to wrong branch This reverts commit cdcb3fb7d87e7420210ed4a3219f4ee90f56fe21. * update cubed array type * fix missed conflict * sketch for ChunkManager adapter class * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove erroneous docstring about usage of map_blocks Co-authored-by: Deepak Cherian * apply_ufunc -> apply_gufunc Co-authored-by: Deepak Cherian * chunk -> from_array Co-authored-by: Deepak Cherian * remove staticmethods * attempt to type methods of ABC * from_array * attempt to specify types * method for checking array type * Update pyproject.toml * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed import errors * generalize .chunk method kwargs * used dask functions in dask chunkmanager * define signatures for apply_gufunc, blockwise, map_blocks * prototype function to detect which parallel backend to use * add cubed.apply_gufunc * ruffify * add rechunk and compute methods for cubed * xr.apply_ufunc now dispatches to chunkmanager.apply_gufunc * CubedManager.chunks * attempt to keep dask and cubed imports lazy * generalize idxmax * move unify_chunks import to ChunkManager * generalize Dataset.load() * check explicitly for chunks attribute instead of hard-coding cubed * better function names * add cubed version of unify_chunks * recognize wrapped duck dask arrays (e.g. pint wrapping dask) * add some tests for fetching ChunkManagers * add from_array_kwargs to open_dataset * add from_array_kwargs to open_zarr * pipe constructors through chunkmanager * generalize map_blocks inside coding * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed full_like * add from_array_kwargs to open_zarr * don't import dask.tokenize * fix bugs with passing from_array_kwargs down * generalise reductions by adding to chunkmanager * moved nanfirst/nanlast to duck_array_ops from dask_array_ops * generalize interp * generalized chunk_hint function inside indexing * DaskIndexingAdapter->ChunkedIndexingAdapter * Revert "DaskIndexingAdapter->ChunkedIndexingAdapter" This reverts commit 4ca044b3d74be6ec05214fda9c0470b0a007a96f. * pass cubed-related kwargs down through to_zarr by adding .store to ChunkManager * fix typing_extensions on py3.9 * fix ImportError with cubed array type * give up trying to import TypeAlias in CI * fix import of T_Chunks * fix no_implicit_optional warnings * don't define CubedManager if cubed can't be imported * fix local mypy errors * don't explicitly pass enforce_ndim into dask.array.map_blocks * fix drop_axis default * use indexing adapter on cubed arrays too * use array API-compatible version of astype function * whatsnew * document new kwargs * add chunkmanager entrypoint * move CubedManager to a separate package * guess chunkmanager based on whats available * fix bug with tokenizing * adapt tests to emulate existence of entrypoint * use fixture to setup/teardown dummy entrypoint * refactor to make DaskManager unavailable if dask not installed * typing * move whatsnew to latest xarray version * remove superfluous lines from whatsnew * fix bug where zarr backend attempted to use dask when not installed * Remove rogue print statement Co-authored-by: Deepak Cherian * Clarify what's new Co-authored-by: Deepak Cherian * use monkeypatch to mock registering of dummy chunkmanager * more tests for guessing chunkmanager correctly * raise TypeError if no chunkmanager found for array types * Correct is_chunked_array check Co-authored-by: Deepak Cherian * vendor dask.array.core.normalize_chunks * add default implementation of rechunk in ABC * remove cubed-specific type check in daskmanager * nanfirst->chunked_nanfirst * revert adding cubed to NON_NUMPY_SUPPORTED_ARRAY_TYPES * licensing to vendor functions from dask * fix bug * ignore mypy error * separate chunk_manager kwarg from from_array_kwargs dict * rename kwarg to chunked_array_type * refactor from_array_kwargs in .chunk ready for deprecation * print statements in test so I can comment on them * remove print statements now I've commented on them in PR * should fix dask naming tests * make dask-specific kwargs explicit in from_array * debugging print statements * Revert "debugging print statements" This reverts commit 7dc658186194aa31e788e106c7527a8909de0a2e. * fix gnarly bug with auto-determining chunksizes caused by not referring to dask.config * hopefully fix broken docstring * Revert "make dask-specific kwargs explicit in from_array" This reverts commit 53d6094f9624a459512e9f88a6bb5c55d0e65a19. * show chunksize limit used in failing tests * move lazy indexing adapter up out of chunkmanager code * try upgrading minimum version of dask * Revert "try upgrading minimum version of dask" This reverts commit 796a577cbccf14701a0da1f5cb8860006b78bba4. * un-vendor dask.array.core.normalize_chunks * refactored to all passing ChunkManagerEntrypoint objects directly * Remove redundant Nones from types Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * From future import annotations Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * From functools import annotations Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * From future import annotations Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * defined type for NormalizedChunks * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * standardized capitalization of ChunkManagerEntrypoint * ensure ruff doesn't remove import * ignore remaining typing errors stemming from unclear dask typing for chunks arguments * rename store_kwargs->chunkmanager_store_kwargs * missed return value * array API fixes for astype * Revert "array API fixes for astype" This reverts commit 9cd9078dfe645543a3f1fff0a776e080800e6820. * Apply suggestions from code review * Update xarray/tests/test_parallelcompat.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * overridden -> subclassed Co-authored-by: Justus Magin * from_array_kwargs is optional Co-authored-by: Justus Magin * ensured all compute calls go through chunkmanager * Raise if multiple chunkmanagers recognize array type Co-authored-by: Justus Magin * from_array_kwargs is optional Co-authored-by: Justus Magin * from_array_kwargs is optional Co-authored-by: Justus Magin * from_array_kwargs is optional Co-authored-by: Justus Magin * from_array_kwargs is optional Co-authored-by: Justus Magin * from_array_kwargs is optional Co-authored-by: Justus Magin * fixes for chunk methods * correct readme to reflect fact we aren't vendoring dask in this PR any more * update whatsnew * more docstring corrections * remove comment * Raise NotImplementedErrors in all abstract methods Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * type hints for every arg in ChunkManagerEntryPOint methods * more explicit typing + fixes for mypy errors revealed * Keyword-only arguments in full_like etc. Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * None as default instead of {} * fix bug apparently introduced by changing default type of drop_axis kwarg to map_blocks * Removed hopefully-unnecessary mypy ignore Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * removed unnecessary mypy ignores * change default value of drop_axis kwarg in map_blocks and catch when dask version < 2022.9.1 * fix checking of dask version in map_blocks --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> Co-authored-by: Justus Magin --- doc/whats-new.rst | 5 + pyproject.toml | 1 + setup.cfg | 4 + xarray/backends/api.py | 85 +++++++-- xarray/backends/common.py | 15 +- xarray/backends/plugins.py | 2 +- xarray/backends/zarr.py | 23 ++- xarray/coding/strings.py | 14 +- xarray/coding/variables.py | 15 +- xarray/core/common.py | 255 ++++++++++++++++++++++--- xarray/core/computation.py | 45 +++-- xarray/core/dask_array_ops.py | 37 ---- xarray/core/daskmanager.py | 215 +++++++++++++++++++++ xarray/core/dataarray.py | 18 +- xarray/core/dataset.py | 106 ++++++++--- xarray/core/duck_array_ops.py | 45 ++++- xarray/core/indexing.py | 15 +- xarray/core/missing.py | 16 +- xarray/core/nanops.py | 5 +- xarray/core/parallelcompat.py | 280 ++++++++++++++++++++++++++++ xarray/core/pycompat.py | 10 +- xarray/core/rolling.py | 6 +- xarray/core/types.py | 13 ++ xarray/core/utils.py | 63 +++++++ xarray/core/variable.py | 86 ++++++--- xarray/core/weighted.py | 5 +- xarray/tests/test_dask.py | 11 +- xarray/tests/test_parallelcompat.py | 219 ++++++++++++++++++++++ xarray/tests/test_plugins.py | 1 + 29 files changed, 1406 insertions(+), 209 deletions(-) create mode 100644 xarray/core/daskmanager.py create mode 100644 xarray/core/parallelcompat.py create mode 100644 xarray/tests/test_parallelcompat.py diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 3d661b30c07..968986c63dd 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -58,6 +58,11 @@ Documentation Internal Changes ~~~~~~~~~~~~~~~~ +- Experimental support for wrapping chunked array libraries other than dask. + A new ABC is defined - :py:class:`xr.core.parallelcompat.ChunkManagerEntrypoint` - which can be subclassed and then + registered by alternative chunked array implementations. (:issue:`6807`, :pull:`7019`) + By `Tom Nicholas `_. + .. _whats-new.2023.04.2: diff --git a/pyproject.toml b/pyproject.toml index 88b34d002d5..4d63fd564ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ module = [ "cf_units.*", "cfgrib.*", "cftime.*", + "cubed.*", "cupy.*", "fsspec.*", "h5netcdf.*", diff --git a/setup.cfg b/setup.cfg index 81b7f1c4a0e..85ac8e259e5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -132,6 +132,10 @@ xarray = static/css/* static/html/* +[options.entry_points] +xarray.chunkmanagers = + dask = xarray.core.daskmanager:DaskManager + [tool:pytest] python_files = test_*.py testpaths = xarray/tests properties diff --git a/xarray/backends/api.py b/xarray/backends/api.py index e5adedbb576..0157e0d9d66 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -6,7 +6,16 @@ from glob import glob from io import BytesIO from numbers import Number -from typing import TYPE_CHECKING, Any, Callable, Final, Literal, Union, cast, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Final, + Literal, + Union, + cast, + overload, +) import numpy as np @@ -20,9 +29,11 @@ _nested_combine, combine_by_coords, ) +from xarray.core.daskmanager import DaskManager from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset, _get_chunk, _maybe_chunk from xarray.core.indexes import Index +from xarray.core.parallelcompat import guess_chunkmanager from xarray.core.utils import is_remote_uri if TYPE_CHECKING: @@ -38,6 +49,7 @@ CompatOptions, JoinOptions, NestedSequence, + T_Chunks, ) T_NetcdfEngine = Literal["netcdf4", "scipy", "h5netcdf"] @@ -48,7 +60,6 @@ str, # no nice typing support for custom backends None, ] - T_Chunks = Union[int, dict[Any, Any], Literal["auto"], None] T_NetcdfTypes = Literal[ "NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", "NETCDF3_CLASSIC" ] @@ -297,17 +308,27 @@ def _chunk_ds( chunks, overwrite_encoded_chunks, inline_array, + chunked_array_type, + from_array_kwargs, **extra_tokens, ): - from dask.base import tokenize + chunkmanager = guess_chunkmanager(chunked_array_type) + + # TODO refactor to move this dask-specific logic inside the DaskManager class + if isinstance(chunkmanager, DaskManager): + from dask.base import tokenize - mtime = _get_mtime(filename_or_obj) - token = tokenize(filename_or_obj, mtime, engine, chunks, **extra_tokens) - name_prefix = f"open_dataset-{token}" + mtime = _get_mtime(filename_or_obj) + token = tokenize(filename_or_obj, mtime, engine, chunks, **extra_tokens) + name_prefix = "open_dataset-" + else: + # not used + token = (None,) + name_prefix = None variables = {} for name, var in backend_ds.variables.items(): - var_chunks = _get_chunk(var, chunks) + var_chunks = _get_chunk(var, chunks, chunkmanager) variables[name] = _maybe_chunk( name, var, @@ -316,6 +337,8 @@ def _chunk_ds( name_prefix=name_prefix, token=token, inline_array=inline_array, + chunked_array_type=chunkmanager, + from_array_kwargs=from_array_kwargs.copy(), ) return backend_ds._replace(variables) @@ -328,6 +351,8 @@ def _dataset_from_backend_dataset( cache, overwrite_encoded_chunks, inline_array, + chunked_array_type, + from_array_kwargs, **extra_tokens, ): if not isinstance(chunks, (int, dict)) and chunks not in {None, "auto"}: @@ -346,6 +371,8 @@ def _dataset_from_backend_dataset( chunks, overwrite_encoded_chunks, inline_array, + chunked_array_type, + from_array_kwargs, **extra_tokens, ) @@ -373,6 +400,8 @@ def open_dataset( decode_coords: Literal["coordinates", "all"] | bool | None = None, drop_variables: str | Iterable[str] | None = None, inline_array: bool = False, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, backend_kwargs: dict[str, Any] | None = None, **kwargs, ) -> Dataset: @@ -465,6 +494,15 @@ def open_dataset( itself, and each chunk refers to that task by its key. With ``inline_array=True``, Dask will instead inline the array directly in the values of the task graph. See :py:func:`dask.array.from_array`. + chunked_array_type: str, optional + Which chunked array type to coerce this datasets' arrays to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEnetryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example if :py:func:`dask.array.Array` objects are used for chunking, additional kwargs will be passed + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. backend_kwargs: dict Additional keyword arguments passed on to the engine open function, equivalent to `**kwargs`. @@ -508,6 +546,9 @@ def open_dataset( if engine is None: engine = plugins.guess_engine(filename_or_obj) + if from_array_kwargs is None: + from_array_kwargs = {} + backend = plugins.get_backend(engine) decoders = _resolve_decoders_kwargs( @@ -536,6 +577,8 @@ def open_dataset( cache, overwrite_encoded_chunks, inline_array, + chunked_array_type, + from_array_kwargs, drop_variables=drop_variables, **decoders, **kwargs, @@ -546,8 +589,8 @@ def open_dataset( def open_dataarray( filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, *, - engine: T_Engine = None, - chunks: T_Chunks = None, + engine: T_Engine | None = None, + chunks: T_Chunks | None = None, cache: bool | None = None, decode_cf: bool | None = None, mask_and_scale: bool | None = None, @@ -558,6 +601,8 @@ def open_dataarray( decode_coords: Literal["coordinates", "all"] | bool | None = None, drop_variables: str | Iterable[str] | None = None, inline_array: bool = False, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, backend_kwargs: dict[str, Any] | None = None, **kwargs, ) -> DataArray: @@ -652,6 +697,15 @@ def open_dataarray( itself, and each chunk refers to that task by its key. With ``inline_array=True``, Dask will instead inline the array directly in the values of the task graph. See :py:func:`dask.array.from_array`. + chunked_array_type: str, optional + Which chunked array type to coerce the underlying data array to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEnetryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example if :py:func:`dask.array.Array` objects are used for chunking, additional kwargs will be passed + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. backend_kwargs: dict Additional keyword arguments passed on to the engine open function, equivalent to `**kwargs`. @@ -695,6 +749,8 @@ def open_dataarray( cache=cache, drop_variables=drop_variables, inline_array=inline_array, + chunked_array_type=chunked_array_type, + from_array_kwargs=from_array_kwargs, backend_kwargs=backend_kwargs, use_cftime=use_cftime, decode_timedelta=decode_timedelta, @@ -726,7 +782,7 @@ def open_dataarray( def open_mfdataset( paths: str | NestedSequence[str | os.PathLike], - chunks: T_Chunks = None, + chunks: T_Chunks | None = None, concat_dim: str | DataArray | Index @@ -736,7 +792,7 @@ def open_mfdataset( | None = None, compat: CompatOptions = "no_conflicts", preprocess: Callable[[Dataset], Dataset] | None = None, - engine: T_Engine = None, + engine: T_Engine | None = None, data_vars: Literal["all", "minimal", "different"] | list[str] = "all", coords="different", combine: Literal["by_coords", "nested"] = "by_coords", @@ -1490,6 +1546,7 @@ def to_zarr( safe_chunks: bool = True, storage_options: dict[str, str] | None = None, zarr_version: int | None = None, + chunkmanager_store_kwargs: dict[str, Any] | None = None, ) -> backends.ZarrStore: ... @@ -1512,6 +1569,7 @@ def to_zarr( safe_chunks: bool = True, storage_options: dict[str, str] | None = None, zarr_version: int | None = None, + chunkmanager_store_kwargs: dict[str, Any] | None = None, ) -> Delayed: ... @@ -1531,6 +1589,7 @@ def to_zarr( safe_chunks: bool = True, storage_options: dict[str, str] | None = None, zarr_version: int | None = None, + chunkmanager_store_kwargs: dict[str, Any] | None = None, ) -> backends.ZarrStore | Delayed: """This function creates an appropriate datastore for writing a dataset to a zarr ztore @@ -1652,7 +1711,9 @@ def to_zarr( writer = ArrayWriter() # TODO: figure out how to properly handle unlimited_dims dump_to_store(dataset, zstore, writer, encoding=encoding) - writes = writer.sync(compute=compute) + writes = writer.sync( + compute=compute, chunkmanager_store_kwargs=chunkmanager_store_kwargs + ) if compute: _finalize_store(writes, zstore) diff --git a/xarray/backends/common.py b/xarray/backends/common.py index bca8b7f668a..50ac606a83e 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -11,7 +11,8 @@ from xarray.conventions import cf_encoder from xarray.core import indexing -from xarray.core.pycompat import is_duck_dask_array +from xarray.core.parallelcompat import get_chunked_array_type +from xarray.core.pycompat import is_chunked_array from xarray.core.utils import FrozenDict, NdimSizeLenMixin, is_remote_uri if TYPE_CHECKING: @@ -153,7 +154,7 @@ def __init__(self, lock=None): self.lock = lock def add(self, source, target, region=None): - if is_duck_dask_array(source): + if is_chunked_array(source): self.sources.append(source) self.targets.append(target) self.regions.append(region) @@ -163,21 +164,25 @@ def add(self, source, target, region=None): else: target[...] = source - def sync(self, compute=True): + def sync(self, compute=True, chunkmanager_store_kwargs=None): if self.sources: - import dask.array as da + chunkmanager = get_chunked_array_type(*self.sources) # TODO: consider wrapping targets with dask.delayed, if this makes # for any discernible difference in perforance, e.g., # targets = [dask.delayed(t) for t in self.targets] - delayed_store = da.store( + if chunkmanager_store_kwargs is None: + chunkmanager_store_kwargs = {} + + delayed_store = chunkmanager.store( self.sources, self.targets, lock=self.lock, compute=compute, flush=True, regions=self.regions, + **chunkmanager_store_kwargs, ) self.sources = [] self.targets = [] diff --git a/xarray/backends/plugins.py b/xarray/backends/plugins.py index 232c2300192..a62ca6c9862 100644 --- a/xarray/backends/plugins.py +++ b/xarray/backends/plugins.py @@ -146,7 +146,7 @@ def refresh_engines() -> None: def guess_engine( store_spec: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, -): +) -> str | type[BackendEntrypoint]: engines = list_engines() for engine, backend in engines.items(): diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 7d21c771e06..a4012a8a733 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -19,6 +19,7 @@ ) from xarray.backends.store import StoreBackendEntrypoint from xarray.core import indexing +from xarray.core.parallelcompat import guess_chunkmanager from xarray.core.pycompat import integer_types from xarray.core.utils import ( FrozenDict, @@ -716,6 +717,8 @@ def open_zarr( decode_timedelta=None, use_cftime=None, zarr_version=None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, **kwargs, ): """Load and decode a dataset from a Zarr store. @@ -800,6 +803,15 @@ def open_zarr( The desired zarr spec version to target (currently 2 or 3). The default of None will attempt to determine the zarr version from ``store`` when possible, otherwise defaulting to 2. + chunked_array_type: str, optional + Which chunked array type to coerce this datasets' arrays to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEnetryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict, optional + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + Defaults to {'manager': 'dask'}, meaning additional kwargs will be passed eventually to + :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. Returns ------- @@ -817,12 +829,17 @@ def open_zarr( """ from xarray.backends.api import open_dataset + if from_array_kwargs is None: + from_array_kwargs = {} + if chunks == "auto": try: - import dask.array # noqa + guess_chunkmanager( + chunked_array_type + ) # attempt to import that parallel backend chunks = {} - except ImportError: + except ValueError: chunks = None if kwargs: @@ -851,6 +868,8 @@ def open_zarr( engine="zarr", chunks=chunks, drop_variables=drop_variables, + chunked_array_type=chunked_array_type, + from_array_kwargs=from_array_kwargs, backend_kwargs=backend_kwargs, decode_timedelta=decode_timedelta, use_cftime=use_cftime, diff --git a/xarray/coding/strings.py b/xarray/coding/strings.py index 61b3ab7c46c..ffe1b1a8d50 100644 --- a/xarray/coding/strings.py +++ b/xarray/coding/strings.py @@ -14,7 +14,7 @@ unpack_for_encoding, ) from xarray.core import indexing -from xarray.core.pycompat import is_duck_dask_array +from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array from xarray.core.variable import Variable @@ -134,10 +134,10 @@ def bytes_to_char(arr): if arr.dtype.kind != "S": raise ValueError("argument must have a fixed-width bytes dtype") - if is_duck_dask_array(arr): - import dask.array as da + if is_chunked_array(arr): + chunkmanager = get_chunked_array_type(arr) - return da.map_blocks( + return chunkmanager.map_blocks( _numpy_bytes_to_char, arr, dtype="S1", @@ -169,8 +169,8 @@ def char_to_bytes(arr): # can't make an S0 dtype return np.zeros(arr.shape[:-1], dtype=np.string_) - if is_duck_dask_array(arr): - import dask.array as da + if is_chunked_array(arr): + chunkmanager = get_chunked_array_type(arr) if len(arr.chunks[-1]) > 1: raise ValueError( @@ -179,7 +179,7 @@ def char_to_bytes(arr): ) dtype = np.dtype("S" + str(arr.shape[-1])) - return da.map_blocks( + return chunkmanager.map_blocks( _numpy_char_to_bytes, arr, dtype=dtype, diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index 5c6e51c2215..8ba7dcbb0e2 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -10,7 +10,8 @@ import pandas as pd from xarray.core import dtypes, duck_array_ops, indexing -from xarray.core.pycompat import is_duck_dask_array +from xarray.core.parallelcompat import get_chunked_array_type +from xarray.core.pycompat import is_chunked_array from xarray.core.variable import Variable if TYPE_CHECKING: @@ -57,7 +58,7 @@ class _ElementwiseFunctionArray(indexing.ExplicitlyIndexedNDArrayMixin): """ def __init__(self, array, func: Callable, dtype: np.typing.DTypeLike): - assert not is_duck_dask_array(array) + assert not is_chunked_array(array) self.array = indexing.as_indexable(array) self.func = func self._dtype = dtype @@ -158,10 +159,10 @@ def lazy_elemwise_func(array, func: Callable, dtype: np.typing.DTypeLike): ------- Either a dask.array.Array or _ElementwiseFunctionArray. """ - if is_duck_dask_array(array): - import dask.array as da + if is_chunked_array(array): + chunkmanager = get_chunked_array_type(array) - return da.map_blocks(func, array, dtype=dtype) + return chunkmanager.map_blocks(func, array, dtype=dtype) else: return _ElementwiseFunctionArray(array, func, dtype) @@ -330,7 +331,7 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable: if "scale_factor" in encoding or "add_offset" in encoding: dtype = _choose_float_dtype(data.dtype, "add_offset" in encoding) - data = data.astype(dtype=dtype, copy=True) + data = duck_array_ops.astype(data, dtype=dtype, copy=True) if "add_offset" in encoding: data -= pop_to(encoding, attrs, "add_offset", name=name) if "scale_factor" in encoding: @@ -377,7 +378,7 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable: if "_FillValue" in attrs: new_fill = signed_dtype.type(attrs["_FillValue"]) attrs["_FillValue"] = new_fill - data = duck_array_ops.around(data).astype(signed_dtype) + data = duck_array_ops.astype(duck_array_ops.around(data), signed_dtype) return Variable(dims, data, attrs, encoding, fastpath=True) else: diff --git a/xarray/core/common.py b/xarray/core/common.py index f6abcba1ff0..397d6de226a 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -13,8 +13,9 @@ from xarray.core import dtypes, duck_array_ops, formatting, formatting_html, ops from xarray.core.indexing import BasicIndexer, ExplicitlyIndexed from xarray.core.options import OPTIONS, _get_keep_attrs +from xarray.core.parallelcompat import get_chunked_array_type, guess_chunkmanager from xarray.core.pdcompat import _convert_base_to_offset -from xarray.core.pycompat import is_duck_dask_array +from xarray.core.pycompat import is_chunked_array from xarray.core.utils import ( Frozen, either_dict_or_kwargs, @@ -46,6 +47,7 @@ DTypeLikeSave, ScalarOrArray, SideOptions, + T_Chunks, T_DataWithCoords, T_Variable, ) @@ -159,7 +161,7 @@ def __int__(self: Any) -> int: def __complex__(self: Any) -> complex: return complex(self.values) - def __array__(self: Any, dtype: DTypeLike = None) -> np.ndarray: + def __array__(self: Any, dtype: DTypeLike | None = None) -> np.ndarray: return np.asarray(self.values, dtype=dtype) def __repr__(self) -> str: @@ -1396,28 +1398,52 @@ def __getitem__(self, value): @overload def full_like( - other: DataArray, fill_value: Any, dtype: DTypeLikeSave = None + other: DataArray, + fill_value: Any, + dtype: DTypeLikeSave | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, ) -> DataArray: ... @overload def full_like( - other: Dataset, fill_value: Any, dtype: DTypeMaybeMapping = None + other: Dataset, + fill_value: Any, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, ) -> Dataset: ... @overload def full_like( - other: Variable, fill_value: Any, dtype: DTypeLikeSave = None + other: Variable, + fill_value: Any, + dtype: DTypeLikeSave | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, ) -> Variable: ... @overload def full_like( - other: Dataset | DataArray, fill_value: Any, dtype: DTypeMaybeMapping = None + other: Dataset | DataArray, + fill_value: Any, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = {}, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, ) -> Dataset | DataArray: ... @@ -1426,7 +1452,11 @@ def full_like( def full_like( other: Dataset | DataArray | Variable, fill_value: Any, - dtype: DTypeMaybeMapping = None, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, ) -> Dataset | DataArray | Variable: ... @@ -1434,9 +1464,16 @@ def full_like( def full_like( other: Dataset | DataArray | Variable, fill_value: Any, - dtype: DTypeMaybeMapping = None, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, ) -> Dataset | DataArray | Variable: - """Return a new object with the same shape and type as a given object. + """ + Return a new object with the same shape and type as a given object. + + Returned object will be chunked if if the given object is chunked, or if chunks or chunked_array_type are specified. Parameters ---------- @@ -1449,6 +1486,18 @@ def full_like( dtype : dtype or dict-like of dtype, optional dtype of the new array. If a dict-like, maps dtypes to variables. If omitted, it defaults to other.dtype. + chunks : int, "auto", tuple of int or mapping of Hashable to int, optional + Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, ``(5, 5)`` or + ``{"x": 5, "y": 5}``. + chunked_array_type: str, optional + Which chunked array type to coerce the underlying data array to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEnetryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict, optional + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example, with dask as the default chunked array type, this method would pass additional kwargs + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. Returns ------- @@ -1562,7 +1611,12 @@ def full_like( data_vars = { k: _full_like_variable( - v.variable, fill_value.get(k, dtypes.NA), dtype_.get(k, None) + v.variable, + fill_value.get(k, dtypes.NA), + dtype_.get(k, None), + chunks, + chunked_array_type, + from_array_kwargs, ) for k, v in other.data_vars.items() } @@ -1571,7 +1625,14 @@ def full_like( if isinstance(dtype, Mapping): raise ValueError("'dtype' cannot be dict-like when passing a DataArray") return DataArray( - _full_like_variable(other.variable, fill_value, dtype), + _full_like_variable( + other.variable, + fill_value, + dtype, + chunks, + chunked_array_type, + from_array_kwargs, + ), dims=other.dims, coords=other.coords, attrs=other.attrs, @@ -1580,13 +1641,20 @@ def full_like( elif isinstance(other, Variable): if isinstance(dtype, Mapping): raise ValueError("'dtype' cannot be dict-like when passing a Variable") - return _full_like_variable(other, fill_value, dtype) + return _full_like_variable( + other, fill_value, dtype, chunks, chunked_array_type, from_array_kwargs + ) else: raise TypeError("Expected DataArray, Dataset, or Variable") def _full_like_variable( - other: Variable, fill_value: Any, dtype: DTypeLike = None + other: Variable, + fill_value: Any, + dtype: DTypeLike | None = None, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, ) -> Variable: """Inner function of full_like, where other must be a variable""" from xarray.core.variable import Variable @@ -1594,13 +1662,28 @@ def _full_like_variable( if fill_value is dtypes.NA: fill_value = dtypes.get_fill_value(dtype if dtype is not None else other.dtype) - if is_duck_dask_array(other.data): - import dask.array + if ( + is_chunked_array(other.data) + or chunked_array_type is not None + or chunks is not None + ): + if chunked_array_type is None: + chunkmanager = get_chunked_array_type(other.data) + else: + chunkmanager = guess_chunkmanager(chunked_array_type) if dtype is None: dtype = other.dtype - data = dask.array.full( - other.shape, fill_value, dtype=dtype, chunks=other.data.chunks + + if from_array_kwargs is None: + from_array_kwargs = {} + + data = chunkmanager.array_api.full( + other.shape, + fill_value, + dtype=dtype, + chunks=chunks if chunks else other.data.chunks, + **from_array_kwargs, ) else: data = np.full_like(other.data, fill_value, dtype=dtype) @@ -1609,36 +1692,72 @@ def _full_like_variable( @overload -def zeros_like(other: DataArray, dtype: DTypeLikeSave = None) -> DataArray: +def zeros_like( + other: DataArray, + dtype: DTypeLikeSave | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> DataArray: ... @overload -def zeros_like(other: Dataset, dtype: DTypeMaybeMapping = None) -> Dataset: +def zeros_like( + other: Dataset, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Dataset: ... @overload -def zeros_like(other: Variable, dtype: DTypeLikeSave = None) -> Variable: +def zeros_like( + other: Variable, + dtype: DTypeLikeSave | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Variable: ... @overload def zeros_like( - other: Dataset | DataArray, dtype: DTypeMaybeMapping = None + other: Dataset | DataArray, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, ) -> Dataset | DataArray: ... @overload def zeros_like( - other: Dataset | DataArray | Variable, dtype: DTypeMaybeMapping = None + other: Dataset | DataArray | Variable, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, ) -> Dataset | DataArray | Variable: ... def zeros_like( - other: Dataset | DataArray | Variable, dtype: DTypeMaybeMapping = None + other: Dataset | DataArray | Variable, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, ) -> Dataset | DataArray | Variable: """Return a new object of zeros with the same shape and type as a given dataarray or dataset. @@ -1649,6 +1768,18 @@ def zeros_like( The reference object. The output will have the same dimensions and coordinates as this object. dtype : dtype, optional dtype of the new array. If omitted, it defaults to other.dtype. + chunks : int, "auto", tuple of int or mapping of Hashable to int, optional + Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, ``(5, 5)`` or + ``{"x": 5, "y": 5}``. + chunked_array_type: str, optional + Which chunked array type to coerce the underlying data array to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEnetryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict, optional + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example, with dask as the default chunked array type, this method would pass additional kwargs + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. Returns ------- @@ -1692,40 +1823,83 @@ def zeros_like( full_like """ - return full_like(other, 0, dtype) + return full_like( + other, + 0, + dtype, + chunks=chunks, + chunked_array_type=chunked_array_type, + from_array_kwargs=from_array_kwargs, + ) @overload -def ones_like(other: DataArray, dtype: DTypeLikeSave = None) -> DataArray: +def ones_like( + other: DataArray, + dtype: DTypeLikeSave | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> DataArray: ... @overload -def ones_like(other: Dataset, dtype: DTypeMaybeMapping = None) -> Dataset: +def ones_like( + other: Dataset, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Dataset: ... @overload -def ones_like(other: Variable, dtype: DTypeLikeSave = None) -> Variable: +def ones_like( + other: Variable, + dtype: DTypeLikeSave | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Variable: ... @overload def ones_like( - other: Dataset | DataArray, dtype: DTypeMaybeMapping = None + other: Dataset | DataArray, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, ) -> Dataset | DataArray: ... @overload def ones_like( - other: Dataset | DataArray | Variable, dtype: DTypeMaybeMapping = None + other: Dataset | DataArray | Variable, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, ) -> Dataset | DataArray | Variable: ... def ones_like( - other: Dataset | DataArray | Variable, dtype: DTypeMaybeMapping = None + other: Dataset | DataArray | Variable, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, ) -> Dataset | DataArray | Variable: """Return a new object of ones with the same shape and type as a given dataarray or dataset. @@ -1736,6 +1910,18 @@ def ones_like( The reference object. The output will have the same dimensions and coordinates as this object. dtype : dtype, optional dtype of the new array. If omitted, it defaults to other.dtype. + chunks : int, "auto", tuple of int or mapping of Hashable to int, optional + Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, ``(5, 5)`` or + ``{"x": 5, "y": 5}``. + chunked_array_type: str, optional + Which chunked array type to coerce the underlying data array to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEnetryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict, optional + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example, with dask as the default chunked array type, this method would pass additional kwargs + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. Returns ------- @@ -1771,7 +1957,14 @@ def ones_like( full_like """ - return full_like(other, 1, dtype) + return full_like( + other, + 1, + dtype, + chunks=chunks, + chunked_array_type=chunked_array_type, + from_array_kwargs=from_array_kwargs, + ) def get_chunksizes( diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 356f1029192..685307fc8c3 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -20,7 +20,8 @@ from xarray.core.indexes import Index, filter_indexes_from_coords from xarray.core.merge import merge_attrs, merge_coordinates_without_align from xarray.core.options import OPTIONS, _get_keep_attrs -from xarray.core.pycompat import is_duck_dask_array +from xarray.core.parallelcompat import get_chunked_array_type +from xarray.core.pycompat import is_chunked_array, is_duck_dask_array from xarray.core.types import Dims, T_DataArray from xarray.core.utils import is_dict_like, is_scalar from xarray.core.variable import Variable @@ -675,16 +676,18 @@ def apply_variable_ufunc( for arg, core_dims in zip(args, signature.input_core_dims) ] - if any(is_duck_dask_array(array) for array in input_data): + if any(is_chunked_array(array) for array in input_data): if dask == "forbidden": raise ValueError( - "apply_ufunc encountered a dask array on an " - "argument, but handling for dask arrays has not " + "apply_ufunc encountered a chunked array on an " + "argument, but handling for chunked arrays has not " "been enabled. Either set the ``dask`` argument " "or load your data into memory first with " "``.load()`` or ``.compute()``" ) elif dask == "parallelized": + chunkmanager = get_chunked_array_type(*input_data) + numpy_func = func if dask_gufunc_kwargs is None: @@ -697,7 +700,7 @@ def apply_variable_ufunc( for n, (data, core_dims) in enumerate( zip(input_data, signature.input_core_dims) ): - if is_duck_dask_array(data): + if is_chunked_array(data): # core dimensions cannot span multiple chunks for axis, dim in enumerate(core_dims, start=-len(core_dims)): if len(data.chunks[axis]) != 1: @@ -705,7 +708,7 @@ def apply_variable_ufunc( f"dimension {dim} on {n}th function argument to " "apply_ufunc with dask='parallelized' consists of " "multiple chunks, but is also a core dimension. To " - "fix, either rechunk into a single dask array chunk along " + "fix, either rechunk into a single array chunk along " f"this dimension, i.e., ``.chunk(dict({dim}=-1))``, or " "pass ``allow_rechunk=True`` in ``dask_gufunc_kwargs`` " "but beware that this may significantly increase memory usage." @@ -732,9 +735,7 @@ def apply_variable_ufunc( ) def func(*arrays): - import dask.array as da - - res = da.apply_gufunc( + res = chunkmanager.apply_gufunc( numpy_func, signature.to_gufunc_string(exclude_dims), *arrays, @@ -749,8 +750,7 @@ def func(*arrays): pass else: raise ValueError( - "unknown setting for dask array handling in " - "apply_ufunc: {}".format(dask) + "unknown setting for chunked array handling in " f"apply_ufunc: {dask}" ) else: if vectorize: @@ -812,7 +812,7 @@ def func(*arrays): def apply_array_ufunc(func, *args, dask="forbidden"): """Apply a ndarray level function over ndarray objects.""" - if any(is_duck_dask_array(arg) for arg in args): + if any(is_chunked_array(arg) for arg in args): if dask == "forbidden": raise ValueError( "apply_ufunc encountered a dask array on an " @@ -2013,7 +2013,7 @@ def to_floatable(x: DataArray) -> DataArray: ) elif x.dtype.kind == "m": # timedeltas - return x.astype(float) + return duck_array_ops.astype(x, dtype=float) return x if isinstance(data, Dataset): @@ -2061,12 +2061,11 @@ def _calc_idxminmax( # This will run argmin or argmax. indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna) - # Handle dask arrays. - if is_duck_dask_array(array.data): - import dask.array - + # Handle chunked arrays (e.g. dask). + if is_chunked_array(array.data): + chunkmanager = get_chunked_array_type(array.data) chunks = dict(zip(array.dims, array.chunks)) - dask_coord = dask.array.from_array(array[dim].data, chunks=chunks[dim]) + dask_coord = chunkmanager.from_array(array[dim].data, chunks=chunks[dim]) res = indx.copy(data=dask_coord[indx.data.ravel()].reshape(indx.shape)) # we need to attach back the dim name res.name = dim @@ -2153,16 +2152,14 @@ def unify_chunks(*objects: Dataset | DataArray) -> tuple[Dataset | DataArray, .. if not unify_chunks_args: return objects - # Run dask.array.core.unify_chunks - from dask.array.core import unify_chunks - - _, dask_data = unify_chunks(*unify_chunks_args) - dask_data_iter = iter(dask_data) + chunkmanager = get_chunked_array_type(*[arg for arg in unify_chunks_args]) + _, chunked_data = chunkmanager.unify_chunks(*unify_chunks_args) + chunked_data_iter = iter(chunked_data) out: list[Dataset | DataArray] = [] for obj, ds in zip(objects, datasets): for k, v in ds._variables.items(): if v.chunks is not None: - ds._variables[k] = v.copy(data=next(dask_data_iter)) + ds._variables[k] = v.copy(data=next(chunked_data_iter)) out.append(obj._from_temp_dataset(ds) if isinstance(obj, DataArray) else ds) return tuple(out) diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index 24c5f698a27..d2d3e4a6d1c 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -1,9 +1,5 @@ from __future__ import annotations -from functools import partial - -from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined] - from xarray.core import dtypes, nputils @@ -96,36 +92,3 @@ def _fill_with_last_one(a, b): axis=axis, dtype=array.dtype, ) - - -def _first_last_wrapper(array, *, axis, op, keepdims): - return op(array, axis, keepdims=keepdims) - - -def _first_or_last(darray, axis, op): - import dask.array - - # This will raise the same error message seen for numpy - axis = normalize_axis_index(axis, darray.ndim) - - wrapped_op = partial(_first_last_wrapper, op=op) - return dask.array.reduction( - darray, - chunk=wrapped_op, - aggregate=wrapped_op, - axis=axis, - dtype=darray.dtype, - keepdims=False, # match numpy version - ) - - -def nanfirst(darray, axis): - from xarray.core.duck_array_ops import nanfirst - - return _first_or_last(darray, axis, op=nanfirst) - - -def nanlast(darray, axis): - from xarray.core.duck_array_ops import nanlast - - return _first_or_last(darray, axis, op=nanlast) diff --git a/xarray/core/daskmanager.py b/xarray/core/daskmanager.py new file mode 100644 index 00000000000..56d8dc9e23a --- /dev/null +++ b/xarray/core/daskmanager.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from typing import TYPE_CHECKING, Any, Callable + +import numpy as np +from packaging.version import Version + +from xarray.core.duck_array_ops import dask_available +from xarray.core.indexing import ImplicitToExplicitIndexingAdapter +from xarray.core.parallelcompat import ChunkManagerEntrypoint, T_ChunkedArray +from xarray.core.pycompat import is_duck_dask_array + +if TYPE_CHECKING: + from xarray.core.types import DaskArray, T_Chunks, T_NormalizedChunks + + +class DaskManager(ChunkManagerEntrypoint["DaskArray"]): + array_cls: type[DaskArray] + available: bool = dask_available + + def __init__(self) -> None: + # TODO can we replace this with a class attribute instead? + + from dask.array import Array + + self.array_cls = Array + + def is_chunked_array(self, data: Any) -> bool: + return is_duck_dask_array(data) + + def chunks(self, data: DaskArray) -> T_NormalizedChunks: + return data.chunks + + def normalize_chunks( + self, + chunks: T_Chunks | T_NormalizedChunks, + shape: tuple[int, ...] | None = None, + limit: int | None = None, + dtype: np.dtype | None = None, + previous_chunks: T_NormalizedChunks | None = None, + ) -> T_NormalizedChunks: + """Called by open_dataset""" + from dask.array.core import normalize_chunks + + return normalize_chunks( + chunks, + shape=shape, + limit=limit, + dtype=dtype, + previous_chunks=previous_chunks, + ) + + def from_array(self, data: Any, chunks, **kwargs) -> DaskArray: + import dask.array as da + + if isinstance(data, ImplicitToExplicitIndexingAdapter): + # lazily loaded backend array classes should use NumPy array operations. + kwargs["meta"] = np.ndarray + + return da.from_array( + data, + chunks, + **kwargs, + ) + + def compute(self, *data: DaskArray, **kwargs) -> tuple[np.ndarray, ...]: + from dask.array import compute + + return compute(*data, **kwargs) + + @property + def array_api(self) -> Any: + from dask import array as da + + return da + + def reduction( + self, + arr: T_ChunkedArray, + func: Callable, + combine_func: Callable | None = None, + aggregate_func: Callable | None = None, + axis: int | Sequence[int] | None = None, + dtype: np.dtype | None = None, + keepdims: bool = False, + ) -> T_ChunkedArray: + from dask.array import reduction + + return reduction( + arr, + chunk=func, + combine=combine_func, + aggregate=aggregate_func, + axis=axis, + dtype=dtype, + keepdims=keepdims, + ) + + def apply_gufunc( + self, + func: Callable, + signature: str, + *args: Any, + axes: Sequence[tuple[int, ...]] | None = None, + axis: int | None = None, + keepdims: bool = False, + output_dtypes: Sequence[np.typing.DTypeLike] | None = None, + output_sizes: dict[str, int] | None = None, + vectorize: bool | None = None, + allow_rechunk: bool = False, + meta: tuple[np.ndarray, ...] | None = None, + **kwargs, + ): + from dask.array.gufunc import apply_gufunc + + return apply_gufunc( + func, + signature, + *args, + axes=axes, + axis=axis, + keepdims=keepdims, + output_dtypes=output_dtypes, + output_sizes=output_sizes, + vectorize=vectorize, + allow_rechunk=allow_rechunk, + meta=meta, + **kwargs, + ) + + def map_blocks( + self, + func: Callable, + *args: Any, + dtype: np.typing.DTypeLike | None = None, + chunks: tuple[int, ...] | None = None, + drop_axis: int | Sequence[int] | None = None, + new_axis: int | Sequence[int] | None = None, + **kwargs, + ): + import dask + from dask.array import map_blocks + + if drop_axis is None and Version(dask.__version__) < Version("2022.9.1"): + # See https://github.com/pydata/xarray/pull/7019#discussion_r1196729489 + # TODO remove once dask minimum version >= 2022.9.1 + drop_axis = [] + + # pass through name, meta, token as kwargs + return map_blocks( + func, + *args, + dtype=dtype, + chunks=chunks, + drop_axis=drop_axis, + new_axis=new_axis, + **kwargs, + ) + + def blockwise( + self, + func: Callable, + out_ind: Iterable, + *args: Any, + # can't type this as mypy assumes args are all same type, but dask blockwise args alternate types + name: str | None = None, + token=None, + dtype: np.dtype | None = None, + adjust_chunks: dict[Any, Callable] | None = None, + new_axes: dict[Any, int] | None = None, + align_arrays: bool = True, + concatenate: bool | None = None, + meta=None, + **kwargs, + ): + from dask.array import blockwise + + return blockwise( + func, + out_ind, + *args, + name=name, + token=token, + dtype=dtype, + adjust_chunks=adjust_chunks, + new_axes=new_axes, + align_arrays=align_arrays, + concatenate=concatenate, + meta=meta, + **kwargs, + ) + + def unify_chunks( + self, + *args: Any, # can't type this as mypy assumes args are all same type, but dask unify_chunks args alternate types + **kwargs, + ) -> tuple[dict[str, T_NormalizedChunks], list[DaskArray]]: + from dask.array.core import unify_chunks + + return unify_chunks(*args, **kwargs) + + def store( + self, + sources: DaskArray | Sequence[DaskArray], + targets: Any, + **kwargs, + ): + from dask.array import store + + return store( + sources=sources, + targets=targets, + **kwargs, + ) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 2f663c4936a..bc2450abc9d 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -77,6 +77,7 @@ from xarray.backends import ZarrStore from xarray.backends.api import T_NetcdfEngine, T_NetcdfTypes from xarray.core.groupby import DataArrayGroupBy + from xarray.core.parallelcompat import ChunkManagerEntrypoint from xarray.core.resample import DataArrayResample from xarray.core.rolling import DataArrayCoarsen, DataArrayRolling from xarray.core.types import ( @@ -1264,6 +1265,8 @@ def chunk( token: str | None = None, lock: bool = False, inline_array: bool = False, + chunked_array_type: str | ChunkManagerEntrypoint | None = None, + from_array_kwargs=None, **chunks_kwargs: Any, ) -> T_DataArray: """Coerce this array's data into a dask arrays with the given chunks. @@ -1285,12 +1288,21 @@ def chunk( Prefix for the name of the new dask array. token : str, optional Token uniquely identifying this array. - lock : optional + lock : bool, default: False Passed on to :py:func:`dask.array.from_array`, if the array is not already as dask array. - inline_array: optional + inline_array: bool, default: False Passed on to :py:func:`dask.array.from_array`, if the array is not already as dask array. + chunked_array_type: str, optional + Which chunked array type to coerce the underlying data array to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEntryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict, optional + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example, with dask as the default chunked array type, this method would pass additional kwargs + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. **chunks_kwargs : {dim: chunks, ...}, optional The keyword arguments form of ``chunks``. One of chunks or chunks_kwargs must be provided. @@ -1328,6 +1340,8 @@ def chunk( token=token, lock=lock, inline_array=inline_array, + chunked_array_type=chunked_array_type, + from_array_kwargs=from_array_kwargs, ) return self._from_temp_dataset(ds) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 2336883d0b7..d2ecd65ba58 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -51,6 +51,7 @@ ) from xarray.core.computation import unify_chunks from xarray.core.coordinates import DatasetCoordinates, assert_coordinate_consistent +from xarray.core.daskmanager import DaskManager from xarray.core.duck_array_ops import datetime_to_numeric from xarray.core.indexes import ( Index, @@ -73,7 +74,16 @@ ) from xarray.core.missing import get_clean_interp_index from xarray.core.options import OPTIONS, _get_keep_attrs -from xarray.core.pycompat import array_type, is_duck_array, is_duck_dask_array +from xarray.core.parallelcompat import ( + get_chunked_array_type, + guess_chunkmanager, +) +from xarray.core.pycompat import ( + array_type, + is_chunked_array, + is_duck_array, + is_duck_dask_array, +) from xarray.core.types import QuantileMethods, T_Dataset from xarray.core.utils import ( Default, @@ -107,6 +117,7 @@ from xarray.core.dataarray import DataArray from xarray.core.groupby import DatasetGroupBy from xarray.core.merge import CoercibleMapping + from xarray.core.parallelcompat import ChunkManagerEntrypoint from xarray.core.resample import DatasetResample from xarray.core.rolling import DatasetCoarsen, DatasetRolling from xarray.core.types import ( @@ -202,13 +213,11 @@ def _assert_empty(args: tuple, msg: str = "%s") -> None: raise ValueError(msg % args) -def _get_chunk(var, chunks): +def _get_chunk(var: Variable, chunks, chunkmanager: ChunkManagerEntrypoint): """ Return map from each dim to chunk sizes, accounting for backend's preferred chunks. """ - import dask.array as da - if isinstance(var, IndexVariable): return {} dims = var.dims @@ -225,7 +234,8 @@ def _get_chunk(var, chunks): chunks.get(dim, None) or preferred_chunk_sizes for dim, preferred_chunk_sizes in zip(dims, preferred_chunk_shape) ) - chunk_shape = da.core.normalize_chunks( + + chunk_shape = chunkmanager.normalize_chunks( chunk_shape, shape=shape, dtype=var.dtype, previous_chunks=preferred_chunk_shape ) @@ -242,7 +252,7 @@ def _get_chunk(var, chunks): # expresses the preferred chunks, the sequence sums to the size. preferred_stops = ( range(preferred_chunk_sizes, size, preferred_chunk_sizes) - if isinstance(preferred_chunk_sizes, Number) + if isinstance(preferred_chunk_sizes, int) else itertools.accumulate(preferred_chunk_sizes[:-1]) ) # Gather any stop indices of the specified chunks that are not a stop index @@ -253,7 +263,7 @@ def _get_chunk(var, chunks): ) if breaks: warnings.warn( - "The specified Dask chunks separate the stored chunks along " + "The specified chunks separate the stored chunks along " f'dimension "{dim}" starting at index {min(breaks)}. This could ' "degrade performance. Instead, consider rechunking after loading." ) @@ -270,18 +280,37 @@ def _maybe_chunk( name_prefix="xarray-", overwrite_encoded_chunks=False, inline_array=False, + chunked_array_type: str | ChunkManagerEntrypoint | None = None, + from_array_kwargs=None, ): - from dask.base import tokenize - if chunks is not None: chunks = {dim: chunks[dim] for dim in var.dims if dim in chunks} + if var.ndim: - # when rechunking by different amounts, make sure dask names change - # by provinding chunks as an input to tokenize. - # subtle bugs result otherwise. see GH3350 - token2 = tokenize(name, token if token else var._data, chunks) - name2 = f"{name_prefix}{name}-{token2}" - var = var.chunk(chunks, name=name2, lock=lock, inline_array=inline_array) + chunked_array_type = guess_chunkmanager( + chunked_array_type + ) # coerce string to ChunkManagerEntrypoint type + if isinstance(chunked_array_type, DaskManager): + from dask.base import tokenize + + # when rechunking by different amounts, make sure dask names change + # by providing chunks as an input to tokenize. + # subtle bugs result otherwise. see GH3350 + token2 = tokenize(name, token if token else var._data, chunks) + name2 = f"{name_prefix}{name}-{token2}" + + from_array_kwargs = utils.consolidate_dask_from_array_kwargs( + from_array_kwargs, + name=name2, + lock=lock, + inline_array=inline_array, + ) + + var = var.chunk( + chunks, + chunked_array_type=chunked_array_type, + from_array_kwargs=from_array_kwargs, + ) if overwrite_encoded_chunks and var.chunks is not None: var.encoding["chunks"] = tuple(x[0] for x in var.chunks) @@ -743,13 +772,13 @@ def load(self: T_Dataset, **kwargs) -> T_Dataset: """ # access .data to coerce everything to numpy or dask arrays lazy_data = { - k: v._data for k, v in self.variables.items() if is_duck_dask_array(v._data) + k: v._data for k, v in self.variables.items() if is_chunked_array(v._data) } if lazy_data: - import dask.array as da + chunkmanager = get_chunked_array_type(*lazy_data.values()) - # evaluate all the dask arrays simultaneously - evaluated_data = da.compute(*lazy_data.values(), **kwargs) + # evaluate all the chunked arrays simultaneously + evaluated_data = chunkmanager.compute(*lazy_data.values(), **kwargs) for k, data in zip(lazy_data, evaluated_data): self.variables[k].data = data @@ -1575,7 +1604,7 @@ def _setitem_check(self, key, value): val = np.array(val) # type conversion - new_value[name] = val.astype(var_k.dtype, copy=False) + new_value[name] = duck_array_ops.astype(val, dtype=var_k.dtype, copy=False) # check consistency of dimension sizes and dimension coordinates if isinstance(value, DataArray) or isinstance(value, Dataset): @@ -1945,6 +1974,7 @@ def to_zarr( safe_chunks: bool = True, storage_options: dict[str, str] | None = None, zarr_version: int | None = None, + chunkmanager_store_kwargs: dict[str, Any] | None = None, ) -> ZarrStore: ... @@ -1966,6 +1996,7 @@ def to_zarr( safe_chunks: bool = True, storage_options: dict[str, str] | None = None, zarr_version: int | None = None, + chunkmanager_store_kwargs: dict[str, Any] | None = None, ) -> Delayed: ... @@ -1984,6 +2015,7 @@ def to_zarr( safe_chunks: bool = True, storage_options: dict[str, str] | None = None, zarr_version: int | None = None, + chunkmanager_store_kwargs: dict[str, Any] | None = None, ) -> ZarrStore | Delayed: """Write dataset contents to a zarr group. @@ -2072,6 +2104,10 @@ def to_zarr( The desired zarr spec version to target (currently 2 or 3). The default of None will attempt to determine the zarr version from ``store`` when possible, otherwise defaulting to 2. + chunkmanager_store_kwargs : dict, optional + Additional keyword arguments passed on to the `ChunkManager.store` method used to store + chunked arrays. For example for a dask array additional kwargs will be passed eventually to + :py:func:`dask.array.store()`. Experimental API that should not be relied upon. Returns ------- @@ -2117,6 +2153,7 @@ def to_zarr( region=region, safe_chunks=safe_chunks, zarr_version=zarr_version, + chunkmanager_store_kwargs=chunkmanager_store_kwargs, ) def __repr__(self) -> str: @@ -2205,6 +2242,8 @@ def chunk( token: str | None = None, lock: bool = False, inline_array: bool = False, + chunked_array_type: str | ChunkManagerEntrypoint | None = None, + from_array_kwargs=None, **chunks_kwargs: None | int | str | tuple[int, ...], ) -> T_Dataset: """Coerce all arrays in this dataset into dask arrays with the given @@ -2232,6 +2271,15 @@ def chunk( inline_array: bool, default: False Passed on to :py:func:`dask.array.from_array`, if the array is not already as dask array. + chunked_array_type: str, optional + Which chunked array type to coerce this datasets' arrays to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEnetryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict, optional + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example, with dask as the default chunked array type, this method would pass additional kwargs + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. **chunks_kwargs : {dim: chunks, ...}, optional The keyword arguments form of ``chunks``. One of chunks or chunks_kwargs must be provided @@ -2266,8 +2314,22 @@ def chunk( f"some chunks keys are not dimensions on this object: {bad_dims}" ) + chunkmanager = guess_chunkmanager(chunked_array_type) + if from_array_kwargs is None: + from_array_kwargs = {} + variables = { - k: _maybe_chunk(k, v, chunks, token, lock, name_prefix) + k: _maybe_chunk( + k, + v, + chunks, + token, + lock, + name_prefix, + inline_array=inline_array, + chunked_array_type=chunkmanager, + from_array_kwargs=from_array_kwargs.copy(), + ) for k, v in self.variables.items() } return self._replace(variables) @@ -2305,7 +2367,7 @@ def _validate_indexers( if v.dtype.kind in "US": index = self._indexes[k].to_pandas_index() if isinstance(index, pd.DatetimeIndex): - v = v.astype("datetime64[ns]") + v = duck_array_ops.astype(v, dtype="datetime64[ns]") elif isinstance(index, CFTimeIndex): v = _parse_array_of_cftime_strings(v, index.date_type) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 84e66803fe8..4d7998e1475 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -9,6 +9,7 @@ import datetime import inspect import warnings +from functools import partial from importlib import import_module import numpy as np @@ -29,10 +30,11 @@ zeros_like, # noqa ) from numpy import concatenate as _concatenate +from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined] from numpy.lib.stride_tricks import sliding_window_view # noqa from xarray.core import dask_array_ops, dtypes, nputils -from xarray.core.nputils import nanfirst, nanlast +from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array from xarray.core.pycompat import array_type, is_duck_dask_array from xarray.core.utils import is_duck_array, module_available @@ -640,10 +642,10 @@ def first(values, axis, skipna=None): """Return the first non-NA elements in this array along the given axis""" if (skipna or skipna is None) and values.dtype.kind not in "iSU": # only bother for dtypes that can hold NaN - if is_duck_dask_array(values): - return dask_array_ops.nanfirst(values, axis) + if is_chunked_array(values): + return chunked_nanfirst(values, axis) else: - return nanfirst(values, axis) + return nputils.nanfirst(values, axis) return take(values, 0, axis=axis) @@ -651,10 +653,10 @@ def last(values, axis, skipna=None): """Return the last non-NA elements in this array along the given axis""" if (skipna or skipna is None) and values.dtype.kind not in "iSU": # only bother for dtypes that can hold NaN - if is_duck_dask_array(values): - return dask_array_ops.nanlast(values, axis) + if is_chunked_array(values): + return chunked_nanlast(values, axis) else: - return nanlast(values, axis) + return nputils.nanlast(values, axis) return take(values, -1, axis=axis) @@ -673,3 +675,32 @@ def push(array, n, axis): return dask_array_ops.push(array, n, axis) else: return push(array, n, axis) + + +def _first_last_wrapper(array, *, axis, op, keepdims): + return op(array, axis, keepdims=keepdims) + + +def _chunked_first_or_last(darray, axis, op): + chunkmanager = get_chunked_array_type(darray) + + # This will raise the same error message seen for numpy + axis = normalize_axis_index(axis, darray.ndim) + + wrapped_op = partial(_first_last_wrapper, op=op) + return chunkmanager.reduction( + darray, + func=wrapped_op, + aggregate_func=wrapped_op, + axis=axis, + dtype=darray.dtype, + keepdims=False, # match numpy version + ) + + +def chunked_nanfirst(darray, axis): + return _chunked_first_or_last(darray, axis, op=nputils.nanfirst) + + +def chunked_nanlast(darray, axis): + return _chunked_first_or_last(darray, axis, op=nputils.nanlast) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 35a5261f248..acab9ccc60b 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -17,6 +17,7 @@ from xarray.core import duck_array_ops from xarray.core.nputils import NumpyVIndexAdapter from xarray.core.options import OPTIONS +from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array from xarray.core.pycompat import ( array_type, integer_types, @@ -1142,16 +1143,15 @@ def _arrayize_vectorized_indexer(indexer, shape): return VectorizedIndexer(tuple(new_key)) -def _dask_array_with_chunks_hint(array, chunks): - """Create a dask array using the chunks hint for dimensions of size > 1.""" - import dask.array as da +def _chunked_array_with_chunks_hint(array, chunks, chunkmanager): + """Create a chunked array using the chunks hint for dimensions of size > 1.""" if len(chunks) < array.ndim: raise ValueError("not enough chunks in hint") new_chunks = [] for chunk, size in zip(chunks, array.shape): new_chunks.append(chunk if size > 1 else (1,)) - return da.from_array(array, new_chunks) + return chunkmanager.from_array(array, new_chunks) def _logical_any(args): @@ -1165,8 +1165,11 @@ def _masked_result_drop_slice(key, data=None): new_keys = [] for k in key: if isinstance(k, np.ndarray): - if is_duck_dask_array(data): - new_keys.append(_dask_array_with_chunks_hint(k, chunks_hint)) + if is_chunked_array(data): + chunkmanager = get_chunked_array_type(data) + new_keys.append( + _chunked_array_with_chunks_hint(k, chunks_hint, chunkmanager) + ) elif isinstance(data, array_type("sparse")): import sparse diff --git a/xarray/core/missing.py b/xarray/core/missing.py index d7f0be5fa08..ee610999411 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -15,7 +15,7 @@ from xarray.core.computation import apply_ufunc from xarray.core.duck_array_ops import datetime_to_numeric, push, timedelta_to_numeric from xarray.core.options import OPTIONS, _get_keep_attrs -from xarray.core.pycompat import is_duck_dask_array +from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array from xarray.core.types import Interp1dOptions, InterpOptions from xarray.core.utils import OrderedSet, is_scalar from xarray.core.variable import Variable, broadcast_variables @@ -693,8 +693,8 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs): else: func, kwargs = _get_interpolator_nd(method, **kwargs) - if is_duck_dask_array(var): - import dask.array as da + if is_chunked_array(var): + chunkmanager = get_chunked_array_type(var) ndim = var.ndim nconst = ndim - len(x) @@ -716,7 +716,7 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs): *new_x_arginds, ) - _, rechunked = da.unify_chunks(*args) + _, rechunked = chunkmanager.unify_chunks(*args) args = tuple(elem for pair in zip(rechunked, args[1::2]) for elem in pair) @@ -741,8 +741,8 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs): meta = var._meta - return da.blockwise( - _dask_aware_interpnd, + return chunkmanager.blockwise( + _chunked_aware_interpnd, out_ind, *args, interp_func=func, @@ -785,8 +785,8 @@ def _interpnd(var, x, new_x, func, kwargs): return rslt.reshape(rslt.shape[:-1] + new_x[0].shape) -def _dask_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=True): - """Wrapper for `_interpnd` through `blockwise` +def _chunked_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=True): + """Wrapper for `_interpnd` through `blockwise` for chunked arrays. The first half arrays in `coords` are original coordinates, the other half are destination coordinates diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 022de845c4c..3b8ddfe032d 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -6,6 +6,7 @@ from xarray.core import dtypes, nputils, utils from xarray.core.duck_array_ops import ( + astype, count, fillna, isnull, @@ -22,7 +23,7 @@ def _maybe_null_out(result, axis, mask, min_count=1): if axis is not None and getattr(result, "ndim", False): null_mask = (np.take(mask.shape, axis).prod() - mask.sum(axis) - min_count) < 0 dtype, fill_value = dtypes.maybe_promote(result.dtype) - result = where(null_mask, fill_value, result.astype(dtype)) + result = where(null_mask, fill_value, astype(result, dtype)) elif getattr(result, "dtype", None) not in dtypes.NAT_TYPES: null_mask = mask.size - mask.sum() @@ -140,7 +141,7 @@ def _nanvar_object(value, axis=None, ddof=0, keepdims=False, **kwargs): value_mean = _nanmean_ddof_object( ddof=0, value=value, axis=axis, keepdims=True, **kwargs ) - squared = (value.astype(value_mean.dtype) - value_mean) ** 2 + squared = (astype(value, value_mean.dtype) - value_mean) ** 2 return _nanmean_ddof_object(ddof, squared, axis=axis, keepdims=keepdims, **kwargs) diff --git a/xarray/core/parallelcompat.py b/xarray/core/parallelcompat.py new file mode 100644 index 00000000000..4df4ff235c6 --- /dev/null +++ b/xarray/core/parallelcompat.py @@ -0,0 +1,280 @@ +""" +The code in this module is an experiment in going from N=1 to N=2 parallel computing frameworks in xarray. +It could later be used as the basis for a public interface allowing any N frameworks to interoperate with xarray, +but for now it is just a private experiment. +""" +from __future__ import annotations + +import functools +import sys +from abc import ABC, abstractmethod +from collections.abc import Iterable, Sequence +from importlib.metadata import EntryPoint, entry_points +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + TypeVar, +) + +import numpy as np + +from xarray.core.pycompat import is_chunked_array + +T_ChunkedArray = TypeVar("T_ChunkedArray") + +if TYPE_CHECKING: + from xarray.core.types import T_Chunks, T_NormalizedChunks + + +@functools.lru_cache(maxsize=1) +def list_chunkmanagers() -> dict[str, ChunkManagerEntrypoint]: + """ + Return a dictionary of available chunk managers and their ChunkManagerEntrypoint objects. + + Notes + ----- + # New selection mechanism introduced with Python 3.10. See GH6514. + """ + if sys.version_info >= (3, 10): + entrypoints = entry_points(group="xarray.chunkmanagers") + else: + entrypoints = entry_points().get("xarray.chunkmanagers", ()) + + return load_chunkmanagers(entrypoints) + + +def load_chunkmanagers( + entrypoints: Sequence[EntryPoint], +) -> dict[str, ChunkManagerEntrypoint]: + """Load entrypoints and instantiate chunkmanagers only once.""" + + loaded_entrypoints = { + entrypoint.name: entrypoint.load() for entrypoint in entrypoints + } + + available_chunkmanagers = { + name: chunkmanager() + for name, chunkmanager in loaded_entrypoints.items() + if chunkmanager.available + } + return available_chunkmanagers + + +def guess_chunkmanager( + manager: str | ChunkManagerEntrypoint | None, +) -> ChunkManagerEntrypoint: + """ + Get namespace of chunk-handling methods, guessing from what's available. + + If the name of a specific ChunkManager is given (e.g. "dask"), then use that. + Else use whatever is installed, defaulting to dask if there are multiple options. + """ + + chunkmanagers = list_chunkmanagers() + + if manager is None: + if len(chunkmanagers) == 1: + # use the only option available + manager = next(iter(chunkmanagers.keys())) + else: + # default to trying to use dask + manager = "dask" + + if isinstance(manager, str): + if manager not in chunkmanagers: + raise ValueError( + f"unrecognized chunk manager {manager} - must be one of: {list(chunkmanagers)}" + ) + + return chunkmanagers[manager] + elif isinstance(manager, ChunkManagerEntrypoint): + # already a valid ChunkManager so just pass through + return manager + else: + raise TypeError( + f"manager must be a string or instance of ChunkManagerEntrypoint, but received type {type(manager)}" + ) + + +def get_chunked_array_type(*args) -> ChunkManagerEntrypoint: + """ + Detects which parallel backend should be used for given set of arrays. + + Also checks that all arrays are of same chunking type (i.e. not a mix of cubed and dask). + """ + + # TODO this list is probably redundant with something inside xarray.apply_ufunc + ALLOWED_NON_CHUNKED_TYPES = {int, float, np.ndarray} + + chunked_arrays = [ + a + for a in args + if is_chunked_array(a) and type(a) not in ALLOWED_NON_CHUNKED_TYPES + ] + + # Asserts all arrays are the same type (or numpy etc.) + chunked_array_types = {type(a) for a in chunked_arrays} + if len(chunked_array_types) > 1: + raise TypeError( + f"Mixing chunked array types is not supported, but received multiple types: {chunked_array_types}" + ) + elif len(chunked_array_types) == 0: + raise TypeError("Expected a chunked array but none were found") + + # iterate over defined chunk managers, seeing if each recognises this array type + chunked_arr = chunked_arrays[0] + chunkmanagers = list_chunkmanagers() + selected = [ + chunkmanager + for chunkmanager in chunkmanagers.values() + if chunkmanager.is_chunked_array(chunked_arr) + ] + if not selected: + raise TypeError( + f"Could not find a Chunk Manager which recognises type {type(chunked_arr)}" + ) + elif len(selected) >= 2: + raise TypeError(f"Multiple ChunkManagers recognise type {type(chunked_arr)}") + else: + return selected[0] + + +class ChunkManagerEntrypoint(ABC, Generic[T_ChunkedArray]): + """ + Adapter between a particular parallel computing framework and xarray. + + Attributes + ---------- + array_cls + Type of the array class this parallel computing framework provides. + + Parallel frameworks need to provide an array class that supports the array API standard. + Used for type checking. + """ + + array_cls: type[T_ChunkedArray] + available: bool = True + + @abstractmethod + def __init__(self) -> None: + raise NotImplementedError() + + def is_chunked_array(self, data: Any) -> bool: + return isinstance(data, self.array_cls) + + @abstractmethod + def chunks(self, data: T_ChunkedArray) -> T_NormalizedChunks: + raise NotImplementedError() + + @abstractmethod + def normalize_chunks( + self, + chunks: T_Chunks | T_NormalizedChunks, + shape: tuple[int, ...] | None = None, + limit: int | None = None, + dtype: np.dtype | None = None, + previous_chunks: T_NormalizedChunks | None = None, + ) -> T_NormalizedChunks: + """Called by open_dataset""" + raise NotImplementedError() + + @abstractmethod + def from_array( + self, data: np.ndarray, chunks: T_Chunks, **kwargs + ) -> T_ChunkedArray: + """Called when .chunk is called on an xarray object that is not already chunked.""" + raise NotImplementedError() + + def rechunk( + self, + data: T_ChunkedArray, + chunks: T_NormalizedChunks | tuple[int, ...] | T_Chunks, + **kwargs, + ) -> T_ChunkedArray: + """Called when .chunk is called on an xarray object that is already chunked.""" + return data.rechunk(chunks, **kwargs) # type: ignore[attr-defined] + + @abstractmethod + def compute(self, *data: T_ChunkedArray, **kwargs) -> tuple[np.ndarray, ...]: + """Used anytime something needs to computed, including multiple arrays at once.""" + raise NotImplementedError() + + @property + def array_api(self) -> Any: + """Return the array_api namespace following the python array API standard.""" + raise NotImplementedError() + + def reduction( + self, + arr: T_ChunkedArray, + func: Callable, + combine_func: Callable | None = None, + aggregate_func: Callable | None = None, + axis: int | Sequence[int] | None = None, + dtype: np.dtype | None = None, + keepdims: bool = False, + ) -> T_ChunkedArray: + """Used in some reductions like nanfirst, which is used by groupby.first""" + raise NotImplementedError() + + @abstractmethod + def apply_gufunc( + self, + func: Callable, + signature: str, + *args: Any, + axes: Sequence[tuple[int, ...]] | None = None, + keepdims: bool = False, + output_dtypes: Sequence[np.typing.DTypeLike] | None = None, + vectorize: bool | None = None, + **kwargs, + ): + """ + Called inside xarray.apply_ufunc, so must be supplied for vast majority of xarray computations to be supported. + """ + raise NotImplementedError() + + def map_blocks( + self, + func: Callable, + *args: Any, + dtype: np.typing.DTypeLike | None = None, + chunks: tuple[int, ...] | None = None, + drop_axis: int | Sequence[int] | None = None, + new_axis: int | Sequence[int] | None = None, + **kwargs, + ): + """Called in elementwise operations, but notably not called in xarray.map_blocks.""" + raise NotImplementedError() + + def blockwise( + self, + func: Callable, + out_ind: Iterable, + *args: Any, # can't type this as mypy assumes args are all same type, but dask blockwise args alternate types + adjust_chunks: dict[Any, Callable] | None = None, + new_axes: dict[Any, int] | None = None, + align_arrays: bool = True, + **kwargs, + ): + """Called by some niche functions in xarray.""" + raise NotImplementedError() + + def unify_chunks( + self, + *args: Any, # can't type this as mypy assumes args are all same type, but dask unify_chunks args alternate types + **kwargs, + ) -> tuple[dict[str, T_NormalizedChunks], list[T_ChunkedArray]]: + """Called by xr.unify_chunks.""" + raise NotImplementedError() + + def store( + self, + sources: T_ChunkedArray | Sequence[T_ChunkedArray], + targets: Any, + **kwargs: dict[str, Any], + ): + """Used when writing to any backend.""" + raise NotImplementedError() diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index 4a3f3638d14..db77ef56fd1 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -12,7 +12,7 @@ integer_types = (int, np.integer) if TYPE_CHECKING: - ModType = Literal["dask", "pint", "cupy", "sparse"] + ModType = Literal["dask", "pint", "cupy", "sparse", "cubed"] DuckArrayTypes = tuple[type[Any], ...] # TODO: improve this? maybe Generic @@ -30,7 +30,7 @@ class DuckArrayModule: available: bool def __init__(self, mod: ModType) -> None: - duck_array_module: ModuleType | None = None + duck_array_module: ModuleType | None duck_array_version: Version duck_array_type: DuckArrayTypes try: @@ -45,6 +45,8 @@ def __init__(self, mod: ModType) -> None: duck_array_type = (duck_array_module.ndarray,) elif mod == "sparse": duck_array_type = (duck_array_module.SparseArray,) + elif mod == "cubed": + duck_array_type = (duck_array_module.Array,) else: raise NotImplementedError @@ -81,5 +83,9 @@ def is_duck_dask_array(x): return is_duck_array(x) and is_dask_collection(x) +def is_chunked_array(x) -> bool: + return is_duck_dask_array(x) or (is_duck_array(x) and hasattr(x, "chunks")) + + def is_0d_dask_array(x): return is_duck_dask_array(x) and is_scalar(x) diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 7eb4e9c7687..916fabe42ac 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -158,9 +158,9 @@ def method(self, keep_attrs=None, **kwargs): return method def _mean(self, keep_attrs, **kwargs): - result = self.sum(keep_attrs=False, **kwargs) / self.count( - keep_attrs=False - ).astype(self.obj.dtype, copy=False) + result = self.sum(keep_attrs=False, **kwargs) / duck_array_ops.astype( + self.count(keep_attrs=False), dtype=self.obj.dtype, copy=False + ) if keep_attrs: result.attrs = self.obj.attrs return result diff --git a/xarray/core/types.py b/xarray/core/types.py index 0f11b16b003..f3342071107 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -33,6 +33,16 @@ except ImportError: DaskArray = np.ndarray # type: ignore + try: + from cubed import Array as CubedArray + except ImportError: + CubedArray = np.ndarray + + try: + from zarr.core import Array as ZarrArray + except ImportError: + ZarrArray = np.ndarray + # TODO: Turn on when https://github.com/python/mypy/issues/11871 is fixed. # Can be uncommented if using pyright though. # import sys @@ -105,6 +115,9 @@ Dims = Union[str, Iterable[Hashable], "ellipsis", None] OrderedDims = Union[str, Sequence[Union[Hashable, "ellipsis"]], "ellipsis", None] +T_Chunks = Union[int, dict[Any, Any], Literal["auto"], None] +T_NormalizedChunks = tuple[tuple[int, ...], ...] + ErrorOptions = Literal["raise", "ignore"] ErrorOptionsWithWarn = Literal["raise", "warn", "ignore"] diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 1c90a2410f2..6ed0b2c4318 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -1202,3 +1202,66 @@ def emit_user_level_warning(message, category=None): """Emit a warning at the user level by inspecting the stack trace.""" stacklevel = find_stack_level() warnings.warn(message, category=category, stacklevel=stacklevel) + + +def consolidate_dask_from_array_kwargs( + from_array_kwargs: dict, + name: str | None = None, + lock: bool | None = None, + inline_array: bool | None = None, +) -> dict: + """ + Merge dask-specific kwargs with arbitrary from_array_kwargs dict. + + Temporary function, to be deleted once explicitly passing dask-specific kwargs to .chunk() is deprecated. + """ + + from_array_kwargs = _resolve_doubly_passed_kwarg( + from_array_kwargs, + kwarg_name="name", + passed_kwarg_value=name, + default=None, + err_msg_dict_name="from_array_kwargs", + ) + from_array_kwargs = _resolve_doubly_passed_kwarg( + from_array_kwargs, + kwarg_name="lock", + passed_kwarg_value=lock, + default=False, + err_msg_dict_name="from_array_kwargs", + ) + from_array_kwargs = _resolve_doubly_passed_kwarg( + from_array_kwargs, + kwarg_name="inline_array", + passed_kwarg_value=inline_array, + default=False, + err_msg_dict_name="from_array_kwargs", + ) + + return from_array_kwargs + + +def _resolve_doubly_passed_kwarg( + kwargs_dict: dict, + kwarg_name: str, + passed_kwarg_value: str | bool | None, + default: bool | None, + err_msg_dict_name: str, +) -> dict: + # if in kwargs_dict but not passed explicitly then just pass kwargs_dict through unaltered + if kwarg_name in kwargs_dict and passed_kwarg_value is None: + pass + # if passed explicitly but not in kwargs_dict then use that + elif kwarg_name not in kwargs_dict and passed_kwarg_value is not None: + kwargs_dict[kwarg_name] = passed_kwarg_value + # if in neither then use default + elif kwarg_name not in kwargs_dict and passed_kwarg_value is None: + kwargs_dict[kwarg_name] = default + # if in both then raise + else: + raise ValueError( + f"argument {kwarg_name} cannot be passed both as a keyword argument and within " + f"the {err_msg_dict_name} dictionary" + ) + + return kwargs_dict diff --git a/xarray/core/variable.py b/xarray/core/variable.py index b4026ed325e..27a5399d639 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -26,10 +26,15 @@ as_indexable, ) from xarray.core.options import OPTIONS, _get_keep_attrs +from xarray.core.parallelcompat import ( + get_chunked_array_type, + guess_chunkmanager, +) from xarray.core.pycompat import ( array_type, integer_types, is_0d_dask_array, + is_chunked_array, is_duck_dask_array, ) from xarray.core.utils import ( @@ -54,6 +59,7 @@ BASIC_INDEXING_TYPES = integer_types + (slice,) if TYPE_CHECKING: + from xarray.core.parallelcompat import ChunkManagerEntrypoint from xarray.core.types import ( Dims, ErrorOptionsWithWarn, @@ -194,10 +200,10 @@ def _as_nanosecond_precision(data): nanosecond_precision_dtype = pd.DatetimeTZDtype("ns", dtype.tz) else: nanosecond_precision_dtype = "datetime64[ns]" - return data.astype(nanosecond_precision_dtype) + return duck_array_ops.astype(data, nanosecond_precision_dtype) elif dtype.kind == "m" and dtype != np.dtype("timedelta64[ns]"): utils.emit_user_level_warning(NON_NANOSECOND_WARNING.format(case="timedelta")) - return data.astype("timedelta64[ns]") + return duck_array_ops.astype(data, "timedelta64[ns]") else: return data @@ -367,7 +373,7 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False): self.encoding = encoding @property - def dtype(self): + def dtype(self) -> np.dtype: """ Data-type of the array’s elements. @@ -379,7 +385,7 @@ def dtype(self): return self._data.dtype @property - def shape(self): + def shape(self) -> tuple[int, ...]: """ Tuple of array dimensions. @@ -532,8 +538,10 @@ def load(self, **kwargs): -------- dask.array.compute """ - if is_duck_dask_array(self._data): - self._data = as_compatible_data(self._data.compute(**kwargs)) + if is_chunked_array(self._data): + chunkmanager = get_chunked_array_type(self._data) + loaded_data, *_ = chunkmanager.compute(self._data, **kwargs) + self._data = as_compatible_data(loaded_data) elif isinstance(self._data, indexing.ExplicitlyIndexed): self._data = self._data.get_duck_array() elif not is_duck_array(self._data): @@ -1165,8 +1173,10 @@ def chunk( | Mapping[Any, None | int | tuple[int, ...]] ) = {}, name: str | None = None, - lock: bool = False, - inline_array: bool = False, + lock: bool | None = None, + inline_array: bool | None = None, + chunked_array_type: str | ChunkManagerEntrypoint | None = None, + from_array_kwargs=None, **chunks_kwargs: Any, ) -> Variable: """Coerce this array's data into a dask array with the given chunks. @@ -1187,12 +1197,21 @@ def chunk( name : str, optional Used to generate the name for this array in the internal dask graph. Does not need not be unique. - lock : optional + lock : bool, default: False Passed on to :py:func:`dask.array.from_array`, if the array is not already as dask array. - inline_array: optional + inline_array : bool, default: False Passed on to :py:func:`dask.array.from_array`, if the array is not already as dask array. + chunked_array_type: str, optional + Which chunked array type to coerce this datasets' arrays to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEntrypoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict, optional + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example, with dask as the default chunked array type, this method would pass additional kwargs + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. **chunks_kwargs : {dim: chunks, ...}, optional The keyword arguments form of ``chunks``. One of chunks or chunks_kwargs must be provided. @@ -1208,7 +1227,6 @@ def chunk( xarray.unify_chunks dask.array.from_array """ - import dask.array as da if chunks is None: warnings.warn( @@ -1219,6 +1237,8 @@ def chunk( chunks = {} if isinstance(chunks, (float, str, int, tuple, list)): + # TODO we shouldn't assume here that other chunkmanagers can handle these types + # TODO should we call normalize_chunks here? pass # dask.array.from_array can handle these directly else: chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk") @@ -1226,9 +1246,22 @@ def chunk( if utils.is_dict_like(chunks): chunks = {self.get_axis_num(dim): chunk for dim, chunk in chunks.items()} + chunkmanager = guess_chunkmanager(chunked_array_type) + + if from_array_kwargs is None: + from_array_kwargs = {} + + # TODO deprecate passing these dask-specific arguments explicitly. In future just pass everything via from_array_kwargs + _from_array_kwargs = utils.consolidate_dask_from_array_kwargs( + from_array_kwargs, + name=name, + lock=lock, + inline_array=inline_array, + ) + data = self._data - if is_duck_dask_array(data): - data = data.rechunk(chunks) + if chunkmanager.is_chunked_array(data): + data = chunkmanager.rechunk(data, chunks) # type: ignore[arg-type] else: if isinstance(data, indexing.ExplicitlyIndexed): # Unambiguously handle array storage backends (like NetCDF4 and h5py) @@ -1243,17 +1276,13 @@ def chunk( data, indexing.OuterIndexer ) - # All of our lazily loaded backend array classes should use NumPy - # array operations. - kwargs = {"meta": np.ndarray} - else: - kwargs = {} - if utils.is_dict_like(chunks): - chunks = tuple(chunks.get(n, s) for n, s in enumerate(self.shape)) + chunks = tuple(chunks.get(n, s) for n, s in enumerate(data.shape)) - data = da.from_array( - data, chunks, name=name, lock=lock, inline_array=inline_array, **kwargs + data = chunkmanager.from_array( + data, + chunks, # type: ignore[arg-type] + **_from_array_kwargs, ) return self._replace(data=data) @@ -1265,7 +1294,8 @@ def to_numpy(self) -> np.ndarray: # TODO first attempt to call .to_numpy() once some libraries implement it if hasattr(data, "chunks"): - data = data.compute() + chunkmanager = get_chunked_array_type(data) + data, *_ = chunkmanager.compute(data) if isinstance(data, array_type("cupy")): data = data.get() # pint has to be imported dynamically as pint imports xarray @@ -2902,7 +2932,15 @@ def values(self, values): f"Please use DataArray.assign_coords, Dataset.assign_coords or Dataset.assign as appropriate." ) - def chunk(self, chunks={}, name=None, lock=False, inline_array=False): + def chunk( + self, + chunks={}, + name=None, + lock=False, + inline_array=False, + chunked_array_type=None, + from_array_kwargs=None, + ): # Dummy - do not chunk. This method is invoked e.g. by Dataset.chunk() return self.copy(deep=False) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 904c6a4d980..e21091fad6b 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -238,7 +238,10 @@ def _sum_of_weights(self, da: DataArray, dim: Dims = None) -> DataArray: # (and not 2); GH4074 if self.weights.dtype == bool: sum_of_weights = self._reduce( - mask, self.weights.astype(int), dim=dim, skipna=False + mask, + duck_array_ops.astype(self.weights, dtype=int), + dim=dim, + skipna=False, ) else: sum_of_weights = self._reduce(mask, self.weights, dim=dim, skipna=False) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 1171464a962..ed18718043b 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -904,13 +904,12 @@ def test_to_dask_dataframe_dim_order(self): @pytest.mark.parametrize("method", ["load", "compute"]) def test_dask_kwargs_variable(method): - x = Variable("y", da.from_array(np.arange(3), chunks=(2,))) - # args should be passed on to da.Array.compute() - with mock.patch.object( - da.Array, "compute", return_value=np.arange(3) - ) as mock_compute: + chunked_array = da.from_array(np.arange(3), chunks=(2,)) + x = Variable("y", chunked_array) + # args should be passed on to dask.compute() (via DaskManager.compute()) + with mock.patch.object(da, "compute", return_value=(np.arange(3),)) as mock_compute: getattr(x, method)(foo="bar") - mock_compute.assert_called_with(foo="bar") + mock_compute.assert_called_with(chunked_array, foo="bar") @pytest.mark.parametrize("method", ["load", "compute", "persist"]) diff --git a/xarray/tests/test_parallelcompat.py b/xarray/tests/test_parallelcompat.py new file mode 100644 index 00000000000..2c3378a2816 --- /dev/null +++ b/xarray/tests/test_parallelcompat.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np +import pytest + +from xarray.core.daskmanager import DaskManager +from xarray.core.parallelcompat import ( + ChunkManagerEntrypoint, + get_chunked_array_type, + guess_chunkmanager, + list_chunkmanagers, +) +from xarray.core.types import T_Chunks, T_NormalizedChunks +from xarray.tests import has_dask, requires_dask + + +class DummyChunkedArray(np.ndarray): + """ + Mock-up of a chunked array class. + + Adds a (non-functional) .chunks attribute by following this example in the numpy docs + https://numpy.org/doc/stable/user/basics.subclassing.html#simple-example-adding-an-extra-attribute-to-ndarray + """ + + chunks: T_NormalizedChunks + + def __new__( + cls, + shape, + dtype=float, + buffer=None, + offset=0, + strides=None, + order=None, + chunks=None, + ): + obj = super().__new__(cls, shape, dtype, buffer, offset, strides, order) + obj.chunks = chunks + return obj + + def __array_finalize__(self, obj): + if obj is None: + return + self.chunks = getattr(obj, "chunks", None) + + def rechunk(self, chunks, **kwargs): + copied = self.copy() + copied.chunks = chunks + return copied + + +class DummyChunkManager(ChunkManagerEntrypoint): + """Mock-up of ChunkManager class for DummyChunkedArray""" + + def __init__(self): + self.array_cls = DummyChunkedArray + + def is_chunked_array(self, data: Any) -> bool: + return isinstance(data, DummyChunkedArray) + + def chunks(self, data: DummyChunkedArray) -> T_NormalizedChunks: + return data.chunks + + def normalize_chunks( + self, + chunks: T_Chunks | T_NormalizedChunks, + shape: tuple[int, ...] | None = None, + limit: int | None = None, + dtype: np.dtype | None = None, + previous_chunks: T_NormalizedChunks | None = None, + ) -> T_NormalizedChunks: + from dask.array.core import normalize_chunks + + return normalize_chunks(chunks, shape, limit, dtype, previous_chunks) + + def from_array( + self, data: np.ndarray, chunks: T_Chunks, **kwargs + ) -> DummyChunkedArray: + from dask import array as da + + return da.from_array(data, chunks, **kwargs) + + def rechunk(self, data: DummyChunkedArray, chunks, **kwargs) -> DummyChunkedArray: + return data.rechunk(chunks, **kwargs) + + def compute(self, *data: DummyChunkedArray, **kwargs) -> tuple[np.ndarray, ...]: + from dask.array import compute + + return compute(*data, **kwargs) + + def apply_gufunc( + self, + func, + signature, + *args, + axes=None, + axis=None, + keepdims=False, + output_dtypes=None, + output_sizes=None, + vectorize=None, + allow_rechunk=False, + meta=None, + **kwargs, + ): + from dask.array.gufunc import apply_gufunc + + return apply_gufunc( + func, + signature, + *args, + axes=axes, + axis=axis, + keepdims=keepdims, + output_dtypes=output_dtypes, + output_sizes=output_sizes, + vectorize=vectorize, + allow_rechunk=allow_rechunk, + meta=meta, + **kwargs, + ) + + +@pytest.fixture +def register_dummy_chunkmanager(monkeypatch): + """ + Mocks the registering of an additional ChunkManagerEntrypoint. + + This preserves the presence of the existing DaskManager, so a test that relies on this and DaskManager both being + returned from list_chunkmanagers() at once would still work. + + The monkeypatching changes the behavior of list_chunkmanagers when called inside xarray.core.parallelcompat, + but not when called from this tests file. + """ + # Should include DaskManager iff dask is available to be imported + preregistered_chunkmanagers = list_chunkmanagers() + + monkeypatch.setattr( + "xarray.core.parallelcompat.list_chunkmanagers", + lambda: {"dummy": DummyChunkManager()} | preregistered_chunkmanagers, + ) + yield + + +class TestGetChunkManager: + def test_get_chunkmanger(self, register_dummy_chunkmanager) -> None: + chunkmanager = guess_chunkmanager("dummy") + assert isinstance(chunkmanager, DummyChunkManager) + + def test_fail_on_nonexistent_chunkmanager(self) -> None: + with pytest.raises(ValueError, match="unrecognized chunk manager foo"): + guess_chunkmanager("foo") + + @requires_dask + def test_get_dask_if_installed(self) -> None: + chunkmanager = guess_chunkmanager(None) + assert isinstance(chunkmanager, DaskManager) + + @pytest.mark.skipif(has_dask, reason="requires dask not to be installed") + def test_dont_get_dask_if_not_installed(self) -> None: + with pytest.raises(ValueError, match="unrecognized chunk manager dask"): + guess_chunkmanager("dask") + + @requires_dask + def test_choose_dask_over_other_chunkmanagers( + self, register_dummy_chunkmanager + ) -> None: + chunk_manager = guess_chunkmanager(None) + assert isinstance(chunk_manager, DaskManager) + + +class TestGetChunkedArrayType: + def test_detect_chunked_arrays(self, register_dummy_chunkmanager) -> None: + dummy_arr = DummyChunkedArray([1, 2, 3]) + + chunk_manager = get_chunked_array_type(dummy_arr) + assert isinstance(chunk_manager, DummyChunkManager) + + def test_ignore_inmemory_arrays(self, register_dummy_chunkmanager) -> None: + dummy_arr = DummyChunkedArray([1, 2, 3]) + + chunk_manager = get_chunked_array_type(*[dummy_arr, 1.0, np.array([5, 6])]) + assert isinstance(chunk_manager, DummyChunkManager) + + with pytest.raises(TypeError, match="Expected a chunked array"): + get_chunked_array_type(5.0) + + def test_raise_if_no_arrays_chunked(self, register_dummy_chunkmanager) -> None: + with pytest.raises(TypeError, match="Expected a chunked array "): + get_chunked_array_type(*[1.0, np.array([5, 6])]) + + def test_raise_if_no_matching_chunkmanagers(self) -> None: + dummy_arr = DummyChunkedArray([1, 2, 3]) + + with pytest.raises( + TypeError, match="Could not find a Chunk Manager which recognises" + ): + get_chunked_array_type(dummy_arr) + + @requires_dask + def test_detect_dask_if_installed(self) -> None: + import dask.array as da + + dask_arr = da.from_array([1, 2, 3], chunks=(1,)) + + chunk_manager = get_chunked_array_type(dask_arr) + assert isinstance(chunk_manager, DaskManager) + + @requires_dask + def test_raise_on_mixed_array_types(self, register_dummy_chunkmanager) -> None: + import dask.array as da + + dummy_arr = DummyChunkedArray([1, 2, 3]) + dask_arr = da.from_array([1, 2, 3], chunks=(1,)) + + with pytest.raises(TypeError, match="received multiple types"): + get_chunked_array_type(*[dask_arr, dummy_arr]) diff --git a/xarray/tests/test_plugins.py b/xarray/tests/test_plugins.py index 0882bc1b570..441f16f4dca 100644 --- a/xarray/tests/test_plugins.py +++ b/xarray/tests/test_plugins.py @@ -236,6 +236,7 @@ def test_lazy_import() -> None: "sparse", "cupy", "pint", + "cubed", ] # ensure that none of the above modules has been imported before modules_backup = {}