Skip to content

Allow for consistenty ordered coords #4755

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
Dict,
Hashable,
Iterator,
List,
Mapping,
Sequence,
Set,
Tuple,
Union,
cast,
Expand Down Expand Up @@ -40,7 +40,7 @@ def __setitem__(self, key: Hashable, value: Any) -> None:
self.update({key: value})

@property
def _names(self) -> Set[Hashable]:
def _names(self) -> List[Hashable]:
raise NotImplementedError()

@property
Expand Down Expand Up @@ -195,7 +195,7 @@ def __init__(self, dataset: "Dataset"):
self._data = dataset

@property
def _names(self) -> Set[Hashable]:
def _names(self) -> List[Hashable]:
return self._data._coord_names

@property
Expand Down Expand Up @@ -229,13 +229,17 @@ def _update_coords(

# check for inconsistent state *before* modifying anything in-place
dims = calculate_dimensions(variables)
new_coord_names = set(coords)
for dim, size in dims.items():
if dim in variables:
new_coord_names.add(dim)
new_coord_names = list(coords)
for dim, _ in dims.items():
if dim in variables and dim not in new_coord_names:
new_coord_names.append(dim)

self._data._variables = variables
self._data._coord_names.update(new_coord_names)
self._data._coord_names += [
x for x in new_coord_names if x not in self._data._coord_names
]
# TODO: remove `sort` if possible
self._data._coord_names.sort()
self._data._dims = dims

# TODO(shoyer): once ._indexes is always populated by a dict, modify
Expand Down Expand Up @@ -276,8 +280,8 @@ def dims(self) -> Tuple[Hashable, ...]:
return self._data.dims

@property
def _names(self) -> Set[Hashable]:
return set(self._data._coords)
def _names(self) -> List[Hashable]:
return list(sorted(self._data._coords))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorted always returns a list:

Suggested change
return list(sorted(self._data._coords))
return sorted(self._data._coords)


def __getitem__(self, key: Hashable) -> "DataArray":
return self._data._getitem_coord(key)
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1896,7 +1896,7 @@ def reset_index(
DataArray.set_index
"""
coords, _ = split_indexes(
dims_or_levels, self._coords, set(), self._level_coords, drop=drop
dims_or_levels, self._coords, [], self._level_coords, drop=drop
)
return self._replace(coords=coords)

Expand Down
71 changes: 43 additions & 28 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,9 @@ def calculate_dimensions(variables: Mapping[Hashable, Variable]) -> Dict[Hashabl
def merge_indexes(
indexes: Mapping[Hashable, Union[Hashable, Sequence[Hashable]]],
variables: Mapping[Hashable, Variable],
coord_names: Set[Hashable],
coord_names: List[Hashable],
append: bool = False,
) -> Tuple[Dict[Hashable, Variable], Set[Hashable]]:
) -> Tuple[Dict[Hashable, Variable], List[Hashable]]:
"""Merge variables into multi-indexes.

Not public API. Used in Dataset and DataArray set_index
Expand Down Expand Up @@ -291,18 +291,18 @@ def merge_indexes(
if any(d in dims_to_replace for d in v.dims):
new_dims = [dims_to_replace.get(d, d) for d in v.dims]
new_variables[k] = v._replace(dims=new_dims)
new_coord_names = coord_names | set(vars_to_replace)
new_coord_names -= set(vars_to_remove)
new_coord_names = coord_names + [x for x in vars_to_replace if x not in coord_names]
new_coord_names = [x for x in new_coord_names if x not in vars_to_remove]
Comment on lines +294 to +295
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if we can come up with something more efficient (or if being efficient is really necessary), but this seems a bit wasteful (we iterate over set(vars_to_replace) - set(coord_names) twice). Maybe something like this:

Suggested change
new_coord_names = coord_names + [x for x in vars_to_replace if x not in coord_names]
new_coord_names = [x for x in new_coord_names if x not in vars_to_remove]
new_coord_names = [x for x in coord_names if x not in vars_to_remove]
new_coord_names += [
x
for x in vars_to_replace
if x not in coord_names and x not in vars_to_remove
]

?

return new_variables, new_coord_names


def split_indexes(
dims_or_levels: Union[Hashable, Sequence[Hashable]],
variables: Mapping[Hashable, Variable],
coord_names: Set[Hashable],
coord_names: List[Hashable],
level_coords: Mapping[Hashable, Hashable],
drop: bool = False,
) -> Tuple[Dict[Hashable, Variable], Set[Hashable]]:
) -> Tuple[Dict[Hashable, Variable], List[Hashable]]:
"""Extract (multi-)indexes (levels) as variables.

Not public API. Used in Dataset and DataArray reset_index
Expand Down Expand Up @@ -349,7 +349,8 @@ def split_indexes(
del new_variables[v]
new_variables.update(vars_to_replace)
new_variables.update(vars_to_create)
new_coord_names = (coord_names | set(vars_to_create)) - set(vars_to_remove)
new_coord_names = coord_names + [x for x in vars_to_create if x not in coord_names]
new_coord_names = [x for x in new_coord_names if x not in vars_to_remove]
Comment on lines +352 to +353
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here:

Suggested change
new_coord_names = coord_names + [x for x in vars_to_create if x not in coord_names]
new_coord_names = [x for x in new_coord_names if x not in vars_to_remove]
new_coord_names = [x for x in coord_names if x not in vars_to_remove]
new_coord_names += [
x
for x in vars_to_create
if x not in coord_names and x not in vars_to_remove
]


return new_variables, new_coord_names

Expand Down Expand Up @@ -634,7 +635,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords):

_attrs: Optional[Dict[Hashable, Any]]
_cache: Dict[str, Any]
_coord_names: Set[Hashable]
_coord_names: List[Hashable]
_dims: Dict[Hashable, int]
_encoding: Optional[Dict[Hashable, Any]]
_close: Optional[Callable[[], None]]
Expand Down Expand Up @@ -692,7 +693,8 @@ def __init__(
self._close = None
self._encoding = None
self._variables = variables
self._coord_names = coord_names
# TODO: can we remove `sorted` and let it be user-defined?
self._coord_names = sorted(coord_names)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if user-defined is a good idea for Dataset, see #4649. sorted would have the advantage of always being predictable.

self._dims = dims
self._indexes = indexes

Expand Down Expand Up @@ -1029,7 +1031,7 @@ def _construct_direct(
def _replace(
self,
variables: Dict[Hashable, Variable] = None,
coord_names: Set[Hashable] = None,
coord_names: List[Hashable] = None,
dims: Dict[Any, int] = None,
attrs: Union[Dict[Hashable, Any], None, Default] = _default,
indexes: Union[Dict[Any, pd.Index], None, Default] = _default,
Expand Down Expand Up @@ -1079,7 +1081,7 @@ def _replace(
def _replace_with_new_dims(
self,
variables: Dict[Hashable, Variable],
coord_names: set = None,
coord_names: List = None,
attrs: Union[Dict[Hashable, Any], None, Default] = _default,
indexes: Union[Dict[Hashable, pd.Index], None, Default] = _default,
inplace: bool = False,
Expand All @@ -1093,7 +1095,7 @@ def _replace_with_new_dims(
def _replace_vars_and_dims(
self,
variables: Dict[Hashable, Variable],
coord_names: set = None,
coord_names: List = None,
dims: Dict[Hashable, int] = None,
attrs: Union[Dict[Hashable, Any], None, Default] = _default,
inplace: bool = False,
Expand Down Expand Up @@ -1273,7 +1275,7 @@ def _copy_listed(self, names: Iterable[Hashable]) -> "Dataset":
the all relevant coordinates. Skips all validation.
"""
variables: Dict[Hashable, Variable] = {}
coord_names = set()
coord_names = []
indexes: Dict[Hashable, pd.Index] = {}

for name in names:
Expand All @@ -1284,8 +1286,12 @@ def _copy_listed(self, names: Iterable[Hashable]) -> "Dataset":
self._variables, name, self._level_coords, self.dims
)
variables[var_name] = var
if ref_name in self._coord_names or ref_name in self.dims:
coord_names.add(var_name)
if (
ref_name in self._coord_names
or ref_name in self.dims
and var_name not in coord_names
):
coord_names.append(var_name)
if (var_name,) == var.dims:
indexes[var_name] = var.to_index()

Expand All @@ -1302,7 +1308,8 @@ def _copy_listed(self, names: Iterable[Hashable]) -> "Dataset":

if set(self.variables[k].dims) <= needed_dims:
variables[k] = self._variables[k]
coord_names.add(k)
if k not in coord_names:
coord_names.append(k)
if k in self.indexes:
indexes[k] = self.indexes[k]

Expand Down Expand Up @@ -1442,7 +1449,8 @@ def __setitem__(self, key: Hashable, value) -> None:
def __delitem__(self, key: Hashable) -> None:
"""Remove a variable from this dataset."""
del self._variables[key]
self._coord_names.discard(key)
if key in self._coord_names:
self._coord_names.remove(key)
if key in self.indexes:
assert self._indexes is not None
del self._indexes[key]
Expand Down Expand Up @@ -1563,7 +1571,7 @@ def set_coords(self, names: "Union[Hashable, Iterable[Hashable]]") -> "Dataset":
names = list(names)
self._assert_all_in_dataset(names)
obj = self.copy()
obj._coord_names.update(names)
obj._coord_names += names
return obj

def reset_coords(
Expand All @@ -1587,7 +1595,7 @@ def reset_coords(
Dataset
"""
if names is None:
names = self._coord_names - set(self.dims)
names = [x for x in self._coord_names if x not in self.dims]
else:
if isinstance(names, str) or not isinstance(names, Iterable):
names = [names]
Expand All @@ -1600,7 +1608,7 @@ def reset_coords(
"cannot remove index coordinates with reset_coords: %s" % bad_coords
)
obj = self.copy()
obj._coord_names.difference_update(names)
obj._coord_names = [x for x in obj._coord_names if x not in names]
if drop:
for name in names:
del obj._variables[name]
Expand Down Expand Up @@ -2159,7 +2167,7 @@ def _isel_fancy(

variables[name] = new_var

coord_names = self._coord_names & variables.keys()
coord_names = [x for x in self._coord_names if x in variables.keys()]
selected = self._replace_with_new_dims(variables, coord_names, indexes)

# Extract coordinates from indexers
Expand Down Expand Up @@ -2720,8 +2728,8 @@ def _reindex(
fill_value=fill_value,
sparse=sparse,
)
coord_names = set(self._coord_names)
coord_names.update(indexers)
coord_names = self._coord_names
coord_names += [x for x in indexers if x not in coord_names]
return self._replace_with_new_dims(variables, coord_names, indexes=indexes)

def interp(
Expand Down Expand Up @@ -3238,7 +3246,7 @@ def swap_dims(
result_dims = {dims_dict.get(dim, dim) for dim in self.dims}

coord_names = self._coord_names.copy()
coord_names.update({dim for dim in dims_dict.values() if dim in self.variables})
coord_names += {dim for dim in dims_dict.values() if dim in self.variables}

variables: Dict[Hashable, Variable] = {}
indexes: Dict[Hashable, pd.Index] = {}
Expand Down Expand Up @@ -3347,7 +3355,7 @@ def expand_dims(
# value within the dim dict to the length of the iterable
# for later use.
variables[k] = xr.IndexVariable((k,), v)
coord_names.add(k)
coord_names += [k]
dim[k] = variables[k].size
elif isinstance(v, int):
pass # Do nothing if the dimensions value is just an int
Expand Down Expand Up @@ -3780,7 +3788,10 @@ def _unstack_full_reindex(
variables[name] = IndexVariable(name, lev)
indexes[name] = lev

coord_names = set(self._coord_names) - {dim} | set(new_dim_names)
coord_names = self._coord_names + [
x for x in new_dim_names if x not in self._coord_names
]
coord_names.remove(dim)

return self._replace_with_new_dims(
variables, coord_names=coord_names, indexes=indexes
Expand Down Expand Up @@ -3976,6 +3987,10 @@ def merge(
join=join,
fill_value=fill_value,
)
# FIXME: remove
assert merge_result._asdict()["coord_names"] == sorted(
merge_result._asdict()["coord_names"]
)
return self._replace(**merge_result._asdict())

def _assert_all_in_dataset(
Expand Down Expand Up @@ -4018,7 +4033,7 @@ def drop_vars(
self._assert_all_in_dataset(names)

variables = {k: v for k, v in self._variables.items() if k not in names}
coord_names = {k for k in self._coord_names if k in variables}
coord_names = [k for k in self._coord_names if k in variables]
indexes = {k: v for k, v in self.indexes.items() if k not in names}
return self._replace_with_new_dims(
variables, coord_names=coord_names, indexes=indexes
Expand Down Expand Up @@ -4723,7 +4738,7 @@ def reduce(
**kwargs,
)

coord_names = {k for k in self.coords if k in variables}
coord_names = [k for k in self.coords if k in variables]
indexes = {k: v for k, v in self.indexes.items() if k in variables}
attrs = self.attrs if keep_attrs else None
return self._replace_with_new_dims(
Expand Down
Loading