Skip to content

Commit 6dc8b60

Browse files
zdgriffithshoyer
authored andcommitted
Add fill_value for concat and auto_combine (#2964)
* add fill_value option for concat and auto_combine * add tests for fill_value in concat and auto_combine * remove errant whitespace * add fill_value description to doc-string * add missing assert
1 parent 7edf2e2 commit 6dc8b60

File tree

2 files changed

+77
-20
lines changed

2 files changed

+77
-20
lines changed

xarray/core/combine.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import pandas as pd
66

7-
from . import utils
7+
from . import utils, dtypes
88
from .alignment import align
99
from .merge import merge
1010
from .variable import IndexVariable, Variable, as_variable
@@ -14,7 +14,7 @@
1414

1515
def concat(objs, dim=None, data_vars='all', coords='different',
1616
compat='equals', positions=None, indexers=None, mode=None,
17-
concat_over=None):
17+
concat_over=None, fill_value=dtypes.NA):
1818
"""Concatenate xarray objects along a new or existing dimension.
1919
2020
Parameters
@@ -66,6 +66,8 @@ def concat(objs, dim=None, data_vars='all', coords='different',
6666
List of integer arrays which specifies the integer positions to which
6767
to assign each dataset along the concatenated dimension. If not
6868
supplied, objects are concatenated in the provided order.
69+
fill_value : scalar, optional
70+
Value to use for newly missing values
6971
indexers, mode, concat_over : deprecated
7072
7173
Returns
@@ -117,7 +119,7 @@ def concat(objs, dim=None, data_vars='all', coords='different',
117119
else:
118120
raise TypeError('can only concatenate xarray Dataset and DataArray '
119121
'objects, got %s' % type(first_obj))
120-
return f(objs, dim, data_vars, coords, compat, positions)
122+
return f(objs, dim, data_vars, coords, compat, positions, fill_value)
121123

122124

123125
def _calc_concat_dim_coord(dim):
@@ -212,7 +214,8 @@ def process_subset_opt(opt, subset):
212214
return concat_over, equals
213215

214216

215-
def _dataset_concat(datasets, dim, data_vars, coords, compat, positions):
217+
def _dataset_concat(datasets, dim, data_vars, coords, compat, positions,
218+
fill_value=dtypes.NA):
216219
"""
217220
Concatenate a sequence of datasets along a new or existing dimension
218221
"""
@@ -225,7 +228,8 @@ def _dataset_concat(datasets, dim, data_vars, coords, compat, positions):
225228
dim, coord = _calc_concat_dim_coord(dim)
226229
# Make sure we're working on a copy (we'll be loading variables)
227230
datasets = [ds.copy() for ds in datasets]
228-
datasets = align(*datasets, join='outer', copy=False, exclude=[dim])
231+
datasets = align(*datasets, join='outer', copy=False, exclude=[dim],
232+
fill_value=fill_value)
229233

230234
concat_over, equals = _calc_concat_over(datasets, dim, data_vars, coords)
231235

@@ -317,7 +321,7 @@ def ensure_common_dims(vars):
317321

318322

319323
def _dataarray_concat(arrays, dim, data_vars, coords, compat,
320-
positions):
324+
positions, fill_value=dtypes.NA):
321325
arrays = list(arrays)
322326

323327
if data_vars != 'all':
@@ -336,14 +340,15 @@ def _dataarray_concat(arrays, dim, data_vars, coords, compat,
336340
datasets.append(arr._to_temp_dataset())
337341

338342
ds = _dataset_concat(datasets, dim, data_vars, coords, compat,
339-
positions)
343+
positions, fill_value)
340344
result = arrays[0]._from_temp_dataset(ds, name)
341345

342346
result.name = result_name(arrays)
343347
return result
344348

345349

346-
def _auto_concat(datasets, dim=None, data_vars='all', coords='different'):
350+
def _auto_concat(datasets, dim=None, data_vars='all', coords='different',
351+
fill_value=dtypes.NA):
347352
if len(datasets) == 1 and dim is None:
348353
# There is nothing more to combine, so kick out early.
349354
return datasets[0]
@@ -366,7 +371,8 @@ def _auto_concat(datasets, dim=None, data_vars='all', coords='different'):
366371
'supply the ``concat_dim`` argument '
367372
'explicitly')
368373
dim, = concat_dims
369-
return concat(datasets, dim=dim, data_vars=data_vars, coords=coords)
374+
return concat(datasets, dim=dim, data_vars=data_vars,
375+
coords=coords, fill_value=fill_value)
370376

371377

372378
_CONCAT_DIM_DEFAULT = utils.ReprObject('<inferred>')
@@ -442,7 +448,8 @@ def _check_shape_tile_ids(combined_tile_ids):
442448

443449

444450
def _combine_nd(combined_ids, concat_dims, data_vars='all',
445-
coords='different', compat='no_conflicts'):
451+
coords='different', compat='no_conflicts',
452+
fill_value=dtypes.NA):
446453
"""
447454
Concatenates and merges an N-dimensional structure of datasets.
448455
@@ -472,13 +479,14 @@ def _combine_nd(combined_ids, concat_dims, data_vars='all',
472479
dim=concat_dim,
473480
data_vars=data_vars,
474481
coords=coords,
475-
compat=compat)
482+
compat=compat,
483+
fill_value=fill_value)
476484
combined_ds = list(combined_ids.values())[0]
477485
return combined_ds
478486

479487

480488
def _auto_combine_all_along_first_dim(combined_ids, dim, data_vars,
481-
coords, compat):
489+
coords, compat, fill_value=dtypes.NA):
482490
# Group into lines of datasets which must be combined along dim
483491
# need to sort by _new_tile_id first for groupby to work
484492
# TODO remove all these sorted OrderedDicts once python >= 3.6 only
@@ -490,7 +498,8 @@ def _auto_combine_all_along_first_dim(combined_ids, dim, data_vars,
490498
combined_ids = OrderedDict(sorted(group))
491499
datasets = combined_ids.values()
492500
new_combined_ids[new_id] = _auto_combine_1d(datasets, dim, compat,
493-
data_vars, coords)
501+
data_vars, coords,
502+
fill_value)
494503
return new_combined_ids
495504

496505

@@ -500,18 +509,20 @@ def vars_as_keys(ds):
500509

501510
def _auto_combine_1d(datasets, concat_dim=_CONCAT_DIM_DEFAULT,
502511
compat='no_conflicts',
503-
data_vars='all', coords='different'):
512+
data_vars='all', coords='different',
513+
fill_value=dtypes.NA):
504514
# This is just the old auto_combine function (which only worked along 1D)
505515
if concat_dim is not None:
506516
dim = None if concat_dim is _CONCAT_DIM_DEFAULT else concat_dim
507517
sorted_datasets = sorted(datasets, key=vars_as_keys)
508518
grouped_by_vars = itertools.groupby(sorted_datasets, key=vars_as_keys)
509519
concatenated = [_auto_concat(list(ds_group), dim=dim,
510-
data_vars=data_vars, coords=coords)
520+
data_vars=data_vars, coords=coords,
521+
fill_value=fill_value)
511522
for id, ds_group in grouped_by_vars]
512523
else:
513524
concatenated = datasets
514-
merged = merge(concatenated, compat=compat)
525+
merged = merge(concatenated, compat=compat, fill_value=fill_value)
515526
return merged
516527

517528

@@ -521,7 +532,7 @@ def _new_tile_id(single_id_ds_pair):
521532

522533

523534
def _auto_combine(datasets, concat_dims, compat, data_vars, coords,
524-
infer_order_from_coords, ids):
535+
infer_order_from_coords, ids, fill_value=dtypes.NA):
525536
"""
526537
Calls logic to decide concatenation order before concatenating.
527538
"""
@@ -550,12 +561,14 @@ def _auto_combine(datasets, concat_dims, compat, data_vars, coords,
550561

551562
# Repeatedly concatenate then merge along each dimension
552563
combined = _combine_nd(combined_ids, concat_dims, compat=compat,
553-
data_vars=data_vars, coords=coords)
564+
data_vars=data_vars, coords=coords,
565+
fill_value=fill_value)
554566
return combined
555567

556568

557569
def auto_combine(datasets, concat_dim=_CONCAT_DIM_DEFAULT,
558-
compat='no_conflicts', data_vars='all', coords='different'):
570+
compat='no_conflicts', data_vars='all', coords='different',
571+
fill_value=dtypes.NA):
559572
"""Attempt to auto-magically combine the given datasets into one.
560573
This method attempts to combine a list of datasets into a single entity by
561574
inspecting metadata and using a combination of concat and merge.
@@ -596,6 +609,8 @@ def auto_combine(datasets, concat_dim=_CONCAT_DIM_DEFAULT,
596609
Details are in the documentation of concat
597610
coords : {'minimal', 'different', 'all' or list of str}, optional
598611
Details are in the documentation of conca
612+
fill_value : scalar, optional
613+
Value to use for newly missing values
599614
600615
Returns
601616
-------
@@ -622,4 +637,4 @@ def auto_combine(datasets, concat_dim=_CONCAT_DIM_DEFAULT,
622637
return _auto_combine(datasets, concat_dims=concat_dims, compat=compat,
623638
data_vars=data_vars, coords=coords,
624639
infer_order_from_coords=infer_order_from_coords,
625-
ids=False)
640+
ids=False, fill_value=fill_value)

xarray/tests/test_combine.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88

99
from xarray import DataArray, Dataset, Variable, auto_combine, concat
10+
from xarray.core import dtypes
1011
from xarray.core.combine import (
1112
_auto_combine, _auto_combine_1d, _auto_combine_all_along_first_dim,
1213
_check_shape_tile_ids, _combine_nd, _infer_concat_order_from_positions,
@@ -237,6 +238,20 @@ def test_concat_multiindex(self):
237238
assert expected.equals(actual)
238239
assert isinstance(actual.x.to_index(), pd.MultiIndex)
239240

241+
@pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0])
242+
def test_concat_fill_value(self, fill_value):
243+
datasets = [Dataset({'a': ('x', [2, 3]), 'x': [1, 2]}),
244+
Dataset({'a': ('x', [1, 2]), 'x': [0, 1]})]
245+
if fill_value == dtypes.NA:
246+
# if we supply the default, we expect the missing value for a
247+
# float array
248+
fill_value = np.nan
249+
expected = Dataset({'a': (('t', 'x'),
250+
[[fill_value, 2, 3], [1, 2, fill_value]])},
251+
{'x': [0, 1, 2]})
252+
actual = concat(datasets, dim='t', fill_value=fill_value)
253+
assert_identical(actual, expected)
254+
240255

241256
class TestConcatDataArray:
242257
def test_concat(self):
@@ -306,6 +321,19 @@ def test_concat_lazy(self):
306321
assert combined.shape == (2, 3, 3)
307322
assert combined.dims == ('z', 'x', 'y')
308323

324+
@pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0])
325+
def test_concat_fill_value(self, fill_value):
326+
foo = DataArray([1, 2], coords=[('x', [1, 2])])
327+
bar = DataArray([1, 2], coords=[('x', [1, 3])])
328+
if fill_value == dtypes.NA:
329+
# if we supply the default, we expect the missing value for a
330+
# float array
331+
fill_value = np.nan
332+
expected = DataArray([[1, 2, fill_value], [1, fill_value, 2]],
333+
dims=['y', 'x'], coords={'x': [1, 2, 3]})
334+
actual = concat((foo, bar), dim='y', fill_value=fill_value)
335+
assert_identical(actual, expected)
336+
309337

310338
class TestAutoCombine:
311339

@@ -417,6 +445,20 @@ def test_auto_combine_no_concat(self):
417445
{'baz': [100]})
418446
assert_identical(expected, actual)
419447

448+
@pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0])
449+
def test_auto_combine_fill_value(self, fill_value):
450+
datasets = [Dataset({'a': ('x', [2, 3]), 'x': [1, 2]}),
451+
Dataset({'a': ('x', [1, 2]), 'x': [0, 1]})]
452+
if fill_value == dtypes.NA:
453+
# if we supply the default, we expect the missing value for a
454+
# float array
455+
fill_value = np.nan
456+
expected = Dataset({'a': (('t', 'x'),
457+
[[fill_value, 2, 3], [1, 2, fill_value]])},
458+
{'x': [0, 1, 2]})
459+
actual = auto_combine(datasets, concat_dim='t', fill_value=fill_value)
460+
assert_identical(expected, actual)
461+
420462

421463
def assert_combined_tile_ids_equal(dict1, dict2):
422464
assert len(dict1) == len(dict2)

0 commit comments

Comments
 (0)