Skip to content
forked from pydata/xarray

Commit

Permalink
[WIP] Custom fill value for reindex, align, and merge operations (pyd…
Browse files Browse the repository at this point in the history
…ata#2920)

* add fill_value option to align and reindex functions

* add fill_value tests for reindex and align

* add fill_value option for merge functions

* add tests for fill_value merge implementation

* implement and test fill_value option in dataaarray reindex methods

* fix PEP 8 issue

* move function signature onto function

* Add fill_value enhancement note
  • Loading branch information
zdgriffith authored and shoyer committed May 5, 2019
1 parent ccd0b04 commit 5aaa654
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 48 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ v0.12.2 (unreleased)
Enhancements
~~~~~~~~~~~~

- Add ``fill_value`` argument for reindex, align, and merge operations
to enable custom fill values. (:issue:`2876`)
By `Zach Griffith <https://github.com/zdgriffith>`_.
- Character arrays' character dimension name decoding and encoding handled by
``var.encoding['char_dim_name']`` (:issue:`2895`)
By `James McCreight <https://github.com/jmccreight>`_.
Expand Down
41 changes: 19 additions & 22 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import pandas as pd

from . import utils
from . import utils, dtypes
from .indexing import get_indexer_nd
from .utils import is_dict_like, is_full_slice
from .variable import IndexVariable, Variable
Expand All @@ -31,20 +31,17 @@ def _get_joiner(join):
raise ValueError('invalid value for join: %s' % join)


_DEFAULT_EXCLUDE = frozenset() # type: frozenset


def align(*objects, **kwargs):
"""align(*objects, join='inner', copy=True, indexes=None,
exclude=frozenset())
def align(*objects, join='inner', copy=True, indexes=None, exclude=frozenset(),
fill_value=dtypes.NA):
"""
Given any number of Dataset and/or DataArray objects, returns new
objects with aligned indexes and dimension sizes.
Array from the aligned objects are suitable as input to mathematical
operators, because along each dimension they have the same index and size.
Missing values (if ``join != 'inner'``) are filled with NaN.
Missing values (if ``join != 'inner'``) are filled with ``fill_value``.
The default fill value is NaN.
Parameters
----------
Expand All @@ -65,11 +62,13 @@ def align(*objects, **kwargs):
``copy=False`` and reindexing is unnecessary, or can be performed with
only slice operations, then the output may share memory with the input.
In either case, new xarray objects are always returned.
exclude : sequence of str, optional
Dimensions that must be excluded from alignment
indexes : dict-like, optional
Any indexes explicitly provided with the `indexes` argument should be
used in preference to the aligned indexes.
exclude : sequence of str, optional
Dimensions that must be excluded from alignment
fill_value : scalar, optional
Value to use for newly missing values
Returns
-------
Expand All @@ -82,15 +81,8 @@ def align(*objects, **kwargs):
If any dimensions without labels on the arguments have different sizes,
or a different size than the size of the aligned dimension labels.
"""
join = kwargs.pop('join', 'inner')
copy = kwargs.pop('copy', True)
indexes = kwargs.pop('indexes', None)
exclude = kwargs.pop('exclude', _DEFAULT_EXCLUDE)
if indexes is None:
indexes = {}
if kwargs:
raise TypeError('align() got unexpected keyword arguments: %s'
% list(kwargs))

if not indexes and len(objects) == 1:
# fast path for the trivial case
Expand Down Expand Up @@ -162,15 +154,17 @@ def align(*objects, **kwargs):
# fast path for no reindexing necessary
new_obj = obj.copy(deep=copy)
else:
new_obj = obj.reindex(copy=copy, **valid_indexers)
new_obj = obj.reindex(copy=copy, fill_value=fill_value,
**valid_indexers)
new_obj.encoding = obj.encoding
result.append(new_obj)

return tuple(result)


def deep_align(objects, join='inner', copy=True, indexes=None,
exclude=frozenset(), raise_on_invalid=True):
exclude=frozenset(), raise_on_invalid=True,
fill_value=dtypes.NA):
"""Align objects for merging, recursing into dictionary values.
This function is not public API.
Expand Down Expand Up @@ -214,7 +208,7 @@ def is_alignable(obj):
out.append(variables)

aligned = align(*targets, join=join, copy=copy, indexes=indexes,
exclude=exclude)
exclude=exclude, fill_value=fill_value)

for position, key, aligned_obj in zip(positions, keys, aligned):
if key is no_key:
Expand Down Expand Up @@ -270,6 +264,7 @@ def reindex_variables(
method: Optional[str] = None,
tolerance: Any = None,
copy: bool = True,
fill_value: Optional[Any] = dtypes.NA,
) -> 'Tuple[OrderedDict[Any, Variable], OrderedDict[Any, pd.Index]]':
"""Conform a dictionary of aligned variables onto a new set of variables,
filling in missing values with NaN.
Expand Down Expand Up @@ -305,6 +300,8 @@ def reindex_variables(
``copy=False`` and reindexing is unnecessary, or can be performed
with only slice operations, then the output may share memory with
the input. In either case, new xarray objects are always returned.
fill_value : scalar, optional
Value to use for newly missing values
Returns
-------
Expand Down Expand Up @@ -380,7 +377,7 @@ def reindex_variables(
needs_masking = any(d in masked_dims for d in var.dims)

if needs_masking:
new_var = var._getitem_with_mask(key)
new_var = var._getitem_with_mask(key, fill_value=fill_value)
elif all(is_full_slice(k) for k in key):
# no reindexing necessary
# here we need to manually deal with copying data, since
Expand Down
22 changes: 14 additions & 8 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,9 +879,10 @@ def sel_points(self, dim='points', method=None, tolerance=None,
dim=dim, method=method, tolerance=tolerance, **indexers)
return self._from_temp_dataset(ds)

def reindex_like(self, other, method=None, tolerance=None, copy=True):
"""Conform this object onto the indexes of another object, filling
in missing values with NaN.
def reindex_like(self, other, method=None, tolerance=None, copy=True,
fill_value=dtypes.NA):
"""Conform this object onto the indexes of another object, filling in
missing values with ``fill_value``. The default fill value is NaN.
Parameters
----------
Expand Down Expand Up @@ -910,6 +911,8 @@ def reindex_like(self, other, method=None, tolerance=None, copy=True):
``copy=False`` and reindexing is unnecessary, or can be performed
with only slice operations, then the output may share memory with
the input. In either case, a new xarray object is always returned.
fill_value : scalar, optional
Value to use for newly missing values
Returns
-------
Expand All @@ -924,12 +927,12 @@ def reindex_like(self, other, method=None, tolerance=None, copy=True):
"""
indexers = reindex_like_indexers(self, other)
return self.reindex(method=method, tolerance=tolerance, copy=copy,
**indexers)
fill_value=fill_value, **indexers)

def reindex(self, indexers=None, method=None, tolerance=None, copy=True,
**indexers_kwargs):
"""Conform this object onto a new set of indexes, filling in
missing values with NaN.
fill_value=dtypes.NA, **indexers_kwargs):
"""Conform this object onto the indexes of another object, filling in
missing values with ``fill_value``. The default fill value is NaN.
Parameters
----------
Expand All @@ -956,6 +959,8 @@ def reindex(self, indexers=None, method=None, tolerance=None, copy=True,
Maximum distance between original and new labels for inexact
matches. The values of the index at the matching locations must
satisfy the equation ``abs(index[indexer] - target) <= tolerance``.
fill_value : scalar, optional
Value to use for newly missing values
**indexers_kwarg : {dim: indexer, ...}, optional
The keyword arguments form of ``indexers``.
One of indexers or indexers_kwargs must be provided.
Expand All @@ -974,7 +979,8 @@ def reindex(self, indexers=None, method=None, tolerance=None, copy=True,
indexers = either_dict_or_kwargs(
indexers, indexers_kwargs, 'reindex')
ds = self._to_temp_dataset().reindex(
indexers=indexers, method=method, tolerance=tolerance, copy=copy)
indexers=indexers, method=method, tolerance=tolerance, copy=copy,
fill_value=fill_value)
return self._from_temp_dataset(ds)

def interp(self, coords=None, method='linear', assume_sorted=False,
Expand Down
25 changes: 16 additions & 9 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1932,9 +1932,10 @@ def sel_points(self, dim='points', method=None, tolerance=None,
)
return self.isel_points(dim=dim, **pos_indexers)

def reindex_like(self, other, method=None, tolerance=None, copy=True):
"""Conform this object onto the indexes of another object, filling
in missing values with NaN.
def reindex_like(self, other, method=None, tolerance=None, copy=True,
fill_value=dtypes.NA):
"""Conform this object onto the indexes of another object, filling in
missing values with ``fill_value``. The default fill value is NaN.
Parameters
----------
Expand Down Expand Up @@ -1963,6 +1964,8 @@ def reindex_like(self, other, method=None, tolerance=None, copy=True):
``copy=False`` and reindexing is unnecessary, or can be performed
with only slice operations, then the output may share memory with
the input. In either case, a new xarray object is always returned.
fill_value : scalar, optional
Value to use for newly missing values
Returns
-------
Expand All @@ -1977,12 +1980,12 @@ def reindex_like(self, other, method=None, tolerance=None, copy=True):
"""
indexers = alignment.reindex_like_indexers(self, other)
return self.reindex(indexers=indexers, method=method, copy=copy,
tolerance=tolerance)
fill_value=fill_value, tolerance=tolerance)

def reindex(self, indexers=None, method=None, tolerance=None, copy=True,
**indexers_kwargs):
fill_value=dtypes.NA, **indexers_kwargs):
"""Conform this object onto a new set of indexes, filling in
missing values with NaN.
missing values with ``fill_value``. The default fill value is NaN.
Parameters
----------
Expand Down Expand Up @@ -2010,6 +2013,8 @@ def reindex(self, indexers=None, method=None, tolerance=None, copy=True,
``copy=False`` and reindexing is unnecessary, or can be performed
with only slice operations, then the output may share memory with
the input. In either case, a new xarray object is always returned.
fill_value : scalar, optional
Value to use for newly missing values
**indexers_kwarg : {dim: indexer, ...}, optional
Keyword arguments in the same form as ``indexers``.
One of indexers or indexers_kwargs must be provided.
Expand All @@ -2034,7 +2039,7 @@ def reindex(self, indexers=None, method=None, tolerance=None, copy=True,

variables, indexes = alignment.reindex_variables(
self.variables, self.sizes, self.indexes, indexers, method,
tolerance, copy=copy)
tolerance, copy=copy, fill_value=fill_value)
coord_names = set(self._coord_names)
coord_names.update(indexers)
return self._replace_with_new_dims(
Expand Down Expand Up @@ -2752,7 +2757,7 @@ def update(self, other, inplace=None):
inplace=inplace)

def merge(self, other, inplace=None, overwrite_vars=frozenset(),
compat='no_conflicts', join='outer'):
compat='no_conflicts', join='outer', fill_value=dtypes.NA):
"""Merge the arrays of two datasets into a single dataset.
This method generally not allow for overriding data, with the exception
Expand Down Expand Up @@ -2790,6 +2795,8 @@ def merge(self, other, inplace=None, overwrite_vars=frozenset(),
- 'left': use indexes from ``self``
- 'right': use indexes from ``other``
- 'exact': error instead of aligning non-equal indexes
fill_value: scalar, optional
Value to use for newly missing values
Returns
-------
Expand All @@ -2804,7 +2811,7 @@ def merge(self, other, inplace=None, overwrite_vars=frozenset(),
inplace = _check_inplace(inplace)
variables, coord_names, dims = dataset_merge_method(
self, other, overwrite_vars=overwrite_vars, compat=compat,
join=join)
join=join, fill_value=fill_value)

return self._replace_vars_and_dims(variables, coord_names, dims,
inplace=inplace)
Expand Down
27 changes: 19 additions & 8 deletions xarray/core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pandas as pd

from . import dtypes
from .alignment import deep_align
from .pycompat import TYPE_CHECKING
from .utils import Frozen
Expand Down Expand Up @@ -349,7 +350,7 @@ def expand_and_merge_variables(objs, priority_arg=None):


def merge_coords(objs, compat='minimal', join='outer', priority_arg=None,
indexes=None):
indexes=None, fill_value=dtypes.NA):
"""Merge coordinate variables.
See merge_core below for argument descriptions. This works similarly to
Expand All @@ -358,7 +359,8 @@ def merge_coords(objs, compat='minimal', join='outer', priority_arg=None,
"""
_assert_compat_valid(compat)
coerced = coerce_pandas_values(objs)
aligned = deep_align(coerced, join=join, copy=False, indexes=indexes)
aligned = deep_align(coerced, join=join, copy=False, indexes=indexes,
fill_value=fill_value)
expanded = expand_variable_dicts(aligned)
priority_vars = _get_priority_vars(aligned, priority_arg, compat=compat)
variables = merge_variables(expanded, priority_vars, compat=compat)
Expand Down Expand Up @@ -404,7 +406,8 @@ def merge_core(objs,
join='outer',
priority_arg=None,
explicit_coords=None,
indexes=None):
indexes=None,
fill_value=dtypes.NA):
"""Core logic for merging labeled objects.
This is not public API.
Expand All @@ -423,6 +426,8 @@ def merge_core(objs,
An explicit list of variables from `objs` that are coordinates.
indexes : dict, optional
Dictionary with values given by pandas.Index objects.
fill_value : scalar, optional
Value to use for newly missing values
Returns
-------
Expand All @@ -442,7 +447,8 @@ def merge_core(objs,
_assert_compat_valid(compat)

coerced = coerce_pandas_values(objs)
aligned = deep_align(coerced, join=join, copy=False, indexes=indexes)
aligned = deep_align(coerced, join=join, copy=False, indexes=indexes,
fill_value=fill_value)
expanded = expand_variable_dicts(aligned)

coord_names, noncoord_names = determine_coords(coerced)
Expand Down Expand Up @@ -470,7 +476,7 @@ def merge_core(objs,
return variables, coord_names, dict(dims)


def merge(objects, compat='no_conflicts', join='outer'):
def merge(objects, compat='no_conflicts', join='outer', fill_value=dtypes.NA):
"""Merge any number of xarray objects into a single Dataset as variables.
Parameters
Expand All @@ -492,6 +498,8 @@ def merge(objects, compat='no_conflicts', join='outer'):
of all non-null values.
join : {'outer', 'inner', 'left', 'right', 'exact'}, optional
How to combine objects with different indexes.
fill_value : scalar, optional
Value to use for newly missing values
Returns
-------
Expand Down Expand Up @@ -529,15 +537,17 @@ def merge(objects, compat='no_conflicts', join='outer'):
obj.to_dataset() if isinstance(obj, DataArray) else obj
for obj in objects]

variables, coord_names, dims = merge_core(dict_like_objects, compat, join)
variables, coord_names, dims = merge_core(dict_like_objects, compat, join,
fill_value=fill_value)
# TODO: don't always recompute indexes
merged = Dataset._construct_direct(
variables, coord_names, dims, indexes=None)

return merged


def dataset_merge_method(dataset, other, overwrite_vars, compat, join):
def dataset_merge_method(dataset, other, overwrite_vars, compat, join,
fill_value=dtypes.NA):
"""Guts of the Dataset.merge method."""

# we are locked into supporting overwrite_vars for the Dataset.merge
Expand Down Expand Up @@ -565,7 +575,8 @@ def dataset_merge_method(dataset, other, overwrite_vars, compat, join):
objs = [dataset, other_no_overwrite, other_overwrite]
priority_arg = 2

return merge_core(objs, compat, join, priority_arg=priority_arg)
return merge_core(objs, compat, join, priority_arg=priority_arg,
fill_value=fill_value)


def dataset_update_method(dataset, other):
Expand Down
Loading

0 comments on commit 5aaa654

Please sign in to comment.