diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 8bcf15f8144..a6f6d2d7253 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -12,7 +12,7 @@ from .common import _maybe_promote from .indexing import get_indexer from .pycompat import iteritems, OrderedDict, suppress -from .utils import is_full_slice +from .utils import is_full_slice, is_dict_like from .variable import Variable, IndexVariable @@ -29,8 +29,12 @@ def _get_joiner(join): raise ValueError('invalid value for join: %s' % join) +_DEFAULT_EXCLUDE = frozenset() + + def align(*objects, **kwargs): - """align(*objects, join='inner', copy=True) + """align(*objects, join='inner', copy=True, indexes=None, + exclude=frozenset()) Given any number of Dataset and/or DataArray objects, returns new objects with aligned indexes and dimension sizes. @@ -76,15 +80,18 @@ def align(*objects, **kwargs): join = kwargs.pop('join', 'inner') copy = kwargs.pop('copy', True) indexes = kwargs.pop('indexes', None) - exclude = kwargs.pop('exclude', None) + exclude = kwargs.pop('exclude', _DEFAULT_EXCLUDE) if indexes is None: indexes = {} - if exclude is None: - exclude = set() if kwargs: raise TypeError('align() got unexpected keyword arguments: %s' % list(kwargs)) + if not indexes and len(objects) == 1: + # fast path for the trivial case + obj, = objects + return (obj.copy(deep=copy),) + all_indexes = defaultdict(list) unlabeled_dim_sizes = defaultdict(set) for obj in objects: @@ -142,11 +149,72 @@ def align(*objects, **kwargs): for obj in objects: valid_indexers = {k: v for k, v in joined_indexes.items() if k in obj.dims} - result.append(obj.reindex(copy=copy, **valid_indexers)) + if not valid_indexers: + # fast path for no reindexing necessary + new_obj = obj.copy(deep=copy) + else: + new_obj = obj.reindex(copy=copy, **valid_indexers) + result.append(new_obj) return tuple(result) +def deep_align(objects, join='inner', copy=True, indexes=None, + exclude=frozenset(), raise_on_invalid=True): + """Align objects for merging, recursing into dictionary values. + + This function is not public API. + """ + if indexes is None: + indexes = {} + + def is_alignable(obj): + return hasattr(obj, 'indexes') and hasattr(obj, 'reindex') + + positions = [] + keys = [] + out = [] + targets = [] + no_key = object() + not_replaced = object() + for n, variables in enumerate(objects): + if is_alignable(variables): + positions.append(n) + keys.append(no_key) + targets.append(variables) + out.append(not_replaced) + elif is_dict_like(variables): + for k, v in variables.items(): + if is_alignable(v) and k not in indexes: + # Skip variables in indexes for alignment, because these + # should to be overwritten instead: + # https://github.com/pydata/xarray/issues/725 + positions.append(n) + keys.append(k) + targets.append(v) + out.append(OrderedDict(variables)) + elif raise_on_invalid: + raise ValueError('object to align is neither an xarray.Dataset, ' + 'an xarray.DataArray nor a dictionary: %r' + % variables) + else: + out.append(variables) + + aligned = align(*targets, join=join, copy=copy, indexes=indexes, + exclude=exclude) + + for position, key, aligned_obj in zip(positions, keys, aligned): + if key is no_key: + out[position] = aligned_obj + else: + out[position][key] = aligned_obj + + # something went wrong: we should have replaced all sentinel values + assert all(arg is not not_replaced for arg in out) + + return out + + def reindex_like_indexers(target, other): """Extract indexers to align target with other. diff --git a/xarray/core/computation.py b/xarray/core/computation.py new file mode 100644 index 00000000000..675c5cbe9f6 --- /dev/null +++ b/xarray/core/computation.py @@ -0,0 +1,713 @@ +"""Functions for applying functions that act on arrays to xarray's labeled data. + +NOT PUBLIC API. +""" +import collections +import functools +import itertools +import operator +import re + +from . import ops +from .alignment import deep_align +from .merge import expand_and_merge_variables +from .pycompat import OrderedDict, basestring, dask_array_type +from .utils import is_dict_like + + +_DEFAULT_FROZEN_SET = frozenset() + +# see http://docs.scipy.org/doc/numpy/reference/c-api.generalized-ufuncs.html +DIMENSION_NAME = r'\w+' +CORE_DIMENSION_LIST = '(?:' + DIMENSION_NAME + '(?:,' + DIMENSION_NAME + ')*)?' +ARGUMENT = r'\(' + CORE_DIMENSION_LIST + r'\)' +ARGUMENT_LIST = ARGUMENT + '(?:,' + ARGUMENT + ')*' +SIGNATURE = '^' + ARGUMENT_LIST + '->' + ARGUMENT_LIST + '$' + + +def safe_tuple(x): + # type: Iterable -> tuple + if isinstance(x, basestring): + raise ValueError('cannot safely convert %r to a tuple') + return tuple(x) + + +class UFuncSignature(object): + """Core dimensions signature for a given function. + + Based on the signature provided by generalized ufuncs in NumPy. + + Attributes + ---------- + input_core_dims : list of tuples + A list of tuples of core dimension names on each input variable. + output_core_dims : list of tuples + A list of tuples of core dimension names on each output variable. + """ + def __init__(self, input_core_dims, output_core_dims=((),)): + self.input_core_dims = tuple(safe_tuple(a) for a in input_core_dims) + self.output_core_dims = tuple(safe_tuple(a) for a in output_core_dims) + self._all_input_core_dims = None + self._all_output_core_dims = None + self._all_core_dims = None + + @property + def all_input_core_dims(self): + if self._all_input_core_dims is None: + self._all_input_core_dims = frozenset( + dim for dims in self.input_core_dims for dim in dims) + return self._all_input_core_dims + + @property + def all_output_core_dims(self): + if self._all_output_core_dims is None: + self._all_output_core_dims = frozenset( + dim for dims in self.output_core_dims for dim in dims) + return self._all_output_core_dims + + @property + def all_core_dims(self): + if self._all_core_dims is None: + self._all_core_dims = (self.all_input_core_dims | + self.all_output_core_dims) + return self._all_core_dims + + @property + def n_inputs(self): + return len(self.input_core_dims) + + @property + def n_outputs(self): + return len(self.output_core_dims) + + @classmethod + def default(cls, n_inputs): + return cls([()] * n_inputs, [()]) + + @classmethod + def from_sequence(cls, nested): + if (not isinstance(nested, collections.Sequence) or + not len(nested) == 2 or + any(not isinstance(arg_list, collections.Sequence) + for arg_list in nested) or + any(isinstance(arg, basestring) or + not isinstance(arg, collections.Sequence) + for arg_list in nested for arg in arg_list)): + raise TypeError('functions signatures not provided as a string ' + 'must be a triply nested sequence providing the ' + 'list of core dimensions for each variable, for ' + 'both input and output.') + return cls(*nested) + + @classmethod + def from_string(cls, string): + """Create a UFuncSignature object from a NumPy gufunc signature. + + Parameters + ---------- + string : str + Signature string, e.g., (m,n),(n,p)->(m,p). + """ + if not re.match(SIGNATURE, string): + raise ValueError('not a valid gufunc signature: {}'.format(string)) + return cls(*[[re.findall(DIMENSION_NAME, arg) + for arg in re.findall(ARGUMENT, arg_list)] + for arg_list in string.split('->')]) + + def __eq__(self, other): + try: + return (self.input_core_dims == other.input_core_dims and + self.output_core_dims == other.output_core_dims) + except AttributeError: + return False + + def __ne__(self, other): + return not self == other + + def __repr__(self): + return ('%s(%r, %r)' + % (type(self).__name__, + list(self.input_core_dims), + list(self.output_core_dims))) + + +def result_name(objects): + # type: List[object] -> Any + # use the same naming heuristics as pandas: + # https://github.com/blaze/blaze/issues/458#issuecomment-51936356 + names = {getattr(obj, 'name', None) for obj in objects} + names.discard(None) + if len(names) == 1: + name, = names + else: + name = None + return name + + +_REPEAT_NONE = itertools.repeat(None) + + +def _get_coord_variables(args): + input_coords = [] + for arg in args: + try: + coords = arg.coords + except AttributeError: + pass # skip this argument + else: + coord_vars = getattr(coords, 'variables', coords) + input_coords.append(coord_vars) + return input_coords + + +def build_output_coords( + args, # type: list + signature, # type: UFuncSignature + exclude_dims=frozenset(), # type: set +): + # type: (...) -> List[OrderedDict[Any, Variable]] + input_coords = _get_coord_variables(args) + + if exclude_dims: + input_coords = [OrderedDict((k, v) for k, v in coord_vars.items() + if exclude_dims.isdisjoint(v.dims)) + for coord_vars in input_coords] + + if len(input_coords) == 1: + # we can skip the expensive merge + unpacked_input_coords, = input_coords + merged = OrderedDict(unpacked_input_coords) + else: + merged = expand_and_merge_variables(input_coords) + + output_coords = [] + for output_dims in signature.output_core_dims: + dropped_dims = signature.all_input_core_dims - set(output_dims) + if dropped_dims: + filtered = OrderedDict((k, v) for k, v in merged.items() + if dropped_dims.isdisjoint(v.dims)) + else: + filtered = merged + output_coords.append(filtered) + + return output_coords + + +def apply_dataarray_ufunc(func, *args, **kwargs): + """apply_dataarray_ufunc(func, *args, signature, join='inner', + exclude_dims=frozenset()) + """ + from .dataarray import DataArray + + signature = kwargs.pop('signature') + join = kwargs.pop('join', 'inner') + exclude_dims = kwargs.pop('exclude_dims', _DEFAULT_FROZEN_SET) + if kwargs: + raise TypeError('apply_dataarray_ufunc() got unexpected keyword ' + 'arguments: %s' % list(kwargs)) + + if len(args) > 1: + args = deep_align(args, join=join, copy=False, exclude=exclude_dims, + raise_on_invalid=False) + + name = result_name(args) + result_coords = build_output_coords(args, signature, exclude_dims) + + data_vars = [getattr(a, 'variable', a) for a in args] + result_var = func(*data_vars) + + if signature.n_outputs > 1: + return tuple(DataArray(variable, coords, name=name, fastpath=True) + for variable, coords in zip(result_var, result_coords)) + else: + coords, = result_coords + return DataArray(result_var, coords, name=name, fastpath=True) + + +def ordered_set_union(all_keys): + # type: List[Iterable] -> Iterable + result_dict = OrderedDict() + for keys in all_keys: + for key in keys: + result_dict[key] = None + return result_dict.keys() + + +def ordered_set_intersection(all_keys): + # type: List[Iterable] -> Iterable + intersection = set(all_keys[0]) + for keys in all_keys[1:]: + intersection.intersection_update(keys) + return [key for key in all_keys[0] if key in intersection] + + +_JOINERS = { + 'inner': ordered_set_intersection, + 'outer': ordered_set_union, + 'left': operator.itemgetter(0), + 'right': operator.itemgetter(-1), +} + + +def join_dict_keys(objects, how='inner'): + # type: (Iterable[Union[Mapping, Any]], str) -> Iterable + joiner = _JOINERS[how] + all_keys = [obj.keys() for obj in objects if hasattr(obj, 'keys')] + return joiner(all_keys) + + +def collect_dict_values(objects, keys, fill_value=None): + # type: (Iterable[Union[Mapping, Any]], Iterable, Any) -> List[list] + return [[obj.get(key, fill_value) + if is_dict_like(obj) + else obj + for obj in objects] + for key in keys] + + +def _as_variables_or_variable(arg): + try: + return arg.variables + except AttributeError: + try: + return arg.variable + except AttributeError: + return arg + + +def _unpack_dict_tuples( + result_vars, # type: Mapping[Any, Tuple[Variable]] + n_outputs, # type: int +): + # type: (...) -> Tuple[Dict[Any, Variable]] + out = tuple(OrderedDict() for _ in range(n_outputs)) + for name, values in result_vars.items(): + for value, results_dict in zip(values, out): + results_dict[name] = value + return out + + +def apply_dict_of_variables_ufunc(func, *args, **kwargs): + """apply_dict_of_variables_ufunc(func, *args, signature, join='inner', + fill_value=None): + """ + signature = kwargs.pop('signature') + join = kwargs.pop('join', 'inner') + fill_value = kwargs.pop('fill_value', None) + if kwargs: + raise TypeError('apply_dict_of_variables_ufunc() got unexpected ' + 'keyword arguments: %s' % list(kwargs)) + + args = [_as_variables_or_variable(arg) for arg in args] + names = join_dict_keys(args, how=join) + grouped_by_name = collect_dict_values(args, names, fill_value) + + result_vars = OrderedDict() + for name, variable_args in zip(names, grouped_by_name): + result_vars[name] = func(*variable_args) + + if signature.n_outputs > 1: + return _unpack_dict_tuples(result_vars, signature.n_outputs) + else: + return result_vars + + +def _fast_dataset(variables, coord_variables): + # type: (OrderedDict[Any, Variable], Mapping[Any, Variable]) -> Dataset + """Create a dataset as quickly as possible. + + Beware: the `variables` OrderedDict is modified INPLACE. + """ + from .dataset import Dataset + variables.update(coord_variables) + coord_names = set(coord_variables) + return Dataset._from_vars_and_coord_names(variables, coord_names) + + +def apply_dataset_ufunc(func, *args, **kwargs): + """apply_dataset_ufunc(func, *args, signature, join='inner', + fill_value=None, exclude_dims=frozenset()): + """ + signature = kwargs.pop('signature') + join = kwargs.pop('join', 'inner') + fill_value = kwargs.pop('fill_value', None) + exclude_dims = kwargs.pop('exclude_dims', _DEFAULT_FROZEN_SET) + if kwargs: + raise TypeError('apply_dataset_ufunc() got unexpected keyword ' + 'arguments: %s' % list(kwargs)) + + if len(args) > 1: + args = deep_align(args, join=join, copy=False, exclude=exclude_dims, + raise_on_invalid=False) + + list_of_coords = build_output_coords(args, signature, exclude_dims) + + args = [getattr(arg, 'data_vars', arg) for arg in args] + result_vars = apply_dict_of_variables_ufunc( + func, *args, signature=signature, join=join, fill_value=fill_value) + + if signature.n_outputs > 1: + return tuple(_fast_dataset(*args) + for args in zip(result_vars, list_of_coords)) + else: + coord_vars, = list_of_coords + return _fast_dataset(result_vars, coord_vars) + + +def _iter_over_selections(obj, dim, values): + """Iterate over selections of an xarray object in the provided order.""" + from .groupby import _dummy_copy + + dummy = None + for value in values: + try: + obj_sel = obj.sel(**{dim: value}) + except (KeyError, IndexError): + if dummy is None: + dummy = _dummy_copy(obj) + obj_sel = dummy + yield obj_sel + + +def apply_groupby_ufunc(func, *args): + from .groupby import GroupBy, peek_at + from .variable import Variable + + groupbys = [arg for arg in args if isinstance(arg, GroupBy)] + if not groupbys: + raise ValueError('must have at least one groupby to iterate over') + first_groupby = groupbys[0] + if any(not first_groupby._group.equals(gb._group) for gb in groupbys[1:]): + raise ValueError('can only perform operations over multiple groupbys ' + 'at once if they are all grouped the same way') + + grouped_dim = first_groupby._group.name + unique_values = first_groupby._unique_coord.values + + iterators = [] + for arg in args: + if isinstance(arg, GroupBy): + iterator = (value for _, value in arg) + elif hasattr(arg, 'dims') and grouped_dim in arg.dims: + if isinstance(arg, Variable): + raise ValueError( + 'groupby operations cannot be performed with ' + 'xarray.Variable objects that share a dimension with ' + 'the grouped dimension') + iterator = _iter_over_selections(arg, grouped_dim, unique_values) + else: + iterator = itertools.repeat(arg) + iterators.append(iterator) + + applied = (func(*zipped_args) for zipped_args in zip(*iterators)) + applied_example, applied = peek_at(applied) + combine = first_groupby._combine + if isinstance(applied_example, tuple): + combined = tuple(combine(output) for output in zip(*applied)) + else: + combined = combine(applied) + return combined + + +def unified_dim_sizes(variables, exclude_dims=frozenset()): + # type: Iterable[Variable] -> OrderedDict[Any, int] + dim_sizes = OrderedDict() + + for var in variables: + if len(set(var.dims)) < len(var.dims): + raise ValueError('broadcasting cannot handle duplicate ' + 'dimensions: %r' % list(var.dims)) + for dim, size in zip(var.dims, var.shape): + if dim not in exclude_dims: + if dim not in dim_sizes: + dim_sizes[dim] = size + elif dim_sizes[dim] != size: + raise ValueError('operands cannot be broadcast together ' + 'with mismatched lengths for dimension ' + '%r: %s vs %s' + % (dim, dim_sizes[dim], size)) + return dim_sizes + + +SLICE_NONE = slice(None) + +# A = TypeVar('A', numpy.ndarray, dask.array.Array) + + +def broadcast_compat_data(variable, broadcast_dims, core_dims): + # type: (Variable[A], tuple, tuple) -> A + data = variable.data + + old_dims = variable.dims + new_dims = broadcast_dims + core_dims + + if new_dims == old_dims: + # optimize for the typical case + return data + + set_old_dims = set(old_dims) + missing_core_dims = [d for d in core_dims if d not in set_old_dims] + if missing_core_dims: + raise ValueError('operation requires dimensions missing on input ' + 'variable: %r' % missing_core_dims) + + set_new_dims = set(new_dims) + unexpected_dims = [d for d in old_dims if d not in set_new_dims] + if unexpected_dims: + raise ValueError('operation encountered unexpected dimensions %r ' + 'on input variable: these are core dimensions on ' + 'other input or output variables' % unexpected_dims) + + # for consistency with numpy, keep broadcast dimensions to the left + old_broadcast_dims = tuple(d for d in broadcast_dims if d in set_old_dims) + reordered_dims = old_broadcast_dims + core_dims + if reordered_dims != old_dims: + order = tuple(old_dims.index(d) for d in reordered_dims) + data = ops.transpose(data, order) + + if new_dims != reordered_dims: + key = tuple(SLICE_NONE if dim in set_old_dims else None + for dim in new_dims) + data = data[key] + + return data + + +def apply_variable_ufunc(func, *args, **kwargs): + """apply_variable_ufunc(func, *args, signature, exclude_dims=frozenset()) + """ + from .variable import Variable + + signature = kwargs.pop('signature') + exclude_dims = kwargs.pop('exclude_dims', _DEFAULT_FROZEN_SET) + if kwargs: + raise TypeError('apply_variable_ufunc() got unexpected keyword ' + 'arguments: %s' % list(kwargs)) + + dim_sizes = unified_dim_sizes((a for a in args if hasattr(a, 'dims')), + exclude_dims=exclude_dims) + broadcast_dims = tuple(dim for dim in dim_sizes + if dim not in signature.all_core_dims) + output_dims = [broadcast_dims + out for out in signature.output_core_dims] + + input_data = [broadcast_compat_data(arg, broadcast_dims, core_dims) + if isinstance(arg, Variable) + else arg + for arg, core_dims in zip(args, signature.input_core_dims)] + + result_data = func(*input_data) + + if signature.n_outputs > 1: + output = [] + for dims, data in zip(output_dims, result_data): + output.append(Variable(dims, data)) + return tuple(output) + else: + dims, = output_dims + return Variable(dims, result_data) + + +def apply_array_ufunc(func, *args, **kwargs): + """apply_variable_ufunc(func, *args, dask_array='forbidden') + """ + dask_array = kwargs.pop('dask_array', 'forbidden') + if kwargs: + raise TypeError('apply_array_ufunc() got unexpected keyword ' + 'arguments: %s' % list(kwargs)) + + if any(isinstance(arg, dask_array_type) for arg in args): + # TODO: add a mode dask_array='auto' when dask.array gets a function + # for applying arbitrary gufuncs + if dask_array == 'forbidden': + raise ValueError('encountered dask array, but did not set ' + "dask_array='allowed'") + elif dask_array != 'allowed': + raise ValueError('unknown setting for dask array handling: %r' + % dask_array) + # fall through + return func(*args) + + +def apply_ufunc(func, *args, **kwargs): + """apply_ufunc(func, *args, signature=None, join='inner', + exclude_dims=frozenset(), dataset_fill_value=None, + kwargs=None, dask_array='forbidden') + + Apply a vectorized function for unlabeled arrays to xarray objects. + + The input arguments will be handled using xarray's standard rules for + labeled computation, including alignment, broadcasting, looping over + GroupBy/Dataset variables, and merging of coordinates. + + Parameters + ---------- + func : callable + Function to call like ``func(*args, **kwargs)`` on unlabeled arrays + (``.data``). If multiple arguments with non-matching dimensions are + supplied, this function is expected to vectorize (broadcast) over + axes of positional arguments in the style of NumPy universal + functions [1]_. + *args : Dataset, DataArray, GroupBy, Variable, numpy/dask arrays or scalars + Mix of labeled and/or unlabeled arrays to which to apply the function. + signature : string or triply nested sequence, optional + Object indicating core dimensions that should not be broadcast on + the input and outputs arguments. If omitted, inputs will be broadcast + to share all dimensions in common before calling ``func`` on their + values, and the output of ``func`` will be assumed to be a single array + with the same shape as the inputs. + + Two forms of signatures are accepted: + (a) A signature string of the form used by NumPy's generalized + universal functions [2]_, e.g., '(),(time)->()' indicating a + function that accepts two arguments and returns a single argument, + on which all dimensions should be broadcast except 'time' on the + second argument. + (a) A triply nested sequence providing lists of core dimensions for + each variable, for both input and output, e.g., + ``([(), ('time',)], [()])``. + + Core dimensions are automatically moved to the last axes of any input + variables, which facilitates using NumPy style generalized ufuncs (see + the examples below). + + Unlike the NumPy gufunc signature spec, the names of all dimensions + provided in signatures must be the names of actual dimensions on the + xarray objects. + join : {'outer', 'inner', 'left', 'right'}, optional + Method for joining the indexes of the passed objects along each + dimension, and the variables of Dataset objects with mismatched + data variables: + - 'outer': use the union of object indexes + - 'inner': use the intersection of object indexes + - 'left': use indexes from the first object with each dimension + - 'right': use indexes from the last object with each dimension + exclude_dims : set, optional + Dimensions to exclude from alignment and broadcasting. Any inputs + coordinates along these dimensions will be dropped. Each excluded + dimension must be a core dimension in the function signature. + dataset_fill_value : optional + Value used in place of missing variables on Dataset inputs when the + datasets do not share the exact same ``data_vars``. Only relevant if + ``join != 'inner'``. + kwargs: dict, optional + Optional keyword arguments passed directly on to call ``func``. + dask_array: 'forbidden' or 'allowed', optional + Whether or not to allow applying the ufunc to objects containing lazy + data in the form of dask arrays. By default, this is forbidden, to + avoid implicitly converting lazy data. + + Returns + ------- + Single value or tuple of Dataset, DataArray, Variable, dask.array.Array or + numpy.ndarray, the first type on that list to appear on an input. + + Examples + -------- + For illustrative purposes only, here are examples of how you could use + ``apply_ufunc`` to write functions to (very nearly) replicate existing + xarray functionality: + + Calculate the vector magnitude of two arguments: + + def magnitude(a, b): + func = lambda x, y: np.sqrt(x ** 2 + y ** 2) + return xr.apply_func(func, a, b) + + Compute the mean (``.mean``):: + + def mean(obj, dim): + # note: apply always moves core dimensions to the end + sig = ([(dim,)], [()]) + kwargs = {'axis': -1} + return apply_ufunc(np.mean, obj, signature=sig, kwargs=kwargs) + + Inner product over a specific dimension:: + + def _inner(x, y): + result = np.matmul(x[..., np.newaxis, :], y[..., :, np.newaxis]) + return result[..., 0, 0] + + def inner_product(a, b, dim): + sig = ([(dim,), (dim,)], [()]) + return apply_ufunc(_inner, a, b, signature=sig) + + Stack objects along a new dimension (like ``xr.concat``):: + + def stack(objects, dim, new_coord): + sig = ([()] * len(objects), [(dim,)]) + func = lambda *x: np.stack(x, axis=-1) + result = apply_ufunc(func, *objects, signature=sig, + join='outer', dataset_fill_value=np.nan) + result[dim] = new_coord + return result + + Most of NumPy's builtin functions already broadcast their inputs + appropriately for use in `apply`. You may find helper functions such as + numpy.broadcast_arrays or numpy.vectorize helpful in writing your function. + `apply_ufunc` also works well with numba's vectorize and guvectorize. + + See also + -------- + numpy.broadcast_arrays + numpy.vectorize + numba.vectorize + numba.guvectorize + + References + ---------- + .. [1] http://docs.scipy.org/doc/numpy/reference/ufuncs.html + .. [2] http://docs.scipy.org/doc/numpy/reference/c-api.generalized-ufuncs.html + """ + from .groupby import GroupBy + from .dataarray import DataArray + from .variable import Variable + + signature = kwargs.pop('signature', None) + join = kwargs.pop('join', 'inner') + exclude_dims = kwargs.pop('exclude_dims', frozenset()) + dataset_fill_value = kwargs.pop('dataset_fill_value', None) + kwargs_ = kwargs.pop('kwargs', None) + dask_array = kwargs.pop('dask_array', 'forbidden') + if kwargs: + raise TypeError('apply_ufunc() got unexpected keyword arguments: %s' + % list(kwargs)) + + if signature is None: + signature = UFuncSignature.default(len(args)) + elif isinstance(signature, basestring): + signature = UFuncSignature.from_string(signature) + elif not isinstance(signature, UFuncSignature): + signature = UFuncSignature.from_sequence(signature) + + if exclude_dims and not exclude_dims <= signature.all_core_dims: + raise ValueError('each dimension in `exclude_dims` must also be a ' + 'core dimension in the function signature') + + if kwargs_: + func = functools.partial(func, **kwargs_) + + array_ufunc = functools.partial( + apply_array_ufunc, func, dask_array=dask_array) + + variables_ufunc = functools.partial( + apply_variable_ufunc, array_ufunc, signature=signature, + exclude_dims=exclude_dims) + + if any(isinstance(a, GroupBy) for a in args): + this_apply = functools.partial( + apply_ufunc, func, signature=signature, join=join, + dask_array=dask_array, exclude_dims=exclude_dims, + dataset_fill_value=dataset_fill_value) + return apply_groupby_ufunc(this_apply, *args) + elif any(is_dict_like(a) for a in args): + return apply_dataset_ufunc(variables_ufunc, *args, signature=signature, + join=join, exclude_dims=exclude_dims, + fill_value=dataset_fill_value) + elif any(isinstance(a, DataArray) for a in args): + return apply_dataarray_ufunc(variables_ufunc, *args, + signature=signature, + join=join, exclude_dims=exclude_dims) + elif any(isinstance(a, Variable) for a in args): + return variables_ufunc(*args) + else: + return array_ufunc(*args) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 41027676aa4..426e213d627 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -7,7 +7,8 @@ from . import formatting from .utils import Frozen -from .merge import merge_coords, merge_coords_without_align +from .merge import ( + merge_coords, expand_and_merge_variables, merge_coords_for_inplace_math) from .pycompat import OrderedDict @@ -68,7 +69,7 @@ def _merge_raw(self, other): variables = OrderedDict(self.variables) else: # don't align because we already called xarray.align - variables = merge_coords_without_align( + variables = expand_and_merge_variables( [self.variables, other.variables]) return variables @@ -82,7 +83,7 @@ def _merge_inplace(self, other): # first priority_vars = OrderedDict( (k, v) for k, v in self.variables.items() if k not in self.dims) - variables = merge_coords_without_align( + variables = merge_coords_for_inplace_math( [self.variables, other.variables], priority_vars=priority_vars) yield self._update_coords(variables) @@ -115,7 +116,7 @@ def merge(self, other): return self.to_dataset() else: other_vars = getattr(other, 'variables', other) - coords = merge_coords_without_align([self.variables, other_vars]) + coords = expand_and_merge_variables([self.variables, other_vars]) return Dataset._from_vars_and_coord_names(coords, set(coords)) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 9d5618dfb7d..d7a3631ba6c 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -278,6 +278,11 @@ def __getitem__(self, key): def __unicode__(self): return formatting.vars_repr(self) + @property + def variables(self): + all_variables = self._dataset.variables + return Frozen(OrderedDict((k, all_variables[k]) for k in self)) + class _LocIndexer(object): def __init__(self, dataset): @@ -490,9 +495,9 @@ def _construct_direct(cls, variables, coord_names, dims=None, attrs=None, __default_attrs = object() @classmethod - def _from_vars_and_coord_names(cls, variables, coord_names): + def _from_vars_and_coord_names(cls, variables, coord_names, attrs=None): dims = dict(calculate_dimensions(variables)) - return cls._construct_direct(variables, coord_names, dims) + return cls._construct_direct(variables, coord_names, dims, attrs) def _replace_vars_and_dims(self, variables, coord_names=None, dims=None, attrs=__default_attrs, inplace=False): @@ -839,9 +844,11 @@ def reset_coords(self, names=None, drop=False, inplace=False): if isinstance(names, basestring): names = [names] self._assert_all_in_dataset(names) - _assert_empty( - set(names) & set(self.dims), - 'cannot remove index coordinates with reset_coords: %s') + bad_coords = set(names) & set(self.dims) + if bad_coords: + raise ValueError( + 'cannot remove index coordinates with reset_coords: %s' + % bad_coords) obj = self if inplace else self.copy() obj._coord_names.difference_update(names) if drop: @@ -1987,8 +1994,10 @@ def reduce(self, func, dim=None, keep_attrs=False, numeric_only=False, else: dims = set(dim) - _assert_empty([dim for dim in dims if dim not in self.dims], - 'Dataset does not contain the dimensions: %s') + missing_dimensions = [dim for dim in dims if dim not in self.dims] + if missing_dimensions: + raise ValueError('Dataset does not contain the dimensions: %s' + % missing_dimensions) variables = OrderedDict() for name, var in iteritems(self._variables): diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 5ea490a004b..7f02f72c6a9 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -3,10 +3,10 @@ from __future__ import print_function import pandas as pd -from .alignment import align +from .alignment import deep_align +from .pycompat import OrderedDict, basestring from .utils import Frozen -from .variable import (as_variable, assert_unique_multiindex_level_names) -from .pycompat import (basestring, OrderedDict) +from .variable import as_variable, assert_unique_multiindex_level_names PANDAS_TYPES = (pd.Series, pd.DataFrame, pd.Panel) @@ -121,23 +121,20 @@ def merge_variables( ---------- lists_of_variables_dicts : list of mappings with Variable values List of mappings for which each value is a xarray.Variable object. - priority_vars : mapping with Variable values, optional + priority_vars : mapping with Variable or None values, optional If provided, variables are always taken from this dict in preference to the input variable dictionaries, without checking for conflicts. compat : {'identical', 'equals', 'broadcast_equals', 'minimal', 'no_conflicts'}, optional - Type of equality check to use wben checking for conflicts. + Type of equality check to use when checking for conflicts. Returns ------- - OrderedDict keys given by the union of keys on list_of_variable_dicts - (unless compat=='minimal', in which case some variables with conflicting - values can be dropped), and Variable values corresponding to those that - should be found in the result. + OrderedDict with keys taken by the union of keys on list_of_variable_dicts, + and Variable values corresponding to those that should be found on the + merged result. """ if priority_vars is None: - # one of these arguments (e.g., the first for in-place - # arithmetic or the second for Dataset.update) takes priority priority_vars = {} _assert_compat_valid(compat) @@ -154,6 +151,8 @@ def merge_variables( for name, variables in lookup.items(): if name in priority_vars: + # one of these arguments (e.g., the first for in-place arithmetic + # or the second for Dataset.update) takes priority merged[name] = priority_vars[name] else: dim_variables = [var for var in variables if (name,) == var.dims] @@ -287,7 +286,7 @@ def coerce_pandas_values(objects): return out -def merge_coords_without_align(objs, priority_vars=None): +def merge_coords_for_inplace_math(objs, priority_vars=None): """Merge coordinate variables without worrying about alignment. This function is used for merging variables in coordinates.py. @@ -298,58 +297,12 @@ def merge_coords_without_align(objs, priority_vars=None): return variables -def _align_for_merge(input_objects, join, copy, indexes=None): - """Align objects for merging, recursing into dictionary values. - """ - if indexes is None: - indexes = {} - - def is_alignable(obj): - return hasattr(obj, 'indexes') and hasattr(obj, 'reindex') - - positions = [] - keys = [] - out = [] - targets = [] - no_key = object() - not_replaced = object() - for n, variables in enumerate(input_objects): - if is_alignable(variables): - positions.append(n) - keys.append(no_key) - targets.append(variables) - out.append(not_replaced) - else: - for k, v in variables.items(): - if is_alignable(v) and k not in indexes: - # Skip variables in indexes for alignment, because these - # should to be overwritten instead: - # https://github.com/pydata/xarray/issues/725 - positions.append(n) - keys.append(k) - targets.append(v) - out.append(OrderedDict(variables)) - - aligned = align(*targets, join=join, copy=copy, indexes=indexes) - - for position, key, aligned_obj in zip(positions, keys, aligned): - if key is no_key: - out[position] = aligned_obj - else: - out[position][key] = aligned_obj - - # something went wrong: we should have replaced all sentinel values - assert all(arg is not not_replaced for arg in out) - - return out - - def _get_priority_vars(objects, priority_arg, compat='equals'): """Extract the priority variable from a list of mappings. - We need this method because in some cases the priority argument itself might - have conflicting values (e.g., if it is a dict with two DataArray values - with conflicting coordinate values). + We need this method because in some cases the priority argument itself + might have conflicting values (e.g., if it is a dict with two DataArray + values with conflicting coordinate values). Parameters ---------- @@ -367,13 +320,24 @@ def _get_priority_vars(objects, priority_arg, compat='equals'): values indicating priority variables. """ if priority_arg is None: - priority_vars = None + priority_vars = {} else: expanded = expand_variable_dicts([objects[priority_arg]]) priority_vars = merge_variables(expanded, compat=compat) return priority_vars +def expand_and_merge_variables(objs, priority_arg=None): + """Merge coordinate variables without worrying about alignment. + + This function is used for merging variables in computation.py. + """ + expanded = expand_variable_dicts(objs) + priority_vars = _get_priority_vars(objs, priority_arg) + variables = merge_variables(expanded, priority_vars) + return variables + + def merge_coords(objs, compat='minimal', join='outer', priority_arg=None, indexes=None): """Merge coordinate variables. @@ -384,7 +348,7 @@ def merge_coords(objs, compat='minimal', join='outer', priority_arg=None, """ _assert_compat_valid(compat) coerced = coerce_pandas_values(objs) - aligned = _align_for_merge(coerced, join=join, copy=False, indexes=indexes) + aligned = deep_align(coerced, join=join, copy=False, indexes=indexes) expanded = expand_variable_dicts(aligned) priority_vars = _get_priority_vars(aligned, priority_arg, compat=compat) variables = merge_variables(expanded, priority_vars, compat=compat) @@ -445,7 +409,7 @@ def merge_core(objs, _assert_compat_valid(compat) coerced = coerce_pandas_values(objs) - aligned = _align_for_merge(coerced, join=join, copy=False, indexes=indexes) + aligned = deep_align(coerced, join=join, copy=False, indexes=indexes) expanded = expand_variable_dicts(aligned) coord_names, noncoord_names = determine_coords(coerced) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 32c26bd02c1..dffacaf2eb0 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -158,7 +158,7 @@ def remove_incompatible_items(first_dict, second_dict, compat=equivalent): def is_dict_like(value): - return hasattr(value, '__getitem__') and hasattr(value, 'keys') + return hasattr(value, 'keys') and hasattr(value, '__getitem__') def is_full_slice(value): diff --git a/xarray/test/test_computation.py b/xarray/test/test_computation.py new file mode 100644 index 00000000000..db9e51caac8 --- /dev/null +++ b/xarray/test/test_computation.py @@ -0,0 +1,525 @@ +import functools +import operator +from collections import OrderedDict + +import numpy as np +from numpy.testing import assert_array_equal + +import pytest + +import xarray as xr +from xarray.core.computation import ( + UFuncSignature, broadcast_compat_data, collect_dict_values, + join_dict_keys, ordered_set_intersection, ordered_set_union, + unified_dim_sizes, apply_ufunc) + +from . import requires_dask + + +def assert_identical(a, b): + if hasattr(a, 'identical'): + msg = 'not identical:\n%r\n%r' % (a, b) + assert a.identical(b), msg + else: + assert_array_equal(a, b) + + +def test_parse_signature(): + assert (UFuncSignature([['x']]) == + UFuncSignature.from_string('(x)->()')) + assert (UFuncSignature([['x', 'y']]) == + UFuncSignature.from_string('(x,y)->()')) + assert (UFuncSignature([['x'], ['y']]) == + UFuncSignature.from_string('(x),(y)->()')) + assert (UFuncSignature([['x']], [['y'], []]) == + UFuncSignature.from_string('(x)->(y),()')) + with pytest.raises(ValueError): + UFuncSignature.from_string('(x)(y)->()') + with pytest.raises(ValueError): + UFuncSignature.from_string('(x),(y)->') + with pytest.raises(ValueError): + UFuncSignature.from_string('((x))->(x)') + + +def test_signature_properties(): + sig = UFuncSignature.from_string('(x),(x,y)->(z)') + assert sig.input_core_dims == (('x',), ('x', 'y')) + assert sig.output_core_dims == (('z',),) + assert sig.all_input_core_dims == frozenset(['x', 'y']) + assert sig.all_output_core_dims == frozenset(['z']) + assert sig.n_inputs == 2 + assert sig.n_outputs == 1 + # dimension names matter + assert UFuncSignature([['x']]) != UFuncSignature([['y']]) + + +def test_ordered_set_union(): + assert list(ordered_set_union([[1, 2]])) == [1, 2] + assert list(ordered_set_union([[1, 2], [2, 1]])) == [1, 2] + assert list(ordered_set_union([[0], [1, 2], [1, 3]])) == [0, 1, 2, 3] + + +def test_ordered_set_intersection(): + assert list(ordered_set_intersection([[1, 2]])) == [1, 2] + assert list(ordered_set_intersection([[1, 2], [2, 1]])) == [1, 2] + assert list(ordered_set_intersection([[1, 2], [1, 3]])) == [1] + assert list(ordered_set_intersection([[1, 2], [2]])) == [2] + + +def test_join_dict_keys(): + dicts = [OrderedDict.fromkeys(keys) for keys in [['x', 'y'], ['y', 'z']]] + assert list(join_dict_keys(dicts, 'left')) == ['x', 'y'] + assert list(join_dict_keys(dicts, 'right')) == ['y', 'z'] + assert list(join_dict_keys(dicts, 'inner')) == ['y'] + assert list(join_dict_keys(dicts, 'outer')) == ['x', 'y', 'z'] + with pytest.raises(KeyError): + join_dict_keys(dicts, 'foobar') + + +def test_collect_dict_values(): + dicts = [{'x': 1, 'y': 2, 'z': 3}, {'z': 4}, 5] + expected = [[1, 0, 5], [2, 0, 5], [3, 4, 5]] + collected = collect_dict_values(dicts, ['x', 'y', 'z'], fill_value=0) + assert collected == expected + + +def test_apply_identity(): + array = np.arange(10) + variable = xr.Variable('x', array) + data_array = xr.DataArray(variable, [('x', -array)]) + dataset = xr.Dataset({'y': variable}, {'x': -array}) + + identity = functools.partial(apply_ufunc, lambda x: x) + + assert_identical(array, identity(array)) + assert_identical(variable, identity(variable)) + assert_identical(data_array, identity(data_array)) + assert_identical(data_array, identity(data_array.groupby('x'))) + assert_identical(dataset, identity(dataset)) + assert_identical(dataset, identity(dataset.groupby('x'))) + + +def add(a, b): + return apply_ufunc(operator.add, a, b) + + +def test_apply_two_inputs(): + array = np.array([1, 2, 3]) + variable = xr.Variable('x', array) + data_array = xr.DataArray(variable, [('x', -array)]) + dataset = xr.Dataset({'y': variable}, {'x': -array}) + + zero_array = np.zeros_like(array) + zero_variable = xr.Variable('x', zero_array) + zero_data_array = xr.DataArray(zero_variable, [('x', -array)]) + zero_dataset = xr.Dataset({'y': zero_variable}, {'x': -array}) + + assert_identical(array, add(array, zero_array)) + assert_identical(array, add(zero_array, array)) + + assert_identical(variable, add(variable, zero_array)) + assert_identical(variable, add(variable, zero_variable)) + assert_identical(variable, add(zero_array, variable)) + assert_identical(variable, add(zero_variable, variable)) + + assert_identical(data_array, add(data_array, zero_array)) + assert_identical(data_array, add(data_array, zero_variable)) + assert_identical(data_array, add(data_array, zero_data_array)) + assert_identical(data_array, add(zero_array, data_array)) + assert_identical(data_array, add(zero_variable, data_array)) + assert_identical(data_array, add(zero_data_array, data_array)) + + assert_identical(dataset, add(dataset, zero_array)) + assert_identical(dataset, add(dataset, zero_variable)) + assert_identical(dataset, add(dataset, zero_data_array)) + assert_identical(dataset, add(dataset, zero_dataset)) + assert_identical(dataset, add(zero_array, dataset)) + assert_identical(dataset, add(zero_variable, dataset)) + assert_identical(dataset, add(zero_data_array, dataset)) + assert_identical(dataset, add(zero_dataset, dataset)) + + assert_identical(data_array, add(data_array.groupby('x'), zero_data_array)) + assert_identical(data_array, add(zero_data_array, data_array.groupby('x'))) + + assert_identical(dataset, add(data_array.groupby('x'), zero_dataset)) + assert_identical(dataset, add(zero_dataset, data_array.groupby('x'))) + + assert_identical(dataset, add(dataset.groupby('x'), zero_data_array)) + assert_identical(dataset, add(dataset.groupby('x'), zero_dataset)) + assert_identical(dataset, add(zero_data_array, dataset.groupby('x'))) + assert_identical(dataset, add(zero_dataset, dataset.groupby('x'))) + + +def test_apply_1d_and_0d(): + array = np.array([1, 2, 3]) + variable = xr.Variable('x', array) + data_array = xr.DataArray(variable, [('x', -array)]) + dataset = xr.Dataset({'y': variable}, {'x': -array}) + + zero_array = 0 + zero_variable = xr.Variable((), zero_array) + zero_data_array = xr.DataArray(zero_variable) + zero_dataset = xr.Dataset({'y': zero_variable}) + + assert_identical(array, add(array, zero_array)) + assert_identical(array, add(zero_array, array)) + + assert_identical(variable, add(variable, zero_array)) + assert_identical(variable, add(variable, zero_variable)) + assert_identical(variable, add(zero_array, variable)) + assert_identical(variable, add(zero_variable, variable)) + + assert_identical(data_array, add(data_array, zero_array)) + assert_identical(data_array, add(data_array, zero_variable)) + assert_identical(data_array, add(data_array, zero_data_array)) + assert_identical(data_array, add(zero_array, data_array)) + assert_identical(data_array, add(zero_variable, data_array)) + assert_identical(data_array, add(zero_data_array, data_array)) + + assert_identical(dataset, add(dataset, zero_array)) + assert_identical(dataset, add(dataset, zero_variable)) + assert_identical(dataset, add(dataset, zero_data_array)) + assert_identical(dataset, add(dataset, zero_dataset)) + assert_identical(dataset, add(zero_array, dataset)) + assert_identical(dataset, add(zero_variable, dataset)) + assert_identical(dataset, add(zero_data_array, dataset)) + assert_identical(dataset, add(zero_dataset, dataset)) + + assert_identical(data_array, add(data_array.groupby('x'), zero_data_array)) + assert_identical(data_array, add(zero_data_array, data_array.groupby('x'))) + + assert_identical(dataset, add(data_array.groupby('x'), zero_dataset)) + assert_identical(dataset, add(zero_dataset, data_array.groupby('x'))) + + assert_identical(dataset, add(dataset.groupby('x'), zero_data_array)) + assert_identical(dataset, add(dataset.groupby('x'), zero_dataset)) + assert_identical(dataset, add(zero_data_array, dataset.groupby('x'))) + assert_identical(dataset, add(zero_dataset, dataset.groupby('x'))) + + +def test_apply_two_outputs(): + array = np.arange(5) + variable = xr.Variable('x', array) + data_array = xr.DataArray(variable, [('x', -array)]) + dataset = xr.Dataset({'y': variable}, {'x': -array}) + + def twice(obj): + func = lambda x: (x, x) + signature = '()->(),()' + return apply_ufunc(func, obj, signature=signature) + + out0, out1 = twice(array) + assert_identical(out0, array) + assert_identical(out1, array) + + out0, out1 = twice(variable) + assert_identical(out0, variable) + assert_identical(out1, variable) + + out0, out1 = twice(data_array) + assert_identical(out0, data_array) + assert_identical(out1, data_array) + + out0, out1 = twice(dataset) + assert_identical(out0, dataset) + assert_identical(out1, dataset) + + out0, out1 = twice(data_array.groupby('x')) + assert_identical(out0, data_array) + assert_identical(out1, data_array) + + out0, out1 = twice(dataset.groupby('x')) + assert_identical(out0, dataset) + assert_identical(out1, dataset) + + +def test_apply_input_core_dimension(): + + def first_element(obj, dim): + func = lambda x: x[..., 0] + sig = ([(dim,)], [()]) + return apply_ufunc(func, obj, signature=sig) + + array = np.array([[1, 2], [3, 4]]) + variable = xr.Variable(['x', 'y'], array) + data_array = xr.DataArray(variable, {'x': ['a', 'b'], 'y': [-1, -2]}) + dataset = xr.Dataset({'data': data_array}) + + expected_variable_x = xr.Variable(['y'], [1, 2]) + expected_data_array_x = xr.DataArray(expected_variable_x, {'y': [-1, -2]}) + expected_dataset_x = xr.Dataset({'data': expected_data_array_x}) + + expected_variable_y = xr.Variable(['x'], [1, 3]) + expected_data_array_y = xr.DataArray(expected_variable_y, + {'x': ['a', 'b']}) + expected_dataset_y = xr.Dataset({'data': expected_data_array_y}) + + assert_identical(expected_variable_x, first_element(variable, 'x')) + assert_identical(expected_variable_y, first_element(variable, 'y')) + + assert_identical(expected_data_array_x, first_element(data_array, 'x')) + assert_identical(expected_data_array_y, first_element(data_array, 'y')) + + assert_identical(expected_dataset_x, first_element(dataset, 'x')) + assert_identical(expected_dataset_y, first_element(dataset, 'y')) + + assert_identical(expected_data_array_x, + first_element(data_array.groupby('y'), 'x')) + assert_identical(expected_dataset_x, + first_element(dataset.groupby('y'), 'x')) + + +def test_apply_output_core_dimension(): + + def stack_negative(obj): + func = lambda x: xr.core.npcompat.stack([x, -x], axis=-1) + sig = ([()], [('sign',)]) + result = apply_ufunc(func, obj, signature=sig) + if isinstance(result, (xr.Dataset, xr.DataArray)): + result.coords['sign'] = [1, -1] + return result + + array = np.array([[1, 2], [3, 4]]) + variable = xr.Variable(['x', 'y'], array) + data_array = xr.DataArray(variable, {'x': ['a', 'b'], 'y': [-1, -2]}) + dataset = xr.Dataset({'data': data_array}) + + stacked_array = np.array([[[1, -1], [2, -2]], [[3, -3], [4, -4]]]) + stacked_variable = xr.Variable(['x', 'y', 'sign'], stacked_array) + stacked_coords = {'x': ['a', 'b'], 'y': [-1, -2], 'sign': [1, -1]} + stacked_data_array = xr.DataArray(stacked_variable, stacked_coords) + stacked_dataset = xr.Dataset({'data': stacked_data_array}) + + assert_identical(stacked_array, stack_negative(array)) + assert_identical(stacked_variable, stack_negative(variable)) + assert_identical(stacked_data_array, stack_negative(data_array)) + assert_identical(stacked_dataset, stack_negative(dataset)) + assert_identical(stacked_data_array, + stack_negative(data_array.groupby('x'))) + assert_identical(stacked_dataset, + stack_negative(dataset.groupby('x'))) + + def original_and_stack_negative(obj): + func = lambda x: (x, xr.core.npcompat.stack([x, -x], axis=-1)) + sig = ([()], [(), ('sign',)]) + result = apply_ufunc(func, obj, signature=sig) + if isinstance(result[1], (xr.Dataset, xr.DataArray)): + result[1].coords['sign'] = [1, -1] + return result + + out0, out1 = original_and_stack_negative(array) + assert_identical(array, out0) + assert_identical(stacked_array, out1) + + out0, out1 = original_and_stack_negative(variable) + assert_identical(variable, out0) + assert_identical(stacked_variable, out1) + + out0, out1 = original_and_stack_negative(data_array) + assert_identical(data_array, out0) + assert_identical(stacked_data_array, out1) + + out0, out1 = original_and_stack_negative(dataset) + assert_identical(dataset, out0) + assert_identical(stacked_dataset, out1) + + out0, out1 = original_and_stack_negative(data_array.groupby('x')) + assert_identical(data_array, out0) + assert_identical(stacked_data_array, out1) + + out0, out1 = original_and_stack_negative(dataset.groupby('x')) + assert_identical(dataset, out0) + assert_identical(stacked_dataset, out1) + + def stack_invalid(obj): + func = lambda x: xr.core.npcompat.stack([x, -x], axis=-1) + sig = ([()], [('sign',)]) + # no new_coords + return apply_ufunc(func, obj, signature=sig) + + +def test_apply_exclude(): + + def concatenate(objects, dim='x'): + sig = ([(dim,)] * len(objects), [(dim,)]) + new_coord = np.concatenate( + [obj.coords[dim] if hasattr(obj, 'coords') else [] + for obj in objects]) + func = lambda *x: np.concatenate(x, axis=-1) + result = apply_ufunc(func, *objects, signature=sig, exclude_dims={dim}) + if isinstance(result, (xr.Dataset, xr.DataArray)): + result.coords[dim] = new_coord + return result + + arrays = [np.array([1]), np.array([2, 3])] + variables = [xr.Variable('x', a) for a in arrays] + data_arrays = [xr.DataArray(v, {'x': c, 'y': ('x', range(len(c)))}) + for v, c in zip(variables, [['a'], ['b', 'c']])] + datasets = [xr.Dataset({'data': data_array}) for data_array in data_arrays] + + expected_array = np.array([1, 2, 3]) + expected_variable = xr.Variable('x', expected_array) + expected_data_array = xr.DataArray(expected_variable, [('x', list('abc'))]) + expected_dataset = xr.Dataset({'data': expected_data_array}) + + assert_identical(expected_array, concatenate(arrays)) + assert_identical(expected_variable, concatenate(variables)) + assert_identical(expected_data_array, concatenate(data_arrays)) + assert_identical(expected_dataset, concatenate(datasets)) + + identity = lambda x: x + # must also be a core dimension + with pytest.raises(ValueError): + apply_ufunc(identity, variables[0], exclude_dims={'x'}) + + +def test_apply_groupby_add(): + array = np.arange(5) + variable = xr.Variable('x', array) + coords = {'x': -array, 'y': ('x', [0, 0, 1, 1, 2])} + data_array = xr.DataArray(variable, coords, dims='x') + dataset = xr.Dataset({'z': variable}, coords) + + other_variable = xr.Variable('y', [0, 10]) + other_data_array = xr.DataArray(other_variable, dims='y') + other_dataset = xr.Dataset({'z': other_variable}) + + expected_variable = xr.Variable('x', [0, 1, 12, 13, np.nan]) + expected_data_array = xr.DataArray(expected_variable, coords, dims='x') + expected_dataset = xr.Dataset({'z': expected_variable}, coords) + + assert_identical(expected_data_array, + add(data_array.groupby('y'), other_data_array)) + assert_identical(expected_dataset, + add(data_array.groupby('y'), other_dataset)) + assert_identical(expected_dataset, + add(dataset.groupby('y'), other_data_array)) + assert_identical(expected_dataset, + add(dataset.groupby('y'), other_dataset)) + + # cannot be performed with xarray.Variable objects that share a dimension + with pytest.raises(ValueError): + add(data_array.groupby('y'), other_variable) + + # if they are all grouped the same way + with pytest.raises(ValueError): + add(data_array.groupby('y'), data_array[:4].groupby('y')) + with pytest.raises(ValueError): + add(data_array.groupby('y'), data_array[1:].groupby('y')) + with pytest.raises(ValueError): + add(data_array.groupby('y'), other_data_array.groupby('y')) + with pytest.raises(ValueError): + add(data_array.groupby('y'), data_array.groupby('x')) + + +def test_unified_dim_sizes(): + assert unified_dim_sizes([xr.Variable((), 0)]) == OrderedDict() + assert (unified_dim_sizes([xr.Variable('x', [1]), + xr.Variable('x', [1])]) == + OrderedDict([('x', 1)])) + assert (unified_dim_sizes([xr.Variable('x', [1]), + xr.Variable('y', [1, 2])]) == + OrderedDict([('x', 1), ('y', 2)])) + assert (unified_dim_sizes([xr.Variable(('x', 'z'), [[1]]), + xr.Variable(('y', 'z'), [[1, 2], [3, 4]])], + exclude_dims={'z'}) == + OrderedDict([('x', 1), ('y', 2)])) + + # duplicate dimensions + with pytest.raises(ValueError): + unified_dim_sizes([xr.Variable(('x', 'x'), [[1]])]) + + # mismatched lengths + with pytest.raises(ValueError): + unified_dim_sizes( + [xr.Variable('x', [1]), xr.Variable('x', [1, 2])]) + + +def test_broadcast_compat_data_1d(): + data = np.arange(5) + var = xr.Variable('x', data) + + assert_identical(data, broadcast_compat_data(var, ('x',), ())) + assert_identical(data, broadcast_compat_data(var, (), ('x',))) + assert_identical(data[None, :], broadcast_compat_data(var, ('w',), ('x',))) + assert_identical(data[None, :, None], + broadcast_compat_data(var, ('w', 'x', 'y'), ())) + + with pytest.raises(ValueError): + broadcast_compat_data(var, ('x',), ('w',)) + + with pytest.raises(ValueError): + broadcast_compat_data(var, (), ()) + + +def test_broadcast_compat_data_2d(): + data = np.arange(12).reshape(3, 4) + var = xr.Variable(['x', 'y'], data) + + assert_identical(data, broadcast_compat_data(var, ('x', 'y'), ())) + assert_identical(data, broadcast_compat_data(var, ('x',), ('y',))) + assert_identical(data, broadcast_compat_data(var, (), ('x', 'y'))) + assert_identical(data.T, broadcast_compat_data(var, ('y', 'x'), ())) + assert_identical(data.T, broadcast_compat_data(var, ('y',), ('x',))) + assert_identical(data[None, :, :], + broadcast_compat_data(var, ('w', 'x'), ('y',))) + assert_identical(data[None, :, :], + broadcast_compat_data(var, ('w',), ('x', 'y'))) + assert_identical(data.T[None, :, :], + broadcast_compat_data(var, ('w',), ('y', 'x'))) + assert_identical(data[None, :, :, None], + broadcast_compat_data(var, ('w', 'x', 'y', 'z'), ())) + assert_identical(data.T[None, :, :, None], + broadcast_compat_data(var, ('w', 'y', 'x', 'z'), ())) + + +class _NoCacheVariable(xr.Variable): + """Subclass of Variable for testing that does not cache values.""" + # TODO: remove this class when we change the default behavior for caching + # dask.array objects. + def _data_cached(self): + return np.asarray(self._data) + + +@requires_dask +def test_apply_dask(): + import dask.array as da + + array = da.ones((2,), chunks=2) + variable = _NoCacheVariable('x', array) + coords = xr.DataArray(variable).coords.variables + data_array = xr.DataArray(variable, coords, fastpath=True) + dataset = xr.Dataset({'y': variable}) + + identity = lambda x: x + + # encountered dask array, but did not set dask_array='allowed' + with pytest.raises(ValueError): + apply_ufunc(identity, array) + with pytest.raises(ValueError): + apply_ufunc(identity, variable) + with pytest.raises(ValueError): + apply_ufunc(identity, data_array) + with pytest.raises(ValueError): + apply_ufunc(identity, dataset) + + # unknown setting for dask array handling + with pytest.raises(ValueError): + apply_ufunc(identity, array, dask_array='auto') + + def dask_safe_identity(x): + return apply_ufunc(identity, x, dask_array='allowed') + + assert array is dask_safe_identity(array) + + actual = dask_safe_identity(variable) + assert isinstance(actual.data, da.Array) + assert_identical(variable, actual) + + actual = dask_safe_identity(data_array) + assert isinstance(actual.data, da.Array) + assert_identical(data_array, actual) + + actual = dask_safe_identity(dataset) + assert isinstance(actual['y'].data, da.Array) + assert_identical(dataset, actual)