Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up isel and __getitem__ #3375

Merged
merged 12 commits into from
Oct 9, 2019
Merged
10 changes: 8 additions & 2 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,27 @@ Breaking changes

(:issue:`3222`, :issue:`3293`, :issue:`3340`, :issue:`3346`, :issue:`3358`).
By `Guido Imperiale <https://github.com/crusaderky>`_.
- Dropped the 'drop=False' optional parameter from :meth:`Variable.isel`.
It was unused and doesn't make sense for a Variable.
(:pull:`3375`) by `Guido Imperiale <https://github.com/crusaderky>`_.

New functions/methods
~~~~~~~~~~~~~~~~~~~~~

Enhancements
~~~~~~~~~~~~

- Add a repr for :py:class:`~xarray.core.GroupBy` objects (:issue:`3344`).
- Add a repr for :py:class:`~xarray.core.GroupBy` objects.
Example::

>>> da.groupby("time.season")
DataArrayGroupBy, grouped over 'season'
4 groups with labels 'DJF', 'JJA', 'MAM', 'SON'

By `Deepak Cherian <https://github.com/dcherian>`_.
(:issue:`3344`) by `Deepak Cherian <https://github.com/dcherian>`_.
- Speed up :meth:`Dataset.isel` up to 33% and :meth:`DataArray.isel` up to 25% for small
arrays (:issue:`2799`, :pull:`3375`) by
`Guido Imperiale <https://github.com/crusaderky>`_.

Bug fixes
~~~~~~~~~
Expand Down
108 changes: 59 additions & 49 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1745,8 +1745,8 @@ def maybe_chunk(name, var, chunks):
return self._replace(variables)

def _validate_indexers(
self, indexers: Mapping
) -> List[Tuple[Any, Union[slice, Variable]]]:
self, indexers: Mapping[Hashable, Any]
) -> Iterator[Tuple[Hashable, Union[int, slice, np.ndarray, Variable]]]:
""" Here we make sure
+ indexer has a valid keys
+ indexer is in a valid data type
Expand All @@ -1755,50 +1755,61 @@ def _validate_indexers(
"""
from .dataarray import DataArray

invalid = [k for k in indexers if k not in self.dims]
invalid = indexers.keys() - self.dims.keys()
if invalid:
raise ValueError("dimensions %r do not exist" % invalid)

# all indexers should be int, slice, np.ndarrays, or Variable
indexers_list: List[Tuple[Any, Union[slice, Variable]]] = []
for k, v in indexers.items():
if isinstance(v, slice):
indexers_list.append((k, v))
continue

if isinstance(v, Variable):
pass
if isinstance(v, (int, slice, Variable)):
yield k, v
elif isinstance(v, DataArray):
v = v.variable
yield k, v.variable
elif isinstance(v, tuple):
v = as_variable(v)
yield k, as_variable(v)
elif isinstance(v, Dataset):
raise TypeError("cannot use a Dataset as an indexer")
elif isinstance(v, Sequence) and len(v) == 0:
v = Variable((k,), np.zeros((0,), dtype="int64"))
yield k, np.empty((0,), dtype="int64")
else:
v = np.asarray(v)

if v.dtype.kind == "U" or v.dtype.kind == "S":
if v.dtype.kind in "US":
index = self.indexes[k]
if isinstance(index, pd.DatetimeIndex):
v = v.astype("datetime64[ns]")
elif isinstance(index, xr.CFTimeIndex):
v = _parse_array_of_cftime_strings(v, index.date_type)

if v.ndim == 0:
v = Variable((), v)
elif v.ndim == 1:
v = Variable((k,), v)
else:
if v.ndim > 1:
raise IndexError(
"Unlabeled multi-dimensional array cannot be "
"used for indexing: {}".format(k)
)
yield k, v
jhamman marked this conversation as resolved.
Show resolved Hide resolved

indexers_list.append((k, v))

return indexers_list
def _validate_interp_indexers(
self, indexers: Mapping[Hashable, Any]
) -> Iterator[Tuple[Hashable, Variable]]:
"""Variant of _validate_indexers to be used for interpolation
"""
for k, v in self._validate_indexers(indexers):
if isinstance(v, Variable):
if v.ndim == 1:
yield k, v.to_index_variable()
else:
yield k, v
elif isinstance(v, int):
yield k, Variable((), v)
elif isinstance(v, np.ndarray):
if v.ndim == 0:
yield k, Variable((), v)
elif v.ndim == 1:
yield k, IndexVariable((k,), v)
else:
raise AssertionError() # Already tested by _validate_indexers
else:
raise TypeError(type(v))

def _get_indexers_coords_and_indexes(self, indexers):
"""Extract coordinates and indexes from indexers.
Expand Down Expand Up @@ -1885,10 +1896,10 @@ def isel(
Dataset.sel
DataArray.isel
"""

indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")

indexers_list = self._validate_indexers(indexers)
# Note: we need to preserve the original indexers variable in order to merge the
# coords below
indexers_list = list(self._validate_indexers(indexers))

variables = OrderedDict() # type: OrderedDict[Hashable, Variable]
indexes = OrderedDict() # type: OrderedDict[Hashable, pd.Index]
Expand All @@ -1904,19 +1915,21 @@ def isel(
)
if new_index is not None:
indexes[name] = new_index
else:
elif var_indexers:
new_var = var.isel(indexers=var_indexers)
else:
new_var = var.copy(deep=False)

variables[name] = new_var

coord_names = set(variables).intersection(self._coord_names)
coord_names = self._coord_names & variables.keys()
selected = self._replace_with_new_dims(variables, coord_names, indexes)

# Extract coordinates from indexers
coord_vars, new_indexes = selected._get_indexers_coords_and_indexes(indexers)
variables.update(coord_vars)
indexes.update(new_indexes)
coord_names = set(variables).intersection(self._coord_names).union(coord_vars)
coord_names = self._coord_names & variables.keys() | coord_vars.keys()
return self._replace_with_new_dims(variables, coord_names, indexes=indexes)

def sel(
Expand Down Expand Up @@ -2478,11 +2491,9 @@ def interp(

if kwargs is None:
kwargs = {}

coords = either_dict_or_kwargs(coords, coords_kwargs, "interp")
indexers = OrderedDict(
(k, v.to_index_variable() if isinstance(v, Variable) and v.ndim == 1 else v)
for k, v in self._validate_indexers(coords)
)
indexers = OrderedDict(self._validate_interp_indexers(coords))

obj = self if assume_sorted else self.sortby([k for k in coords])

Expand All @@ -2507,26 +2518,25 @@ def _validate_interp_indexer(x, new_x):
"strings or datetimes. "
"Instead got\n{}".format(new_x)
)
else:
return (x, new_x)
return x, new_x

variables = OrderedDict() # type: OrderedDict[Hashable, Variable]
for name, var in obj._variables.items():
if name not in indexers:
if var.dtype.kind in "uifc":
var_indexers = {
k: _validate_interp_indexer(maybe_variable(obj, k), v)
for k, v in indexers.items()
if k in var.dims
}
variables[name] = missing.interp(
var, var_indexers, method, **kwargs
)
elif all(d not in indexers for d in var.dims):
# keep unrelated object array
variables[name] = var
if name in indexers:
continue

if var.dtype.kind in "uifc":
var_indexers = {
k: _validate_interp_indexer(maybe_variable(obj, k), v)
for k, v in indexers.items()
if k in var.dims
}
variables[name] = missing.interp(var, var_indexers, method, **kwargs)
elif all(d not in indexers for d in var.dims):
# keep unrelated object array
variables[name] = var

coord_names = set(variables).intersection(obj._coord_names)
coord_names = obj._coord_names & variables.keys()
indexes = OrderedDict(
(k, v) for k, v in obj.indexes.items() if k not in indexers
)
Expand All @@ -2546,7 +2556,7 @@ def _validate_interp_indexer(x, new_x):
variables.update(coord_vars)
indexes.update(new_indexes)

coord_names = set(variables).intersection(obj._coord_names).union(coord_vars)
coord_names = obj._coord_names & variables.keys() | coord_vars.keys()
return self._replace_with_new_dims(variables, coord_names, indexes=indexes)

def interp_like(
Expand Down
3 changes: 2 additions & 1 deletion xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import OrderedDict
from typing import Any, Hashable, Iterable, Mapping, Optional, Tuple, Union

import numpy as np
import pandas as pd

from . import formatting
Expand Down Expand Up @@ -63,7 +64,7 @@ def isel_variable_and_index(
name: Hashable,
variable: Variable,
index: pd.Index,
indexers: Mapping[Any, Union[slice, Variable]],
indexers: Mapping[Hashable, Union[int, slice, np.ndarray, Variable]],
) -> Tuple[Variable, Optional[pd.Index]]:
"""Index a Variable and pandas.Index together."""
if not indexers:
Expand Down
35 changes: 24 additions & 11 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import OrderedDict, defaultdict
from datetime import timedelta
from distutils.version import LooseVersion
from typing import Any, Hashable, Mapping, Union
from typing import Any, Hashable, Mapping, Union, TypeVar

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -41,6 +41,18 @@
# https://github.com/python/mypy/issues/224
BASIC_INDEXING_TYPES = integer_types + (slice,) # type: ignore

VariableType = TypeVar("VariableType", bound="Variable")
jhamman marked this conversation as resolved.
Show resolved Hide resolved
"""Type annotation to be used when methods of Variable return self or a copy of self.
When called from an instance of a subclass, e.g. IndexVariable, mypy identifies the
output as an instance of the subclass.

Usage::

class Variable:
def f(self: VariableType, ...) -> VariableType:
...
"""


class MissingDimensionsError(ValueError):
"""Error class used when we can't safely guess a dimension name.
Expand Down Expand Up @@ -663,8 +675,8 @@ def _broadcast_indexes_vectorized(self, key):

return out_dims, VectorizedIndexer(tuple(out_key)), new_order

def __getitem__(self, key):
"""Return a new Array object whose contents are consistent with
def __getitem__(self: VariableType, key) -> VariableType:
"""Return a new Variable object whose contents are consistent with
getting the provided key from the underlying data.

NB. __getitem__ and __setitem__ implement xarray-style indexing,
Expand All @@ -682,7 +694,7 @@ def __getitem__(self, key):
data = duck_array_ops.moveaxis(data, range(len(new_order)), new_order)
return self._finalize_indexing_result(dims, data)

def _finalize_indexing_result(self, dims, data):
def _finalize_indexing_result(self: VariableType, dims, data) -> VariableType:
"""Used by IndexVariable to return IndexVariable objects when possible.
"""
return type(self)(dims, data, self._attrs, self._encoding, fastpath=True)
Expand Down Expand Up @@ -957,7 +969,11 @@ def chunk(self, chunks=None, name=None, lock=False):

return type(self)(self.dims, data, self._attrs, self._encoding, fastpath=True)

def isel(self, indexers=None, drop=False, **indexers_kwargs):
def isel(
self: VariableType,
indexers: Mapping[Hashable, Any] = None,
**indexers_kwargs: Any
) -> VariableType:
crusaderky marked this conversation as resolved.
Show resolved Hide resolved
"""Return a new array indexed along the specified dimension(s).

Parameters
Expand All @@ -976,15 +992,12 @@ def isel(self, indexers=None, drop=False, **indexers_kwargs):
"""
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")

invalid = [k for k in indexers if k not in self.dims]
invalid = indexers.keys() - set(self.dims)
if invalid:
raise ValueError("dimensions %r do not exist" % invalid)

key = [slice(None)] * self.ndim
for i, dim in enumerate(self.dims):
if dim in indexers:
key[i] = indexers[dim]
return self[tuple(key)]
key = tuple(indexers.get(dim, slice(None)) for dim in self.dims)
return self[key]

def squeeze(self, dim=None):
"""Return a new object with squeezed data.
Expand Down