Skip to content

Commit

Permalink
generalized chunk_hint function inside indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
TomNicholas committed Mar 14, 2023
1 parent 8bbc141 commit 6cfe9fa
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
15 changes: 9 additions & 6 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from xarray.core import duck_array_ops
from xarray.core.nputils import NumpyVIndexAdapter
from xarray.core.options import OPTIONS
from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array
from xarray.core.pycompat import array_type, integer_types, is_duck_dask_array
from xarray.core.types import T_Xarray
from xarray.core.utils import (
Expand Down Expand Up @@ -1075,16 +1076,15 @@ def _arrayize_vectorized_indexer(indexer, shape):
return VectorizedIndexer(tuple(new_key))


def _dask_array_with_chunks_hint(array, chunks):
"""Create a dask array using the chunks hint for dimensions of size > 1."""
import dask.array as da
def _chunked_array_with_chunks_hint(array, chunks, chunkmanager):
"""Create a chunked array using the chunks hint for dimensions of size > 1."""

if len(chunks) < array.ndim:
raise ValueError("not enough chunks in hint")
new_chunks = []
for chunk, size in zip(chunks, array.shape):
new_chunks.append(chunk if size > 1 else (1,))
return da.from_array(array, new_chunks)
return chunkmanager.from_array(array, new_chunks)


def _logical_any(args):
Expand All @@ -1098,8 +1098,11 @@ def _masked_result_drop_slice(key, data=None):
new_keys = []
for k in key:
if isinstance(k, np.ndarray):
if is_duck_dask_array(data):
new_keys.append(_dask_array_with_chunks_hint(k, chunks_hint))
if is_chunked_array(data):
chunkmanager = get_chunked_array_type(data)
new_keys.append(
_chunked_array_with_chunks_hint(k, chunks_hint, chunkmanager)
)
elif isinstance(data, array_type("sparse")):
import sparse

Expand Down
4 changes: 3 additions & 1 deletion xarray/core/parallelcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import numpy as np

from xarray.core import indexing, utils
from xarray.core import utils
from xarray.core.pycompat import DuckArrayModule, is_chunked_array, is_duck_dask_array
from xarray.core.types import T_Chunks

Expand Down Expand Up @@ -197,6 +197,8 @@ def chunks(self, data: T_DaskArray) -> T_Chunks:
def from_array(self, data: np.ndarray, chunks, **kwargs) -> T_DaskArray:
import dask.array as da

from xarray.core import indexing

# dask-specific kwargs
name = kwargs.pop("name", None)
lock = kwargs.pop("lock", False)
Expand Down

0 comments on commit 6cfe9fa

Please sign in to comment.