diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index becf1554453..428ec63eda0 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -17,6 +17,7 @@ from xarray.core import duck_array_ops from xarray.core.nputils import NumpyVIndexAdapter from xarray.core.options import OPTIONS +from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array from xarray.core.pycompat import array_type, integer_types, is_duck_dask_array from xarray.core.types import T_Xarray from xarray.core.utils import ( @@ -1075,16 +1076,15 @@ def _arrayize_vectorized_indexer(indexer, shape): return VectorizedIndexer(tuple(new_key)) -def _dask_array_with_chunks_hint(array, chunks): - """Create a dask array using the chunks hint for dimensions of size > 1.""" - import dask.array as da +def _chunked_array_with_chunks_hint(array, chunks, chunkmanager): + """Create a chunked array using the chunks hint for dimensions of size > 1.""" if len(chunks) < array.ndim: raise ValueError("not enough chunks in hint") new_chunks = [] for chunk, size in zip(chunks, array.shape): new_chunks.append(chunk if size > 1 else (1,)) - return da.from_array(array, new_chunks) + return chunkmanager.from_array(array, new_chunks) def _logical_any(args): @@ -1098,8 +1098,11 @@ def _masked_result_drop_slice(key, data=None): new_keys = [] for k in key: if isinstance(k, np.ndarray): - if is_duck_dask_array(data): - new_keys.append(_dask_array_with_chunks_hint(k, chunks_hint)) + if is_chunked_array(data): + chunkmanager = get_chunked_array_type(data) + new_keys.append( + _chunked_array_with_chunks_hint(k, chunks_hint, chunkmanager) + ) elif isinstance(data, array_type("sparse")): import sparse diff --git a/xarray/core/parallelcompat.py b/xarray/core/parallelcompat.py index 8ce8b9d5e8b..1aa8c9e15ec 100644 --- a/xarray/core/parallelcompat.py +++ b/xarray/core/parallelcompat.py @@ -9,7 +9,7 @@ import numpy as np -from xarray.core import indexing, utils +from xarray.core import utils from xarray.core.pycompat import DuckArrayModule, is_chunked_array, is_duck_dask_array from xarray.core.types import T_Chunks @@ -197,6 +197,8 @@ def chunks(self, data: T_DaskArray) -> T_Chunks: def from_array(self, data: np.ndarray, chunks, **kwargs) -> T_DaskArray: import dask.array as da + from xarray.core import indexing + # dask-specific kwargs name = kwargs.pop("name", None) lock = kwargs.pop("lock", False)