-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
combine_first by using apply_ufunc in ops.fillna #1204
Changes from 1 commit
8c6a9ab
2fdd135
c190f84
02a4a14
f0c0866
b172eda
6e65c8b
67a599f
8c46c51
3c59009
a72e2a1
8bf856e
56c752a
4bf4efe
367e3ca
793552e
f5ebf78
998cd71
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -207,7 +207,6 @@ def apply_dataarray_ufunc(func, *args, **kwargs): | |
join = kwargs.pop('join', 'inner') | ||
exclude_dims = kwargs.pop('exclude_dims', _DEFAULT_FROZEN_SET) | ||
keep_attrs = kwargs.pop('keep_attrs', False) | ||
first_obj = args[0] # we'll copy attrs from this in case keep_attrs=True | ||
if kwargs: | ||
raise TypeError('apply_dataarray_ufunc() got unexpected keyword ' | ||
'arguments: %s' % list(kwargs)) | ||
|
@@ -229,11 +228,11 @@ def apply_dataarray_ufunc(func, *args, **kwargs): | |
coords, = result_coords | ||
out = DataArray(result_var, coords, name=name, fastpath=True) | ||
|
||
if keep_attrs and isinstance(first_obj, DataArray): | ||
if keep_attrs and isinstance(args[0], DataArray): | ||
if isinstance(out, tuple): | ||
out = tuple(ds._copy_attrs_from(first_obj) for ds in out) | ||
out = tuple(ds._copy_attrs_from(args[0]) for ds in out) | ||
else: | ||
out._copy_attrs_from(first_obj) | ||
out._copy_attrs_from(args[0]) | ||
return out | ||
|
||
def ordered_set_union(all_keys): | ||
|
@@ -270,9 +269,6 @@ def join_dict_keys(objects, how='inner'): | |
|
||
def collect_dict_values(objects, keys, fill_value=None): | ||
# type: (Iterable[Union[Mapping, Any]], Iterable, Any) -> List[list] | ||
if fill_value is _DEFAULT_FILL_VALUE: | ||
raise ValueError('Inappropriate fill value for Dataset: {}' | ||
.format(fill_value)) | ||
return [[obj.get(key, fill_value) | ||
if is_dict_like(obj) | ||
else obj | ||
|
@@ -346,11 +342,17 @@ def apply_dataset_ufunc(func, *args, **kwargs): | |
from .dataset import Dataset | ||
signature = kwargs.pop('signature') | ||
join = kwargs.pop('join', 'inner') | ||
data_vars_join = kwargs.pop('data_vars_join', 'inner') | ||
dataset_join = kwargs.pop('dataset_join', 'inner') | ||
fill_value = kwargs.pop('fill_value', None) | ||
exclude_dims = kwargs.pop('exclude_dims', _DEFAULT_FROZEN_SET) | ||
keep_attrs = kwargs.pop('keep_attrs', False) | ||
first_obj = args[0] # we'll copy attrs from this in case keep_attrs=True | ||
|
||
if dataset_join != 'inner' and fill_value is _DEFAULT_FILL_VALUE: | ||
raise TypeError('To apply an operation to datasets with different ', | ||
'data variables, you must supply the ', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add |
||
'dataset_fill_value argument.') | ||
|
||
if kwargs: | ||
raise TypeError('apply_dataset_ufunc() got unexpected keyword ' | ||
'arguments: %s' % list(kwargs)) | ||
|
@@ -362,7 +364,7 @@ def apply_dataset_ufunc(func, *args, **kwargs): | |
args = [getattr(arg, 'data_vars', arg) for arg in args] | ||
|
||
result_vars = apply_dict_of_variables_ufunc( | ||
func, *args, signature=signature, join=data_vars_join, | ||
func, *args, signature=signature, join=dataset_join, | ||
fill_value=fill_value) | ||
|
||
if signature.n_outputs > 1: | ||
|
@@ -607,23 +609,23 @@ def apply_ufunc(func, *args, **kwargs): | |
- 'inner': use the intersection of object indexes | ||
- 'left': use indexes from the first object with each dimension | ||
- 'right': use indexes from the last object with each dimension | ||
data_vars_join : {'outer', 'inner', 'left', 'right'}, optional | ||
dataset_join : {'outer', 'inner', 'left', 'right'}, optional | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you fix the function signature at the top of the docstring, also to add in the new arguments? |
||
Method for joining variables of Dataset objects with mismatched | ||
data variables. | ||
- 'outer': take variables from both Dataset objects | ||
- 'inner': take only overlapped variables | ||
- 'left': take only variables from the first object | ||
- 'right': take only variables from the last object | ||
dataset_fill_value : optional | ||
Value used in place of missing variables on Dataset inputs when the | ||
datasets do not share the exact same ``data_vars``. Only relevant if | ||
``dataset_join != 'inner'``. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe switch the last sentence to: Required if dataset_join != 'inner', otherwise ignored. |
||
keep_attrs: boolean, Optional | ||
Whether to copy attributes from the first argument to the output. | ||
exclude_dims : set, optional | ||
Dimensions to exclude from alignment and broadcasting. Any inputs | ||
coordinates along these dimensions will be dropped. Each excluded | ||
dimension must be a core dimension in the function signature. | ||
dataset_fill_value : optional | ||
Value used in place of missing variables on Dataset inputs when the | ||
datasets do not share the exact same ``data_vars``. Only relevant if | ||
``join != 'inner'``. | ||
kwargs: dict, optional | ||
Optional keyword arguments passed directly on to call ``func``. | ||
dask_array: 'forbidden' or 'allowed', optional | ||
|
@@ -699,7 +701,7 @@ def stack(objects, dim, new_coord): | |
|
||
signature = kwargs.pop('signature', None) | ||
join = kwargs.pop('join', 'inner') | ||
data_vars_join = kwargs.pop('data_vars_join', 'inner') | ||
dataset_join = kwargs.pop('dataset_join', 'inner') | ||
keep_attrs = kwargs.pop('keep_attrs', False) | ||
exclude_dims = kwargs.pop('exclude_dims', frozenset()) | ||
dataset_fill_value = kwargs.pop('dataset_fill_value', _DEFAULT_FILL_VALUE) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should rename this argument to Or maybe the pair should be |
||
|
@@ -735,14 +737,14 @@ def stack(objects, dim, new_coord): | |
apply_ufunc, func, signature=signature, join=join, | ||
dask_array=dask_array, exclude_dims=exclude_dims, | ||
dataset_fill_value=dataset_fill_value, | ||
data_vars_join=data_vars_join, | ||
dataset_join=dataset_join, | ||
keep_attrs=keep_attrs) | ||
return apply_groupby_ufunc(this_apply, *args) | ||
elif any(is_dict_like(a) for a in args): | ||
return apply_dataset_ufunc(variables_ufunc, *args, signature=signature, | ||
join=join, exclude_dims=exclude_dims, | ||
fill_value=dataset_fill_value, | ||
data_vars_join=data_vars_join, | ||
dataset_join=dataset_join, | ||
keep_attrs=keep_attrs) | ||
elif any(isinstance(a, DataArray) for a in args): | ||
return apply_dataarray_ufunc(variables_ufunc, *args, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -89,8 +89,7 @@ def test_apply_identity(): | |
data_array = xr.DataArray(variable, [('x', -array)]) | ||
dataset = xr.Dataset({'y': variable}, {'x': -array}) | ||
|
||
identity = functools.partial(apply_ufunc, lambda x: x, | ||
dataset_fill_value=np.nan) | ||
identity = functools.partial(apply_ufunc, lambda x: x) | ||
|
||
assert_identical(array, identity(array)) | ||
assert_identical(variable, identity(variable)) | ||
|
@@ -101,7 +100,7 @@ def test_apply_identity(): | |
|
||
|
||
def add(a, b): | ||
return apply_ufunc(operator.add, a, b, dataset_fill_value=np.nan) | ||
return apply_ufunc(operator.add, a, b) | ||
|
||
|
||
def test_apply_two_inputs(): | ||
|
@@ -207,8 +206,7 @@ def test_apply_two_outputs(): | |
def twice(obj): | ||
func = lambda x: (x, x) | ||
signature = '()->(),()' | ||
return apply_ufunc(func, obj, signature=signature, | ||
dataset_fill_value=np.nan) | ||
return apply_ufunc(func, obj, signature=signature) | ||
|
||
out0, out1 = twice(array) | ||
assert_identical(out0, array) | ||
|
@@ -240,8 +238,7 @@ def test_apply_input_core_dimension(): | |
def first_element(obj, dim): | ||
func = lambda x: x[..., 0] | ||
sig = ([(dim,)], [()]) | ||
return apply_ufunc(func, obj, signature=sig, | ||
dataset_fill_value=np.nan) | ||
return apply_ufunc(func, obj, signature=sig) | ||
|
||
array = np.array([[1, 2], [3, 4]]) | ||
variable = xr.Variable(['x', 'y'], array) | ||
|
@@ -277,8 +274,7 @@ def test_apply_output_core_dimension(): | |
def stack_negative(obj): | ||
func = lambda x: xr.core.npcompat.stack([x, -x], axis=-1) | ||
sig = ([()], [('sign',)]) | ||
result = apply_ufunc(func, obj, signature=sig, | ||
dataset_fill_value=np.nan) | ||
result = apply_ufunc(func, obj, signature=sig) | ||
if isinstance(result, (xr.Dataset, xr.DataArray)): | ||
result.coords['sign'] = [1, -1] | ||
return result | ||
|
@@ -306,8 +302,7 @@ def stack_negative(obj): | |
def original_and_stack_negative(obj): | ||
func = lambda x: (x, xr.core.npcompat.stack([x, -x], axis=-1)) | ||
sig = ([()], [(), ('sign',)]) | ||
result = apply_ufunc(func, obj, signature=sig, | ||
dataset_fill_value=np.nan) | ||
result = apply_ufunc(func, obj, signature=sig) | ||
if isinstance(result[1], (xr.Dataset, xr.DataArray)): | ||
result[1].coords['sign'] = [1, -1] | ||
return result | ||
|
@@ -352,8 +347,7 @@ def concatenate(objects, dim='x'): | |
[obj.coords[dim] if hasattr(obj, 'coords') else [] | ||
for obj in objects]) | ||
func = lambda *x: np.concatenate(x, axis=-1) | ||
result = apply_ufunc(func, *objects, signature=sig, exclude_dims={dim}, | ||
dataset_fill_value=np.nan) | ||
result = apply_ufunc(func, *objects, signature=sig, exclude_dims={dim}) | ||
if isinstance(result, (xr.Dataset, xr.DataArray)): | ||
result.coords[dim] = new_coord | ||
return result | ||
|
@@ -480,14 +474,46 @@ def test_broadcast_compat_data_2d(): | |
broadcast_compat_data(var, ('w', 'y', 'x', 'z'), ())) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be good to add a test in here for |
||
|
||
def test_data_vars_join(): | ||
def test_keep_attrs(): | ||
|
||
def add(a, b, keep_attrs): | ||
if keep_attrs: | ||
return apply_ufunc(operator.add, a, b, keep_attrs=keep_attrs) | ||
else: | ||
return apply_ufunc(operator.add, a, b) | ||
|
||
a = xr.DataArray([0, 1], [('x', [0, 1])]) | ||
a.attrs['attr'] = 'da' | ||
b = xr.DataArray([1, 2], [('x', [0, 1])]) | ||
|
||
actual = add(a, b, keep_attrs=False) | ||
assert not actual.attrs | ||
actual = add(a, b, keep_attrs=True) | ||
assert_identical(actual.attrs, a.attrs) | ||
|
||
a = xr.Dataset({'x': ('x', [1, 2]), 'x': [0, 1]}) | ||
a.attrs['attr'] = 'ds' | ||
a.x.attrs['attr'] = 'da' | ||
b = xr.Dataset({'x': ('x', [1, 1]), 'x': [0, 1]}) | ||
|
||
actual = add(a, b, keep_attrs=False) | ||
assert not actual.attrs | ||
actual = add(a, b, keep_attrs=True) | ||
assert_identical(actual.attrs, a.attrs) | ||
assert_identical(actual.x.attrs, a.x.attrs) | ||
|
||
|
||
def test_dataset_join(): | ||
import numpy as np | ||
ds0 = xr.Dataset({'a': ('x', [1, 2]), 'x': [0, 1]}) | ||
ds1 = xr.Dataset({'a': ('x', [99, 3]), 'x': [1, 2]}) | ||
|
||
def add(a, b, join, data_vars_join): | ||
with pytest.raises(TypeError): | ||
apply_ufunc(operator.add, ds0, ds1, dataset_join='outer') | ||
|
||
def add(a, b, join, dataset_join): | ||
return apply_ufunc(operator.add, a, b, join=join, | ||
data_vars_join=data_vars_join, | ||
dataset_join=dataset_join, | ||
dataset_fill_value=np.nan) | ||
|
||
actual = add(ds0, ds1, 'outer', 'inner') | ||
|
@@ -546,8 +572,7 @@ def test_apply_dask(): | |
apply_ufunc(identity, array, dask_array='auto') | ||
|
||
def dask_safe_identity(x): | ||
return apply_ufunc(identity, x, dask_array='allowed', | ||
dataset_fill_value=np.nan) | ||
return apply_ufunc(identity, x, dask_array='allowed') | ||
|
||
assert array is dask_safe_identity(array) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you lost "see combine" here