diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index e71c4a6f073..22bdb0777d2 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -161,7 +161,10 @@ def unstack(self) -> tuple[dict[Hashable, Index], pd.MultiIndex]: raise NotImplementedError() def create_variables( - self, variables: Mapping[Any, Variable] | None = None + self, + variables: Mapping[Any, Variable] | None = None, + *, + fastpath=False, ) -> IndexVars: """Maybe create new coordinate variables from this index. @@ -575,13 +578,19 @@ class PandasIndex(Index): __slots__ = ("index", "dim", "coord_dtype") - def __init__(self, array: Any, dim: Hashable, coord_dtype: Any = None): - # make a shallow copy: cheap and because the index name may be updated - # here or in other constructors (cannot use pd.Index.rename as this - # constructor is also called from PandasMultiIndex) - index = safe_cast_to_index(array).copy() + def __init__( + self, array: Any, dim: Hashable, coord_dtype: Any = None, *, fastpath=False + ): + if fastpath: + index = array + else: + index = safe_cast_to_index(array) if index.name is None: + # make a shallow copy: cheap and because the index name may be updated + # here or in other constructors (cannot use pd.Index.rename as this + # constructor is also called from PandasMultiIndex) + index = index.copy() index.name = dim self.index = index @@ -596,7 +605,7 @@ def _replace(self, index, dim=None, coord_dtype=None): dim = self.dim if coord_dtype is None: coord_dtype = self.coord_dtype - return type(self)(index, dim, coord_dtype) + return type(self)(index, dim, coord_dtype, fastpath=True) @classmethod def from_variables( @@ -641,6 +650,8 @@ def from_variables( obj = cls(data, dim, coord_dtype=var.dtype) assert not isinstance(obj.index, pd.MultiIndex) + # Rename safely + obj.index = obj.index.copy() obj.index.name = name return obj @@ -684,7 +695,7 @@ def concat( return cls(new_pd_index, dim=dim, coord_dtype=coord_dtype) def create_variables( - self, variables: Mapping[Any, Variable] | None = None + self, variables: Mapping[Any, Variable] | None = None, *, fastpath=False ) -> IndexVars: from xarray.core.variable import IndexVariable @@ -701,7 +712,9 @@ def create_variables( encoding = None data = PandasIndexingAdapter(self.index, dtype=self.coord_dtype) - var = IndexVariable(self.dim, data, attrs=attrs, encoding=encoding) + var = IndexVariable( + self.dim, data, attrs=attrs, encoding=encoding, fastpath=fastpath + ) return {name: var} def to_pandas_index(self) -> pd.Index: @@ -1122,7 +1135,7 @@ def reorder_levels( return self._replace(index, level_coords_dtype=level_coords_dtype) def create_variables( - self, variables: Mapping[Any, Variable] | None = None + self, variables: Mapping[Any, Variable] | None = None, *, fastpath=False ) -> IndexVars: from xarray.core.variable import IndexVariable @@ -1772,6 +1785,37 @@ def check_variables(): return not not_equal +def _apply_indexes_fast(indexes: Indexes[Index], args: Mapping[Any, Any], func: str): + # This function avoids the call to indexes.group_by_index + # which is really slow when repeatidly iterating through + # an array. However, it fails to return the correct ID for + # multi-index arrays + indexes_fast, coords = indexes._indexes, indexes._variables + + new_indexes: dict[Hashable, Index] = {k: v for k, v in indexes_fast.items()} + new_index_variables: dict[Hashable, Variable] = {} + for name, index in indexes_fast.items(): + coord = coords[name] + if hasattr(coord, "_indexes"): + index_vars = {n: coords[n] for n in coord._indexes} + else: + index_vars = {name: coord} + index_dims = {d for var in index_vars.values() for d in var.dims} + index_args = {k: v for k, v in args.items() if k in index_dims} + + if index_args: + new_index = getattr(index, func)(index_args) + if new_index is not None: + new_indexes.update({k: new_index for k in index_vars}) + new_index_vars = new_index.create_variables(index_vars, fastpath=True) + new_index_variables.update(new_index_vars) + new_index_variables.update(new_index_vars) + else: + for k in index_vars: + new_indexes.pop(k, None) + return new_indexes, new_index_variables + + def _apply_indexes( indexes: Indexes[Index], args: Mapping[Any, Any], @@ -1800,7 +1844,10 @@ def isel_indexes( indexes: Indexes[Index], indexers: Mapping[Any, Any], ) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: - return _apply_indexes(indexes, indexers, "isel") + if any(isinstance(v, PandasMultiIndex) for v in indexes._indexes.values()): + return _apply_indexes(indexes, indexers, "isel") + else: + return _apply_indexes_fast(indexes, indexers, "isel") def roll_indexes( diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 4909d506d7b..02f1fb262c5 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1662,10 +1662,13 @@ class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin): __slots__ = ("array", "_dtype") - def __init__(self, array: pd.Index, dtype: DTypeLike = None): + def __init__(self, array: pd.Index, dtype: DTypeLike = None, *, fastpath=False): from xarray.core.indexes import safe_cast_to_index - self.array = safe_cast_to_index(array) + if fastpath: + self.array = array + else: + self.array = safe_cast_to_index(array) if dtype is None: self._dtype = get_valid_numpy_dtype(array)