Skip to content

Commit

Permalink
Refactor merge alignment into _align_for_merge; remove skip_single_ta…
Browse files Browse the repository at this point in the history
…rget

This is a cleaner fix for the issues from GH943

xref GH927
  • Loading branch information
shoyer committed Aug 8, 2016
1 parent 1c31332 commit b76a2f9
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 90 deletions.
63 changes: 0 additions & 63 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,23 +96,10 @@ def partial_align(*objects, **kwargs):
exclude = kwargs.pop('exclude', None)
if exclude is None:
exclude = set()
skip_single_target = kwargs.pop('skip_single_target', False)
if kwargs:
raise TypeError('align() got unexpected keyword arguments: %s'
% list(kwargs))

if len(objects) == 1:
obj, = objects
if (indexes is None or
(skip_single_target and
all(obj.indexes[k].equals(v) for k, v in indexes.items()
if k in obj.indexes))):
# We don't need to align, so don't bother with reindexing, which
# fails for non-unique indexes.
# `skip_single_target` is a hack so we can skip alignment of a
# single object in merge.
return (obj.copy() if copy else obj,)

joined_indexes = _join_indexes(join, objects, exclude=exclude)
if indexes is not None:
joined_indexes.update(indexes)
Expand All @@ -126,56 +113,6 @@ def partial_align(*objects, **kwargs):
return tuple(result)


def is_alignable(obj):
return hasattr(obj, 'indexes') and hasattr(obj, 'reindex')


def deep_align(list_of_variable_maps, join='outer', copy=True, indexes=None,
skip_single_target=False):
"""Align objects, recursing into dictionary values.
"""
if indexes is None:
indexes = {}

# We use keys to identify arguments to align. Integers indicate single
# arguments, while (int, variable_name) pairs indicate variables in ordered
# dictionaries.
keys = []
out = []
targets = []
sentinel = object()
for n, variables in enumerate(list_of_variable_maps):
if is_alignable(variables):
keys.append(n)
targets.append(variables)
out.append(sentinel)
elif is_dict_like(variables):
for k, v in variables.items():
if is_alignable(v) and k not in indexes:
# don't align dict-like variables that are already fixed
# indexes: we might be overwriting these index variables
keys.append((n, k))
targets.append(v)
out.append(OrderedDict(variables))
else:
out.append(variables)

aligned = partial_align(*targets, join=join, copy=copy, indexes=indexes,
skip_single_target=skip_single_target)

for key, aligned_obj in zip(keys, aligned):
if isinstance(key, tuple):
n, k = key
out[n][k] = aligned_obj
else:
out[key] = aligned_obj

# something went wrong: we should have replaced all sentinel values
assert all(arg is not sentinel for arg in out)

return out


def reindex_variables(variables, indexes, indexers, method=None,
tolerance=None, copy=True):
"""Conform a dictionary of aligned variables onto a new set of variables,
Expand Down
15 changes: 8 additions & 7 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from . import formatting
from .utils import Frozen
from .merge import merge_coords_only, align_and_merge_coords
from .merge import merge_coords, merge_coords_without_align
from .pycompat import iteritems, basestring, OrderedDict
from .variable import default_index_coordinate

Expand Down Expand Up @@ -62,9 +62,9 @@ def to_index(self, ordered_dims=None):

def update(self, other):
other_vars = getattr(other, 'variables', other)
coords = align_and_merge_coords([self.variables, other_vars],
priority_arg=1,
indexes=self.indexes)
coords = merge_coords([self.variables, other_vars],
priority_arg=1, indexes=self.indexes,
indexes_from_arg=0)
self._update_coords(coords)

def _merge_raw(self, other):
Expand All @@ -73,7 +73,8 @@ def _merge_raw(self, other):
variables = OrderedDict(self.variables)
else:
# don't align because we already called xarray.align
variables = merge_coords_only([self.variables, other.variables])
variables = merge_coords_without_align(
[self.variables, other.variables])
return variables

@contextmanager
Expand All @@ -86,7 +87,7 @@ def _merge_inplace(self, other):
# first
priority_vars = OrderedDict(
(k, v) for k, v in self.variables.items() if k not in self.dims)
variables = merge_coords_only(
variables = merge_coords_without_align(
[self.variables, other.variables], priority_vars=priority_vars)
yield
self._update_coords(variables)
Expand Down Expand Up @@ -119,7 +120,7 @@ def merge(self, other):
return self.to_dataset()
else:
other_vars = getattr(other, 'variables', other)
coords = merge_coords_only([self.variables, other_vars])
coords = merge_coords_without_align([self.variables, other_vars])
return Dataset._from_vars_and_coord_names(coords, set(coords))


Expand Down
102 changes: 82 additions & 20 deletions xarray/core/merge.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pandas as pd

from .alignment import deep_align
from .utils import Frozen
from .alignment import partial_align
from .utils import Frozen, is_dict_like
from .variable import as_variable, default_index_coordinate
from .pycompat import (basestring, OrderedDict)

Expand Down Expand Up @@ -173,18 +173,18 @@ def expand_variable_dicts(list_of_variable_dicts):
Returns
-------
A list of ordered dictionaries with all values given by those found on the
original dictionaries or coordinates, all of which are xarray.Variable
objects.
A list of ordered dictionaries corresponding to inputs, or coordinates from
an input's values. The values of each ordered dictionary are all
xarray.Variable objects.
"""
var_dicts = []

for variables in list_of_variable_dicts:
if hasattr(variables, 'variables'): # duck-type Dataset
sanitized_vars = variables.variables
else:
# append sanitized_vars before filling it up because we want coords
# to appear afterwards
# append coords to var_dicts before appending sanitized_vars,
# because we want coords to appear first
sanitized_vars = OrderedDict()

for name, var in variables.items():
Expand Down Expand Up @@ -271,7 +271,7 @@ def coerce_pandas_values(objects):
return out


def merge_coords_only(objs, priority_vars=None):
def merge_coords_without_align(objs, priority_vars=None):
"""Merge coordinate variables without worrying about alignment.
This function is used for merging variables in coordinates.py.
Expand All @@ -281,6 +281,61 @@ def merge_coords_only(objs, priority_vars=None):
return variables


def _align_for_merge(input_objects, join, copy, indexes=None,
indexes_from_arg=None):
"""Align objects for merging, recursing into dictionary values.
"""
if indexes is None:
indexes = {}

def is_alignable(obj):
return hasattr(obj, 'indexes') and hasattr(obj, 'reindex')

positions = []
keys = []
out = []
targets = []
no_key = object()
not_replaced = object()
for n, variables in enumerate(input_objects):
if is_alignable(variables):
positions.append(n)
keys.append(no_key)
targets.append(variables)
out.append(not_replaced)
else:
for k, v in variables.items():
if is_alignable(v) and k not in indexes:
# Skip variables in indexes for alignment, because these
# should to be overwritten instead:
# https://github.com/pydata/xarray/issues/725
positions.append(n)
keys.append(k)
targets.append(v)
out.append(OrderedDict(variables))

if not targets or (len(targets) == 1 and
(indexes is None or positions == [indexes_from_arg])):
# Don't align if we only have one object to align and it's already
# aligned to itself. This ensures it's possible to overwrite index
# coordinates with non-unique values, which cannot be reindexed:
# https://github.com/pydata/xarray/issues/943
return input_objects

aligned = partial_align(*targets, join=join, copy=copy, indexes=indexes)

for position, key, aligned_obj in zip(positions, keys, aligned):
if key is no_key:
out[position] = aligned_obj
else:
out[position][key] = aligned_obj

# something went wrong: we should have replaced all sentinel values
assert all(arg is not not_replaced for arg in out)

return out


def _get_priority_vars(objects, priority_arg, compat='equals'):
"""Extract the priority variable from a list of mappings.
Expand Down Expand Up @@ -310,13 +365,18 @@ def _get_priority_vars(objects, priority_arg, compat='equals'):
return priority_vars


def align_and_merge_coords(objs, compat='minimal', join='outer',
priority_arg=None, indexes=None):
"""Align and merge coordinate variables."""
def merge_coords(objs, compat='minimal', join='outer', priority_arg=None,
indexes=None, indexes_from_arg=False):
"""Merge coordinate variables.
See merge_core below for argument descriptions. This works similarly to
merge_core, except everything we don't worry about whether variables are
coordinates or not.
"""
_assert_compat_valid(compat)
coerced = coerce_pandas_values(objs)
aligned = deep_align(coerced, join=join, copy=False, indexes=indexes,
skip_single_target=True)
aligned = _align_for_merge(coerced, join=join, copy=False, indexes=indexes,
indexes_from_arg=indexes_from_arg)
expanded = expand_variable_dicts(aligned)
priority_vars = _get_priority_vars(aligned, priority_arg, compat=compat)
variables = merge_variables(expanded, priority_vars, compat=compat)
Expand All @@ -333,7 +393,7 @@ def merge_data_and_coords(data, coords, compat='broadcast_equals',


def merge_core(objs, compat='broadcast_equals', join='outer', priority_arg=None,
explicit_coords=None, indexes=None):
explicit_coords=None, indexes=None, indexes_from_arg=False):
"""Core logic for merging labeled objects.
This is not public API.
Expand All @@ -352,6 +412,10 @@ def merge_core(objs, compat='broadcast_equals', join='outer', priority_arg=None,
An explicit list of variables from `objs` that are coordinates.
indexes : dict, optional
Dictionary with values given by pandas.Index objects.
indexes_from_arg : boolean, optional
If indexes were provided, were these taken from one of the objects in
``objs``? This allows us to skip alignment if this object is the only
one to be aligned.
Returns
-------
Expand All @@ -371,8 +435,8 @@ def merge_core(objs, compat='broadcast_equals', join='outer', priority_arg=None,
_assert_compat_valid(compat)

coerced = coerce_pandas_values(objs)
aligned = deep_align(coerced, join=join, copy=False, indexes=indexes,
skip_single_target=True)
aligned = _align_for_merge(coerced, join=join, copy=False, indexes=indexes,
indexes_from_arg=indexes_from_arg)
expanded = expand_variable_dicts(aligned)

coord_names, noncoord_names = determine_coords(coerced)
Expand Down Expand Up @@ -488,7 +552,5 @@ def dataset_merge_method(dataset, other, overwrite_vars=frozenset(),

def dataset_update_method(dataset, other):
"""Guts of the Dataset.update method"""
objs = [dataset, other]
priority_arg = 1
indexes = dataset.indexes
return merge_core(objs, priority_arg=priority_arg, indexes=indexes)
return merge_core([dataset, other], priority_arg=1, indexes=dataset.indexes,
indexes_from_arg=0)

0 comments on commit b76a2f9

Please sign in to comment.