Skip to content

Commit

Permalink
Explicitly keep track of indexes with merging (#3234)
Browse files Browse the repository at this point in the history
* Explicitly keep track of indexes in merge.py

* Typing fixes

* More tying fixes

* more typing fixes

* fixup
  • Loading branch information
shoyer authored Oct 4, 2019
1 parent 86fb71d commit dfdeef7
Show file tree
Hide file tree
Showing 6 changed files with 352 additions and 280 deletions.
26 changes: 15 additions & 11 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def align(
all_indexes[dim].append(index)

if join == "override":
objects = _override_indexes(list(objects), all_indexes, exclude)
objects = _override_indexes(objects, all_indexes, exclude)

# We don't reindex over dimensions with all equal indexes for two reasons:
# - It's faster for the usual case (already aligned objects).
Expand Down Expand Up @@ -365,26 +365,27 @@ def is_alignable(obj):
targets = []
no_key = object()
not_replaced = object()
for n, variables in enumerate(objects):
for position, variables in enumerate(objects):
if is_alignable(variables):
positions.append(n)
positions.append(position)
keys.append(no_key)
targets.append(variables)
out.append(not_replaced)
elif is_dict_like(variables):
current_out = OrderedDict()
for k, v in variables.items():
if is_alignable(v) and k not in indexes:
# Skip variables in indexes for alignment, because these
# should to be overwritten instead:
# https://github.com/pydata/xarray/issues/725
positions.append(n)
if is_alignable(v):
positions.append(position)
keys.append(k)
targets.append(v)
out.append(OrderedDict(variables))
current_out[k] = not_replaced
else:
current_out[k] = v
out.append(current_out)
elif raise_on_invalid:
raise ValueError(
"object to align is neither an xarray.Dataset, "
"an xarray.DataArray nor a dictionary: %r" % variables
"an xarray.DataArray nor a dictionary: {!r}".format(variables)
)
else:
out.append(variables)
Expand All @@ -405,7 +406,10 @@ def is_alignable(obj):
out[position][key] = aligned_obj

# something went wrong: we should have replaced all sentinel values
assert all(arg is not not_replaced for arg in out)
for arg in out:
assert arg is not not_replaced
if is_dict_like(arg):
assert all(value is not not_replaced for value in arg.values())

return out

Expand Down
39 changes: 18 additions & 21 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@

from . import duck_array_ops, utils
from .alignment import deep_align
from .merge import expand_and_merge_variables
from .merge import merge_coordinates_without_align
from .pycompat import dask_array_type
from .utils import is_dict_like
from .variable import Variable

if TYPE_CHECKING:
from .coordinates import Coordinates # noqa
from .dataset import Dataset

_DEFAULT_FROZEN_SET = frozenset() # type: frozenset
Expand Down Expand Up @@ -152,17 +153,16 @@ def result_name(objects: list) -> Any:
return name


def _get_coord_variables(args):
input_coords = []
def _get_coords_list(args) -> List["Coordinates"]:
coords_list = []
for arg in args:
try:
coords = arg.coords
except AttributeError:
pass # skip this argument
else:
coord_vars = getattr(coords, "variables", coords)
input_coords.append(coord_vars)
return input_coords
coords_list.append(coords)
return coords_list


def build_output_coords(
Expand All @@ -185,32 +185,29 @@ def build_output_coords(
-------
OrderedDict of Variable objects with merged coordinates.
"""
input_coords = _get_coord_variables(args)
coords_list = _get_coords_list(args)

if exclude_dims:
input_coords = [
OrderedDict(
(k, v) for k, v in coord_vars.items() if exclude_dims.isdisjoint(v.dims)
)
for coord_vars in input_coords
]

if len(input_coords) == 1:
if len(coords_list) == 1 and not exclude_dims:
# we can skip the expensive merge
unpacked_input_coords, = input_coords
merged = OrderedDict(unpacked_input_coords)
unpacked_coords, = coords_list
merged_vars = OrderedDict(unpacked_coords.variables)
else:
merged = expand_and_merge_variables(input_coords)
# TODO: save these merged indexes, instead of re-computing them later
merged_vars, unused_indexes = merge_coordinates_without_align(
coords_list, exclude_dims=exclude_dims
)

output_coords = []
for output_dims in signature.output_core_dims:
dropped_dims = signature.all_input_core_dims - set(output_dims)
if dropped_dims:
filtered = OrderedDict(
(k, v) for k, v in merged.items() if dropped_dims.isdisjoint(v.dims)
(k, v)
for k, v in merged_vars.items()
if dropped_dims.isdisjoint(v.dims)
)
else:
filtered = merged
filtered = merged_vars
output_coords.append(filtered)

return output_coords
Expand Down
88 changes: 54 additions & 34 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
Hashable,
Iterator,
Mapping,
Sequence,
Set,
Sequence,
Tuple,
Union,
cast,
Expand All @@ -17,11 +17,7 @@

from . import formatting, indexing
from .indexes import Indexes
from .merge import (
expand_and_merge_variables,
merge_coords,
merge_coords_for_inplace_math,
)
from .merge import merge_coords, merge_coordinates_without_align
from .utils import Frozen, ReprObject, either_dict_or_kwargs
from .variable import Variable

Expand All @@ -34,7 +30,7 @@
_THIS_ARRAY = ReprObject("<this-array>")


class AbstractCoordinates(Mapping[Hashable, "DataArray"]):
class Coordinates(Mapping[Hashable, "DataArray"]):
__slots__ = ()

def __getitem__(self, key: Hashable) -> "DataArray":
Expand All @@ -57,10 +53,10 @@ def indexes(self) -> Indexes:

@property
def variables(self):
raise NotImplementedError()
raise NotImplementedError

def _update_coords(self, coords):
raise NotImplementedError()
def _update_coords(self, coords, indexes):
raise NotImplementedError

def __iter__(self) -> Iterator["Hashable"]:
# needs to be in the same order as the dataset variables
Expand Down Expand Up @@ -116,38 +112,38 @@ def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index:

def update(self, other: Mapping[Hashable, Any]) -> None:
other_vars = getattr(other, "variables", other)
coords = merge_coords(
coords, indexes = merge_coords(
[self.variables, other_vars], priority_arg=1, indexes=self.indexes
)
self._update_coords(coords)
self._update_coords(coords, indexes)

def _merge_raw(self, other):
"""For use with binary arithmetic."""
if other is None:
variables = OrderedDict(self.variables)
indexes = OrderedDict(self.indexes)
else:
# don't align because we already called xarray.align
variables = expand_and_merge_variables([self.variables, other.variables])
return variables
variables, indexes = merge_coordinates_without_align([self, other])
return variables, indexes

@contextmanager
def _merge_inplace(self, other):
"""For use with in-place binary arithmetic."""
if other is None:
yield
else:
# don't include indexes in priority_vars, because we didn't align
# first
priority_vars = OrderedDict(
kv for kv in self.variables.items() if kv[0] not in self.dims
)
variables = merge_coords_for_inplace_math(
[self.variables, other.variables], priority_vars=priority_vars
# don't include indexes in prioritized, because we didn't align
# first and we want indexes to be checked
prioritized = {
k: (v, None) for k, v in self.variables.items() if k not in self.indexes
}
variables, indexes = merge_coordinates_without_align(
[self, other], prioritized
)
yield
self._update_coords(variables)
self._update_coords(variables, indexes)

def merge(self, other: "AbstractCoordinates") -> "Dataset":
def merge(self, other: "Coordinates") -> "Dataset":
"""Merge two sets of coordinates to create a new Dataset
The method implements the logic used for joining coordinates in the
Expand All @@ -173,13 +169,19 @@ def merge(self, other: "AbstractCoordinates") -> "Dataset":

if other is None:
return self.to_dataset()
else:
other_vars = getattr(other, "variables", other)
coords = expand_and_merge_variables([self.variables, other_vars])
return Dataset._from_vars_and_coord_names(coords, set(coords))

if not isinstance(other, Coordinates):
other = Dataset(coords=other).coords

coords, indexes = merge_coordinates_without_align([self, other])
coord_names = set(coords)
merged = Dataset._construct_direct(
variables=coords, coord_names=coord_names, indexes=indexes
)
return merged


class DatasetCoordinates(AbstractCoordinates):
class DatasetCoordinates(Coordinates):
"""Dictionary like container for Dataset coordinates.
Essentially an immutable OrderedDict with keys given by the array's
Expand Down Expand Up @@ -218,7 +220,11 @@ def to_dataset(self) -> "Dataset":
"""
return self._data._copy_listed(self._names)

def _update_coords(self, coords: Mapping[Hashable, Any]) -> None:
def _update_coords(
self,
coords: "OrderedDict[Hashable, Variable]",
indexes: Mapping[Hashable, pd.Index],
) -> None:
from .dataset import calculate_dimensions

variables = self._data._variables.copy()
Expand All @@ -234,7 +240,12 @@ def _update_coords(self, coords: Mapping[Hashable, Any]) -> None:
self._data._variables = variables
self._data._coord_names.update(new_coord_names)
self._data._dims = dims
self._data._indexes = None

# TODO(shoyer): once ._indexes is always populated by a dict, modify
# it to update inplace instead.
original_indexes = OrderedDict(self._data.indexes)
original_indexes.update(indexes)
self._data._indexes = original_indexes

def __delitem__(self, key: Hashable) -> None:
if key in self:
Expand All @@ -251,7 +262,7 @@ def _ipython_key_completions_(self):
]


class DataArrayCoordinates(AbstractCoordinates):
class DataArrayCoordinates(Coordinates):
"""Dictionary like container for DataArray coordinates.
Essentially an OrderedDict with keys given by the array's
Expand All @@ -274,7 +285,11 @@ def _names(self) -> Set[Hashable]:
def __getitem__(self, key: Hashable) -> "DataArray":
return self._data._getitem_coord(key)

def _update_coords(self, coords) -> None:
def _update_coords(
self,
coords: "OrderedDict[Hashable, Variable]",
indexes: Mapping[Hashable, pd.Index],
) -> None:
from .dataset import calculate_dimensions

coords_plus_data = coords.copy()
Expand All @@ -285,7 +300,12 @@ def _update_coords(self, coords) -> None:
"cannot add coordinates with new dimensions to " "a DataArray"
)
self._data._coords = coords
self._data._indexes = None

# TODO(shoyer): once ._indexes is always populated by a dict, modify
# it to update inplace instead.
original_indexes = OrderedDict(self._data.indexes)
original_indexes.update(indexes)
self._data._indexes = original_indexes

@property
def variables(self):
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2519,7 +2519,7 @@ def func(self, other):
if not reflexive
else f(other_variable, self.variable)
)
coords = self.coords._merge_raw(other_coords)
coords, indexes = self.coords._merge_raw(other_coords)
name = self._result_name(other)

return self._replace(variable, coords, name)
Expand Down
Loading

0 comments on commit dfdeef7

Please sign in to comment.