Skip to content

Commit

Permalink
Micro optimizations to improve indexing (#9002)
Browse files Browse the repository at this point in the history
* conda instead of mamba

* Make speedups using fastpath

* Change core logic to apply_indexes_fast

* Always have fastpath=True in one path

* Remove basicindexer fastpath=True

* Duplicate a comment

* Add comments

* revert asv changes

* Avoid fastpath=True assignment

* Remove changes to basicindexer

* Do not do fast fastpath for IndexVariable

* Remove one unecessary change

* Remove one more fastpath

* Revert uneeded change to PandasIndexingAdapter

* Update xarray/core/indexes.py

* Update whats-new.rst

* Update whats-new.rst

* fix whats-new

---------

Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
Co-authored-by: Deepak Cherian <deepak@cherian.net>
  • Loading branch information
3 people authored and andersy005 committed Jun 14, 2024
1 parent f36494e commit feb9aad
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 10 deletions.
4 changes: 3 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ Performance

- Small optimization to the netCDF4 and h5netcdf backends (:issue:`9058`, :pull:`9067`).
By `Deepak Cherian <https://github.com/dcherian>`_.
- Small optimizations to help reduce indexing speed of datasets (:pull:`9002`).
By `Mark Harfouche <https://github.com/hmaarrfk>`_.


Breaking changes
Expand Down Expand Up @@ -2906,7 +2908,7 @@ Bug fixes
process (:issue:`4045`, :pull:`4684`). It also enables encoding and decoding standard
calendar dates with time units of nanoseconds (:pull:`4400`).
By `Spencer Clark <https://github.com/spencerkclark>`_ and `Mark Harfouche
<http://github.com/hmaarrfk>`_.
<https://github.com/hmaarrfk>`_.
- :py:meth:`DataArray.astype`, :py:meth:`Dataset.astype` and :py:meth:`Variable.astype` support
the ``order`` and ``subok`` parameters again. This fixes a regression introduced in version 0.16.1
(:issue:`4644`, :pull:`4683`).
Expand Down
66 changes: 59 additions & 7 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,13 +575,24 @@ class PandasIndex(Index):

__slots__ = ("index", "dim", "coord_dtype")

def __init__(self, array: Any, dim: Hashable, coord_dtype: Any = None):
# make a shallow copy: cheap and because the index name may be updated
# here or in other constructors (cannot use pd.Index.rename as this
# constructor is also called from PandasMultiIndex)
index = safe_cast_to_index(array).copy()
def __init__(
self,
array: Any,
dim: Hashable,
coord_dtype: Any = None,
*,
fastpath: bool = False,
):
if fastpath:
index = array
else:
index = safe_cast_to_index(array)

if index.name is None:
# make a shallow copy: cheap and because the index name may be updated
# here or in other constructors (cannot use pd.Index.rename as this
# constructor is also called from PandasMultiIndex)
index = index.copy()
index.name = dim

self.index = index
Expand All @@ -596,7 +607,7 @@ def _replace(self, index, dim=None, coord_dtype=None):
dim = self.dim
if coord_dtype is None:
coord_dtype = self.coord_dtype
return type(self)(index, dim, coord_dtype)
return type(self)(index, dim, coord_dtype, fastpath=True)

@classmethod
def from_variables(
Expand Down Expand Up @@ -642,6 +653,11 @@ def from_variables(

obj = cls(data, dim, coord_dtype=var.dtype)
assert not isinstance(obj.index, pd.MultiIndex)
# Rename safely
# make a shallow copy: cheap and because the index name may be updated
# here or in other constructors (cannot use pd.Index.rename as this
# constructor is also called from PandasMultiIndex)
obj.index = obj.index.copy()
obj.index.name = name

return obj
Expand Down Expand Up @@ -1773,6 +1789,36 @@ def check_variables():
return not not_equal


def _apply_indexes_fast(indexes: Indexes[Index], args: Mapping[Any, Any], func: str):
# This function avoids the call to indexes.group_by_index
# which is really slow when repeatidly iterating through
# an array. However, it fails to return the correct ID for
# multi-index arrays
indexes_fast, coords = indexes._indexes, indexes._variables

new_indexes: dict[Hashable, Index] = {k: v for k, v in indexes_fast.items()}
new_index_variables: dict[Hashable, Variable] = {}
for name, index in indexes_fast.items():
coord = coords[name]
if hasattr(coord, "_indexes"):
index_vars = {n: coords[n] for n in coord._indexes}
else:
index_vars = {name: coord}
index_dims = {d for var in index_vars.values() for d in var.dims}
index_args = {k: v for k, v in args.items() if k in index_dims}

if index_args:
new_index = getattr(index, func)(index_args)
if new_index is not None:
new_indexes.update({k: new_index for k in index_vars})
new_index_vars = new_index.create_variables(index_vars)
new_index_variables.update(new_index_vars)
else:
for k in index_vars:
new_indexes.pop(k, None)
return new_indexes, new_index_variables


def _apply_indexes(
indexes: Indexes[Index],
args: Mapping[Any, Any],
Expand Down Expand Up @@ -1801,7 +1847,13 @@ def isel_indexes(
indexes: Indexes[Index],
indexers: Mapping[Any, Any],
) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]:
return _apply_indexes(indexes, indexers, "isel")
# TODO: remove if clause in the future. It should be unnecessary.
# See failure introduced when removed
# https://github.com/pydata/xarray/pull/9002#discussion_r1590443756
if any(isinstance(v, PandasMultiIndex) for v in indexes._indexes.values()):
return _apply_indexes(indexes, indexers, "isel")
else:
return _apply_indexes_fast(indexes, indexers, "isel")


def roll_indexes(
Expand Down
5 changes: 3 additions & 2 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,15 +297,16 @@ def slice_slice(old_slice: slice, applied_slice: slice, size: int) -> slice:


def _index_indexer_1d(old_indexer, applied_indexer, size: int):
assert isinstance(applied_indexer, integer_types + (slice, np.ndarray))
if isinstance(applied_indexer, slice) and applied_indexer == slice(None):
# shortcut for the usual case
return old_indexer
if isinstance(old_indexer, slice):
if isinstance(applied_indexer, slice):
indexer = slice_slice(old_indexer, applied_indexer, size)
elif isinstance(applied_indexer, integer_types):
indexer = range(*old_indexer.indices(size))[applied_indexer] # type: ignore[assignment]
else:
indexer = _expand_slice(old_indexer, size)[applied_indexer] # type: ignore[assignment]
indexer = _expand_slice(old_indexer, size)[applied_indexer]
else:
indexer = old_indexer[applied_indexer]
return indexer
Expand Down

0 comments on commit feb9aad

Please sign in to comment.