Skip to content

Commit 0ab6331

Browse files
committed
Fixes #4276
Pass 0d dask arrays through for indexing.
1 parent d5e7646 commit 0ab6331

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

xarray/core/indexing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def __init__(self, key):
366366
raise TypeError(
367367
f"invalid indexer array, does not have integer dtype: {k!r}"
368368
)
369-
if k.ndim != 1:
369+
if k.ndim > 1:
370370
raise TypeError(
371371
f"invalid indexer array for {type(self).__name__}; must have "
372372
f"exactly 1 dimension: {k!r}"

xarray/core/variable.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -601,11 +601,12 @@ def _broadcast_indexes(self, key):
601601
key = self._item_key_to_tuple(key) # key is a tuple
602602
# key is a tuple of full size
603603
key = indexing.expanded_indexer(key, self.ndim)
604-
# Convert a scalar Variable to an integer
604+
# Convert a scalar Variable to a 0d-array
605605
key = tuple(
606-
k.data.item() if isinstance(k, Variable) and k.ndim == 0 else k for k in key
606+
k.data if isinstance(k, Variable) and k.ndim == 0 else k for k in key
607607
)
608-
# Convert a 0d-array to an integer
608+
# Convert a 0d numpy arrays to an integer
609+
# dask 0d arrays are passed through
609610
key = tuple(
610611
k.item() if isinstance(k, np.ndarray) and k.ndim == 0 else k for k in key
611612
)
@@ -681,10 +682,11 @@ def _validate_indexers(self, key):
681682
)
682683

683684
def _broadcast_indexes_outer(self, key):
685+
# drop dim if k is integer or if k is a 0d dask array
684686
dims = tuple(
685687
k.dims[0] if isinstance(k, Variable) else dim
686688
for k, dim in zip(key, self.dims)
687-
if not isinstance(k, integer_types)
689+
if (not isinstance(k, integer_types) and k.ndim > 0)
688690
)
689691

690692
new_key = []

xarray/tests/test_indexing.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -846,7 +846,6 @@ def test_indexing_dask_array():
846846
assert_identical(actual, expected)
847847

848848

849-
@pytest.mark.xfail
850849
@requires_dask
851850
def test_indexing_dask_array_scalar():
852851
# GH4276
@@ -856,7 +855,7 @@ def test_indexing_dask_array_scalar():
856855
da = DataArray(a, dims="x")
857856
x_selector = da.argmax(dim=...)
858857
actual = da.isel(x_selector)
859-
expected = da.isel(x=1)
858+
expected = da.isel(x=-1)
860859
assert_identical(actual, expected)
861860

862861

0 commit comments

Comments
 (0)