Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor index vs. coordinate variable(s) #5636

Merged
merged 14 commits into from
Aug 9, 2021
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ Documentation
Internal Changes
~~~~~~~~~~~~~~~~

- Explicit indexes refactor: decouple ``xarray.Index`` from ``xarray.Variable`` (:pull:`5636`).
By `Benoit Bovy <https://github.com/benbovy>`_.

.. _whats-new.0.19.0:

v0.19.0 (23 July 2021)
Expand Down
46 changes: 30 additions & 16 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pandas as pd

from . import dtypes
from .indexes import Index, PandasIndex, get_indexer_nd, wrap_pandas_index
from .indexes import Index, PandasIndex, get_indexer_nd
from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str, safe_cast_to_index
from .variable import IndexVariable, Variable

Expand Down Expand Up @@ -53,7 +53,10 @@ def _get_joiner(join, index_cls):
def _override_indexes(objects, all_indexes, exclude):
for dim, dim_indexes in all_indexes.items():
if dim not in exclude:
lengths = {index.size for index in dim_indexes}
lengths = {
getattr(index, "size", index.to_pandas_index().size)
for index in dim_indexes
}
if len(lengths) != 1:
raise ValueError(
f"Indexes along dimension {dim!r} don't have the same length."
Expand Down Expand Up @@ -300,16 +303,14 @@ def align(
joined_indexes = {}
for dim, matching_indexes in all_indexes.items():
if dim in indexes:
# TODO: benbovy - flexible indexes. maybe move this logic in util func
if isinstance(indexes[dim], Index):
index = indexes[dim]
else:
index = PandasIndex(safe_cast_to_index(indexes[dim]))
index, _ = PandasIndex.from_pandas_index(
safe_cast_to_index(indexes[dim]), dim
)
if (
any(not index.equals(other) for other in matching_indexes)
or dim in unlabeled_dim_sizes
):
joined_indexes[dim] = index
joined_indexes[dim] = indexes[dim]
else:
if (
any(
Expand All @@ -323,17 +324,18 @@ def align(
joiner = _get_joiner(join, type(matching_indexes[0]))
index = joiner(matching_indexes)
# make sure str coords are not cast to object
index = maybe_coerce_to_str(index, all_coords[dim])
index = maybe_coerce_to_str(index.to_pandas_index(), all_coords[dim])
joined_indexes[dim] = index
else:
index = all_coords[dim][0]

if dim in unlabeled_dim_sizes:
unlabeled_sizes = unlabeled_dim_sizes[dim]
# TODO: benbovy - flexible indexes: expose a size property for xarray.Index?
# Some indexes may not have a defined size (e.g., built from multiple coords of
# different sizes)
labeled_size = index.size
# TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647
if isinstance(index, PandasIndex):
labeled_size = index.to_pandas_index().size
else:
labeled_size = index.size
if len(unlabeled_sizes | {labeled_size}) > 1:
raise ValueError(
f"arguments without labels along dimension {dim!r} cannot be "
Expand All @@ -350,7 +352,14 @@ def align(

result = []
for obj in objects:
valid_indexers = {k: v for k, v in joined_indexes.items() if k in obj.dims}
# TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647
valid_indexers = {}
for k, index in joined_indexes.items():
if k in obj.dims:
if isinstance(index, Index):
valid_indexers[k] = index.to_pandas_index()
else:
valid_indexers[k] = index
if not valid_indexers:
# fast path for no reindexing necessary
new_obj = obj.copy(deep=copy)
Expand Down Expand Up @@ -471,7 +480,11 @@ def reindex_like_indexers(
ValueError
If any dimensions without labels have different sizes.
"""
indexers = {k: v for k, v in other.xindexes.items() if k in target.dims}
# TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647
# this doesn't support yet indexes other than pd.Index
indexers = {
k: v.to_pandas_index() for k, v in other.xindexes.items() if k in target.dims
}

for dim in other.dims:
if dim not in indexers and dim in target.dims:
Expand Down Expand Up @@ -560,7 +573,8 @@ def reindex_variables(
"from that to be indexed along {:s}".format(str(indexer.dims), dim)
)

target = new_indexes[dim] = wrap_pandas_index(safe_cast_to_index(indexers[dim]))
target = safe_cast_to_index(indexers[dim])
new_indexes[dim] = PandasIndex(target, dim)

if dim in indexes:
# TODO (benbovy - flexible indexes): support other indexes than pd.Index?
Expand Down
5 changes: 2 additions & 3 deletions xarray/core/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,8 @@ def _infer_concat_order_from_coords(datasets):
"inferring concatenation order"
)

# TODO (benbovy, flexible indexes): all indexes should be Pandas.Index
# get pd.Index objects from Index objects
indexes = [index.array for index in indexes]
# TODO (benbovy, flexible indexes): support flexible indexes?
indexes = [index.to_pandas_index() for index in indexes]

# If dimension coordinate values are same on every dataset then
# should be leaving this dimension alone (it's just a "bystander")
Expand Down
22 changes: 5 additions & 17 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,7 @@
)
from .dataset import Dataset, split_indexes
from .formatting import format_item
from .indexes import (
Index,
Indexes,
default_indexes,
propagate_indexes,
wrap_pandas_index,
)
from .indexes import Index, Indexes, default_indexes, propagate_indexes
from .indexing import is_fancy_indexer
from .merge import PANDAS_TYPES, MergeError, _extract_indexes_from_coords
from .options import OPTIONS, _get_keep_attrs
Expand Down Expand Up @@ -473,15 +467,14 @@ def _overwrite_indexes(self, indexes: Mapping[Hashable, Any]) -> "DataArray":
return self
coords = self._coords.copy()
for name, idx in indexes.items():
coords[name] = IndexVariable(name, idx)
coords[name] = IndexVariable(name, idx.to_pandas_index())
obj = self._replace(coords=coords)

# switch from dimension to level names, if necessary
dim_names: Dict[Any, str] = {}
for dim, idx in indexes.items():
# TODO: benbovy - flexible indexes: update when MultiIndex has its own class
pd_idx = idx.array
if not isinstance(pd_idx, pd.MultiIndex) and pd_idx.name != dim:
pd_idx = idx.to_pandas_index()
if not isinstance(idx, pd.MultiIndex) and pd_idx.name != dim:
dim_names[dim] = idx.name
if dim_names:
obj = obj.rename(dim_names)
Expand Down Expand Up @@ -1046,12 +1039,7 @@ def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray:
if self._indexes is None:
indexes = self._indexes
else:
# TODO: benbovy: flexible indexes: support all xarray indexes (not just pandas.Index)
# xarray Index needs a copy method.
indexes = {
k: wrap_pandas_index(v.to_pandas_index().copy(deep=deep))
for k, v in self._indexes.items()
}
indexes = {k: v.copy(deep=deep) for k, v in self._indexes.items()}
return self._replace(variable, coords, indexes=indexes)

def __copy__(self) -> "DataArray":
Expand Down
52 changes: 32 additions & 20 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@
propagate_indexes,
remove_unused_levels_categories,
roll_index,
wrap_pandas_index,
)
from .indexing import is_fancy_indexer
from .merge import (
Expand Down Expand Up @@ -1184,7 +1183,7 @@ def _overwrite_indexes(self, indexes: Mapping[Any, Index]) -> "Dataset":
variables = self._variables.copy()
new_indexes = dict(self.xindexes)
for name, idx in indexes.items():
variables[name] = IndexVariable(name, idx)
variables[name] = IndexVariable(name, idx.to_pandas_index())
new_indexes[name] = idx
obj = self._replace(variables, indexes=new_indexes)

Expand Down Expand Up @@ -2474,6 +2473,10 @@ def sel(
pos_indexers, new_indexes = remap_label_indexers(
self, indexers=indexers, method=method, tolerance=tolerance
)
# TODO: benbovy - flexible indexes: also use variables returned by Index.query
# (temporary dirty fix).
new_indexes = {k: v[0] for k, v in new_indexes.items()}

result = self.isel(indexers=pos_indexers, drop=drop)
return result._overwrite_indexes(new_indexes)

Expand Down Expand Up @@ -3297,20 +3300,21 @@ def _rename_dims(self, name_dict):
return {name_dict.get(k, k): v for k, v in self.dims.items()}

def _rename_indexes(self, name_dict, dims_set):
# TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5645
if self._indexes is None:
return None
indexes = {}
for k, v in self.xindexes.items():
# TODO: benbovy - flexible indexes: make it compatible with any xarray Index
index = v.to_pandas_index()
for k, v in self.indexes.items():
new_name = name_dict.get(k, k)
if new_name not in dims_set:
continue
if isinstance(index, pd.MultiIndex):
new_names = [name_dict.get(k, k) for k in index.names]
indexes[new_name] = PandasMultiIndex(index.rename(names=new_names))
if isinstance(v, pd.MultiIndex):
new_names = [name_dict.get(k, k) for k in v.names]
indexes[new_name] = PandasMultiIndex(
v.rename(names=new_names), new_name
)
else:
indexes[new_name] = PandasIndex(index.rename(new_name))
indexes[new_name] = PandasIndex(v.rename(new_name), new_name)
return indexes

def _rename_all(self, name_dict, dims_dict):
Expand Down Expand Up @@ -3539,7 +3543,10 @@ def swap_dims(
if new_index.nlevels == 1:
# make sure index name matches dimension name
new_index = new_index.rename(k)
indexes[k] = wrap_pandas_index(new_index)
if isinstance(new_index, pd.MultiIndex):
indexes[k] = PandasMultiIndex(new_index, k)
else:
indexes[k] = PandasIndex(new_index, k)
else:
var = v.to_base_variable()
var.dims = dims
Expand Down Expand Up @@ -3812,7 +3819,7 @@ def reorder_levels(
raise ValueError(f"coordinate {dim} has no MultiIndex")
new_index = index.reorder_levels(order)
variables[dim] = IndexVariable(coord.dims, new_index)
indexes[dim] = PandasMultiIndex(new_index)
indexes[dim] = PandasMultiIndex(new_index, dim)

return self._replace(variables, indexes=indexes)

Expand Down Expand Up @@ -3840,7 +3847,7 @@ def _stack_once(self, dims, new_dim):
coord_names = set(self._coord_names) - set(dims) | {new_dim}

indexes = {k: v for k, v in self.xindexes.items() if k not in dims}
indexes[new_dim] = wrap_pandas_index(idx)
indexes[new_dim] = PandasMultiIndex(idx, new_dim)

return self._replace_with_new_dims(
variables, coord_names=coord_names, indexes=indexes
Expand Down Expand Up @@ -4029,8 +4036,9 @@ def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset":
variables[name] = var

for name, lev in zip(index.names, index.levels):
variables[name] = IndexVariable(name, lev)
indexes[name] = PandasIndex(lev)
idx, idx_vars = PandasIndex.from_pandas_index(lev, name)
variables[name] = idx_vars[name]
indexes[name] = idx

coord_names = set(self._coord_names) - {dim} | set(index.names)

Expand Down Expand Up @@ -4068,8 +4076,9 @@ def _unstack_full_reindex(
variables[name] = var

for name, lev in zip(new_dim_names, index.levels):
variables[name] = IndexVariable(name, lev)
indexes[name] = PandasIndex(lev)
idx, idx_vars = PandasIndex.from_pandas_index(lev, name)
variables[name] = idx_vars[name]
indexes[name] = idx

coord_names = set(self._coord_names) - {dim} | set(new_dim_names)

Expand Down Expand Up @@ -5839,10 +5848,13 @@ def diff(self, dim, n=1, label="upper"):

indexes = dict(self.xindexes)
if dim in indexes:
# TODO: benbovy - flexible indexes: check slicing of xarray indexes?
# or only allow this for pandas indexes?
index = indexes[dim].to_pandas_index()
indexes[dim] = PandasIndex(index[kwargs_new[dim]])
if isinstance(indexes[dim], PandasIndex):
# maybe optimize? (pandas index already indexed above with var.isel)
new_index = indexes[dim].index[kwargs_new[dim]]
if isinstance(new_index, pd.MultiIndex):
indexes[dim] = PandasMultiIndex(new_index, dim)
else:
indexes[dim] = PandasIndex(new_index, dim)

difference = self._replace_with_new_dims(variables, indexes=indexes)

Expand Down
Loading