From a551c7f05abf90a492fb59068b59ebb2bac8cb4c Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 12 Aug 2021 14:37:42 +0200 Subject: [PATCH] fix multi-index selection regression See https://github.com/pydata/xarray/issues/5691 --- xarray/core/indexes.py | 22 +++++++++++++++------- xarray/tests/test_dataarray.py | 14 ++++++++++++++ 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 9223631915f..f1cc7dfaed1 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -129,6 +129,15 @@ def _is_nested_tuple(possible_tuple): ) +def normalize_label(value, extract_scalar=False): + if getattr(value, "ndim", 1) <= 1: + value = _asarray_tuplesafe(value) + if extract_scalar: + # see https://github.com/pydata/xarray/pull/4292 for details + value = value[()] if value.dtype.kind in "mM" else value.item() + return value + + def get_indexer_nd(index, labels, method=None, tolerance=None): """Wrapper around :meth:`pandas.Index.get_indexer` supporting n-dimensional labels @@ -207,14 +216,9 @@ def query(self, labels, method=None, tolerance=None): "a dimension that does not have a MultiIndex" ) else: - label = ( - label - if getattr(label, "ndim", 1) > 1 # vectorized-indexing - else _asarray_tuplesafe(label) - ) + label = normalize_label(label) if label.ndim == 0: - # see https://github.com/pydata/xarray/pull/4292 for details - label_value = label[()] if label.dtype.kind in "mM" else label.item() + label_value = normalize_label(label, extract_scalar=True) if isinstance(self.index, pd.CategoricalIndex): if method is not None: raise ValueError( @@ -336,6 +340,10 @@ def query(self, labels, method=None, tolerance=None): # label(s) given for multi-index level(s) if all([lbl in self.index.names for lbl in labels]): is_nested_vals = _is_nested_tuple(tuple(labels.values())) + labels = { + k: normalize_label(v, extract_scalar=True) for k, v in labels.items() + } + if len(labels) == self.index.nlevels and not is_nested_vals: indexer = self.index.get_loc(tuple(labels[k] for k in self.index.names)) else: diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 8ab8bc872da..5205c1b59ab 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1006,6 +1006,20 @@ def test_sel_float(self): assert_equal(expected_scalar, actual_scalar) assert_equal(expected_16, actual_16) + def test_sel_float_multiindex(self): + # regression test https://github.com/pydata/xarray/issues/5691 + midx = pd.MultiIndex.from_arrays( + [["a", "a", "b", "b"], [0.1, 0.2, 0.3, 0.4]], names=["lvl1", "lvl2"] + ) + da = xr.DataArray([1, 2, 3, 4], coords={"x": midx}, dims="x") + + actual = da.sel(lvl1="a", lvl2=0.1) + expected = da.isel(x=0) + + assert_equal(actual, expected) + + # TODO: test multi-index created from coordinates, one with dtype=float32 + def test_sel_no_index(self): array = DataArray(np.arange(10), dims="x") assert_identical(array[0], array.sel(x=0))