Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2x~5x speed up for isel() in most cases #3533

Merged
merged 9 commits into from
Dec 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ Bug fixes
- Fix plotting with transposed 2D non-dimensional coordinates. (:issue:`3138`, :pull:`3441`)
By `Deepak Cherian <https://github.com/dcherian>`_.


Documentation
~~~~~~~~~~~~~
- Switch doc examples to use nbsphinx and replace sphinx_gallery with
Expand All @@ -58,8 +57,10 @@ Documentation

Internal Changes
~~~~~~~~~~~~~~~~


- 2x to 5x speed boost (on small arrays) for :py:meth:`Dataset.isel`,
:py:meth:`DataArray.isel`, and :py:meth:`DataArray.__getitem__` when indexing by int,
slice, list of int, scalar ndarray, or 1-dimensional ndarray.
(:pull:`3533`) by `Guido Imperiale <https://github.com/crusaderky>`_.
- Removed internal method ``Dataset._from_vars_and_coord_names``,
which was dominated by ``Dataset._construct_direct``. (:pull:`3565`)
By `Maximilian Roos <https://github.com/max-sixty>`_
Expand Down Expand Up @@ -190,6 +191,7 @@ Documentation

Internal Changes
~~~~~~~~~~~~~~~~

- Added integration tests against `pint <https://pint.readthedocs.io/>`_.
(:pull:`3238`, :pull:`3447`, :pull:`3493`, :pull:`3508`)
by `Justus Magin <https://github.com/keewis>`_.
Expand Down
2 changes: 1 addition & 1 deletion xarray/coding/cftime_offsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

import re
from datetime import timedelta
from distutils.version import LooseVersion
from functools import partial
from typing import ClassVar, Optional

Expand All @@ -50,7 +51,6 @@
from ..core.pdcompat import count_not_none
from .cftimeindex import CFTimeIndex, _parse_iso8601_with_reso
from .times import format_cftime_datetime
from distutils.version import LooseVersion


def get_date_type(calendar):
Expand Down
26 changes: 23 additions & 3 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@
)
from .dataset import Dataset, split_indexes
from .formatting import format_item
from .indexes import Indexes, propagate_indexes, default_indexes
from .indexes import Indexes, default_indexes, propagate_indexes
from .indexing import is_fancy_indexer
from .merge import PANDAS_TYPES, _extract_indexes_from_coords
from .options import OPTIONS
from .utils import Default, ReprObject, _check_inplace, _default, either_dict_or_kwargs
Expand Down Expand Up @@ -1027,8 +1028,27 @@ def isel(
DataArray.sel
"""
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
ds = self._to_temp_dataset().isel(drop=drop, indexers=indexers)
return self._from_temp_dataset(ds)
if any(is_fancy_indexer(idx) for idx in indexers.values()):
ds = self._to_temp_dataset()._isel_fancy(indexers, drop=drop)
return self._from_temp_dataset(ds)

# Much faster algorithm for when all indexers are ints, slices, one-dimensional
# lists, or zero or one-dimensional np.ndarray's

variable = self._variable.isel(indexers)

coords = {}
for coord_name, coord_value in self._coords.items():
coord_indexers = {
k: v for k, v in indexers.items() if k in coord_value.dims
}
if coord_indexers:
coord_value = coord_value.isel(coord_indexers)
if drop and coord_value.ndim == 0:
continue
coords[coord_name] = coord_value

return self._replace(variable=variable, coords=coords)

def sel(
self,
Expand Down
45 changes: 44 additions & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
propagate_indexes,
roll_index,
)
from .indexing import is_fancy_indexer
from .merge import (
dataset_merge_method,
dataset_update_method,
Expand All @@ -78,8 +79,8 @@
Default,
Frozen,
SortedKeysDict,
_default,
_check_inplace,
_default,
decode_numpy_dict_values,
either_dict_or_kwargs,
hashable,
Expand Down Expand Up @@ -1907,6 +1908,48 @@ def isel(
DataArray.isel
"""
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
if any(is_fancy_indexer(idx) for idx in indexers.values()):
return self._isel_fancy(indexers, drop=drop)

# Much faster algorithm for when all indexers are ints, slices, one-dimensional
# lists, or zero or one-dimensional np.ndarray's
invalid = indexers.keys() - self.dims.keys()
if invalid:
raise ValueError("dimensions %r do not exist" % invalid)

variables = {}
dims: Dict[Hashable, Tuple[int, ...]] = {}
coord_names = self._coord_names.copy()
indexes = self._indexes.copy() if self._indexes is not None else None

for var_name, var_value in self._variables.items():
var_indexers = {k: v for k, v in indexers.items() if k in var_value.dims}
if var_indexers:
var_value = var_value.isel(var_indexers)
if drop and var_value.ndim == 0 and var_name in coord_names:
coord_names.remove(var_name)
if indexes:
indexes.pop(var_name, None)
continue
if indexes and var_name in indexes:
if var_value.ndim == 1:
indexes[var_name] = var_value.to_index()
else:
del indexes[var_name]
variables[var_name] = var_value
dims.update(zip(var_value.dims, var_value.shape))

return self._construct_direct(
variables=variables,
coord_names=coord_names,
dims=dims,
attrs=self._attrs,
indexes=indexes,
encoding=self._encoding,
file_obj=self._file_obj,
)

def _isel_fancy(self, indexers: Mapping[Hashable, Any], *, drop: bool) -> "Dataset":
# Note: we need to preserve the original indexers variable in order to merge the
# coords below
indexers_list = list(self._validate_indexers(indexers))
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/formatting_html.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import uuid
import pkg_resources
from collections import OrderedDict
from functools import partial
from html import escape

from .formatting import inline_variable_array_repr, short_data_repr
import pkg_resources

from .formatting import inline_variable_array_repr, short_data_repr

CSS_FILE_PATH = "/".join(("static", "css", "style.css"))
CSS_STYLE = pkg_resources.resource_string("xarray", CSS_FILE_PATH).decode("utf8")
Expand Down
13 changes: 13 additions & 0 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,19 @@ def posify_mask_indexer(indexer):
return type(indexer)(key)


def is_fancy_indexer(indexer: Any) -> bool:
"""Return False if indexer is a int, slice, a 1-dimensional list, or a 0 or
1-dimensional ndarray; in all other cases return True
"""
if isinstance(indexer, (int, slice)):
return False
if isinstance(indexer, np.ndarray):
return indexer.ndim > 1
if isinstance(indexer, list):
return bool(indexer) and not isinstance(indexer[0], int)
return True


class NumpyIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
"""Wrap a NumPy array to use explicit indexing."""

Expand Down
5 changes: 4 additions & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,10 @@ def _broadcast_indexes_outer(self, key):
k = k.data
if not isinstance(k, BASIC_INDEXING_TYPES):
k = np.asarray(k)
if k.dtype.kind == "b":
if k.size == 0:
# Slice by empty list; numpy could not infer the dtype
k = k.astype(int)
elif k.dtype.kind == "b":
(k,) = np.nonzero(k)
new_key.append(k)

Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from xarray.testing import assert_chunks_equal
from xarray.tests import mock

from ..core.duck_array_ops import lazy_array_equiv
from . import (
assert_allclose,
assert_array_equal,
Expand All @@ -25,7 +26,6 @@
raises_regex,
requires_scipy_or_netCDF4,
)
from ..core.duck_array_ops import lazy_array_equiv
from .test_backends import create_tmp_file

dask = pytest.importorskip("dask")
Expand Down
1 change: 0 additions & 1 deletion xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from xarray.core.common import full_like
from xarray.core.indexes import propagate_indexes
from xarray.core.utils import is_scalar

from xarray.tests import (
LooseVersion,
ReturnItem,
Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
NumpyInterpolator,
ScipyInterpolator,
SplineInterpolator,
get_clean_interp_index,
_get_nan_block_lengths,
get_clean_interp_index,
)
from xarray.core.pycompat import dask_array_type
from xarray.tests import (
Expand Down
22 changes: 22 additions & 0 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,26 @@ def test_items(self):
def test_getitem_basic(self):
v = self.cls(["x", "y"], [[0, 1, 2], [3, 4, 5]])

# int argument
v_new = v[0]
assert v_new.dims == ("y",)
assert_array_equal(v_new, v._data[0])

# slice argument
v_new = v[:2]
assert v_new.dims == ("x", "y")
assert_array_equal(v_new, v._data[:2])

# list arguments
v_new = v[[0]]
assert v_new.dims == ("x", "y")
assert_array_equal(v_new, v._data[[0]])

v_new = v[[]]
assert v_new.dims == ("x", "y")
assert_array_equal(v_new, v._data[[]])

# dict arguments
v_new = v[dict(x=0)]
assert v_new.dims == ("y",)
assert_array_equal(v_new, v._data[0])
Expand Down Expand Up @@ -1196,6 +1216,8 @@ def test_isel(self):
assert_identical(v.isel(time=0), v[0])
assert_identical(v.isel(time=slice(0, 3)), v[:3])
assert_identical(v.isel(x=0), v[:, 0])
assert_identical(v.isel(x=[0, 2]), v[:, [0, 2]])
assert_identical(v.isel(time=[]), v[[]])
with raises_regex(ValueError, "do not exist"):
v.isel(not_a_dim=0)

Expand Down