Skip to content

(scipy 2022 branch) Add an "options" argument to Index.from_variables() #6800

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2233,7 +2233,12 @@ def reset_index(
ds = self._to_temp_dataset().reset_index(dims_or_levels, drop=drop)
return self._from_temp_dataset(ds)

def set_xindex(self, coord_names, index_cls, **kwargs):
def set_xindex(
self: T_DataArray,
coord_names: Hashable | Sequence[Hashable],
index_cls: type[Index],
**options,
) -> T_DataArray:
"""Temporary API for creating and setting a new, custom index from
existing coordinate(s).

Expand All @@ -2244,11 +2249,11 @@ def set_xindex(self, coord_names, index_cls, **kwargs):
If several names are given, their order matters.
index_cls : class
Xarray index subclass.
**kwargs
Options passed to the index constructor. Not working for now
(not sure yet how to do it).
**options
Options passed to the index constructor.

"""
ds = self._to_temp_dataset().set_xindex(coord_names, index_cls, **kwargs)
ds = self._to_temp_dataset().set_xindex(coord_names, index_cls, **options)
return self._from_temp_dataset(ds)

def reorder_levels(
Expand Down
47 changes: 25 additions & 22 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4020,7 +4020,7 @@ def set_index(
f"dimension mismatch: try setting an index for dimension {dim!r} with "
f"variable {var_name!r} that has dimensions {var.dims}"
)
idx = PandasIndex.from_variables({dim: var})
idx = PandasIndex.from_variables({dim: var}, {})
idx_vars = idx.create_variables({var_name: var})
else:
if append:
Expand Down Expand Up @@ -4138,7 +4138,12 @@ def reset_index(

return self._replace(variables, coord_names=coord_names, indexes=indexes)

def set_xindex(self, coord_names, index_cls, **kwargs):
def set_xindex(
self: T_Dataset,
coord_names: Hashable | Sequence[Hashable],
index_cls: type[Index],
**options,
) -> T_Dataset:
"""Temporary API for creating and setting a new, custom index from
existing coordinate(s).

Expand All @@ -4149,9 +4154,8 @@ def set_xindex(self, coord_names, index_cls, **kwargs):
If several names are given, their order matters.
index_cls : class
Xarray index subclass.
**kwargs
Options passed to the index constructor. Not working for now
(not sure yet how to do it).
**options
Options passed to the index constructor.

"""
warnings.warn("This is temporary API to experiment with custom indexes")
Expand All @@ -4161,7 +4165,8 @@ def set_xindex(self, coord_names, index_cls, **kwargs):
f"{index_cls} is not a subclass of xarray.core.indexes.Index"
)

if isinstance(coord_names, str):
# the Sequence check is required for mypy
if is_scalar(coord_names) or not isinstance(coord_names, Sequence):
coord_names = [coord_names]

invalid_coords = set(coord_names) - self._coord_names
Expand All @@ -4180,11 +4185,11 @@ def set_xindex(self, coord_names, index_cls, **kwargs):
f"those coordinates already have an index: {indexed_coords}"
)

coord_vars = {k: self._variables[k] for k in coord_names}
coord_vars = {name: self._variables[name] for name in coord_names}

# note: extra checks (e.g., all coordinates must have the same dimension(s))
# should be done in the implementation of Index.from_variables
index = index_cls.from_variables(coord_vars)
index = index_cls.from_variables(coord_vars, options)

# in case there are index coordinate variable wrappers
# (e.g., for PandasIndex we create coordinate variables that wrap pd.Index).
Expand All @@ -4193,21 +4198,19 @@ def set_xindex(self, coord_names, index_cls, **kwargs):

# reorder variables and indexes so that coordinates having the same index
# are next to each other
variables = {}
for k, v in self._variables.items():
if k not in coord_names:
variables[k] = v

for k in coord_names:
variables[k] = new_coord_vars.get(k, self._variables[k])
variables: dict[Hashable, Variable] = {}
for name, var in self._variables.items():
if name not in coord_names:
variables[name] = var

indexes = {}
for k, v in self._indexes.items():
if k not in coord_names:
indexes[k] = v
indexes: dict[Hashable, Index] = {}
for name, idx in self._indexes.items():
if name not in coord_names:
indexes[name] = idx

for k in coord_names:
indexes[k] = index
for name in coord_names:
variables[name] = new_coord_vars.get(name, self._variables[name])
indexes[name] = index

return self._construct_direct(
variables=variables,
Expand Down Expand Up @@ -7844,7 +7847,7 @@ def pad(
# reset default index of dimension coordinates
if (name,) == var.dims:
dim_var = {name: variables[name]}
index = PandasIndex.from_variables(dim_var)
index = PandasIndex.from_variables(dim_var, {})
index_vars = index.create_variables(dim_var)
indexes[name] = index
variables[name] = index_vars[name]
Expand Down
20 changes: 16 additions & 4 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ class Index:
"""Base class inherited by all xarray-compatible indexes."""

@classmethod
def from_variables(cls, variables: Mapping[Any, Variable]) -> Index:
def from_variables(
cls,
variables: Mapping[Any, Variable],
options: Mapping[str, Any],
) -> Index:
raise NotImplementedError()

@classmethod
Expand Down Expand Up @@ -247,7 +251,11 @@ def _replace(self, index, dim=None, coord_dtype=None):
return type(self)(index, dim, coord_dtype)

@classmethod
def from_variables(cls, variables: Mapping[Any, Variable]) -> PandasIndex:
def from_variables(
cls,
variables: Mapping[Any, Variable],
options: Mapping[str, Any],
) -> PandasIndex:
if len(variables) != 1:
raise ValueError(
f"PandasIndex only accepts one variable, found {len(variables)} variables"
Expand Down Expand Up @@ -570,7 +578,11 @@ def _replace(self, index, dim=None, level_coords_dtype=None) -> PandasMultiIndex
return type(self)(index, dim, level_coords_dtype)

@classmethod
def from_variables(cls, variables: Mapping[Any, Variable]) -> PandasMultiIndex:
def from_variables(
cls,
variables: Mapping[Any, Variable],
options: Mapping[str, Any],
) -> PandasMultiIndex:
_check_dim_compat(variables)
dim = next(iter(variables.values())).dims[0]

Expand Down Expand Up @@ -995,7 +1007,7 @@ def create_default_index_implicit(
)
else:
dim_var = {name: dim_variable}
index = PandasIndex.from_variables(dim_var)
index = PandasIndex.from_variables(dim_var, {})
index_vars = index.create_variables(dim_var)

return index, index_vars
Expand Down
18 changes: 10 additions & 8 deletions xarray/tests/test_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def index(self) -> CustomIndex:

def test_from_variables(self) -> None:
with pytest.raises(NotImplementedError):
Index.from_variables({})
Index.from_variables({}, {})

def test_concat(self) -> None:
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -132,27 +132,27 @@ def test_from_variables(self) -> None:
"x", data, attrs={"unit": "m"}, encoding={"dtype": np.float64}
)

index = PandasIndex.from_variables({"x": var})
index = PandasIndex.from_variables({"x": var}, {})
assert index.dim == "x"
assert index.index.equals(pd.Index(data))
assert index.coord_dtype == data.dtype

var2 = xr.Variable(("x", "y"), [[1, 2, 3], [4, 5, 6]])
with pytest.raises(ValueError, match=r".*only accepts one variable.*"):
PandasIndex.from_variables({"x": var, "foo": var2})
PandasIndex.from_variables({"x": var, "foo": var2}, {})

with pytest.raises(
ValueError, match=r".*only accepts a 1-dimensional variable.*"
):
PandasIndex.from_variables({"foo": var2})
PandasIndex.from_variables({"foo": var2}, {})

def test_from_variables_index_adapter(self) -> None:
# test index type is preserved when variable wraps a pd.Index
data = pd.Series(["foo", "bar"], dtype="category")
pd_idx = pd.Index(data)
var = xr.Variable("x", pd_idx)

index = PandasIndex.from_variables({"x": var})
index = PandasIndex.from_variables({"x": var}, {})
assert isinstance(index.index, pd.CategoricalIndex)

def test_concat_periods(self):
Expand Down Expand Up @@ -355,7 +355,7 @@ def test_from_variables(self) -> None:
)

index = PandasMultiIndex.from_variables(
{"level1": v_level1, "level2": v_level2}
{"level1": v_level1, "level2": v_level2}, {}
)

expected_idx = pd.MultiIndex.from_arrays([v_level1.data, v_level2.data])
Expand All @@ -368,13 +368,15 @@ def test_from_variables(self) -> None:
with pytest.raises(
ValueError, match=r".*only accepts 1-dimensional variables.*"
):
PandasMultiIndex.from_variables({"var": var})
PandasMultiIndex.from_variables({"var": var}, {})

v_level3 = xr.Variable("y", [4, 5, 6])
with pytest.raises(
ValueError, match=r"unmatched dimensions for multi-index variables.*"
):
PandasMultiIndex.from_variables({"level1": v_level1, "level3": v_level3})
PandasMultiIndex.from_variables(
{"level1": v_level1, "level3": v_level3}, {}
)

def test_concat(self) -> None:
pd_midx = pd.MultiIndex.from_product(
Expand Down