1
+ from typing import (
2
+ TYPE_CHECKING ,
3
+ Dict ,
4
+ Hashable ,
5
+ Iterable ,
6
+ List ,
7
+ Optional ,
8
+ Set ,
9
+ Tuple ,
10
+ Union ,
11
+ overload ,
12
+ )
13
+
1
14
import pandas as pd
2
15
3
16
from . import dtypes , utils
7
20
from .variable import IndexVariable , Variable , as_variable
8
21
from .variable import concat as concat_vars
9
22
23
+ if TYPE_CHECKING :
24
+ from .dataarray import DataArray
25
+ from .dataset import Dataset
26
+
27
+
28
+ @overload
29
+ def concat (
30
+ objs : Iterable ["Dataset" ],
31
+ dim : Union [str , "DataArray" , pd .Index ],
32
+ data_vars : Union [str , List [str ]] = "all" ,
33
+ coords : Union [str , List [str ]] = "different" ,
34
+ compat : str = "equals" ,
35
+ positions : Optional [Iterable [int ]] = None ,
36
+ fill_value : object = dtypes .NA ,
37
+ join : str = "outer" ,
38
+ combine_attrs : str = "override" ,
39
+ ) -> "Dataset" :
40
+ ...
41
+
42
+
43
+ @overload
44
+ def concat (
45
+ objs : Iterable ["DataArray" ],
46
+ dim : Union [str , "DataArray" , pd .Index ],
47
+ data_vars : Union [str , List [str ]] = "all" ,
48
+ coords : Union [str , List [str ]] = "different" ,
49
+ compat : str = "equals" ,
50
+ positions : Optional [Iterable [int ]] = None ,
51
+ fill_value : object = dtypes .NA ,
52
+ join : str = "outer" ,
53
+ combine_attrs : str = "override" ,
54
+ ) -> "DataArray" :
55
+ ...
56
+
10
57
11
58
def concat (
12
59
objs ,
@@ -285,13 +332,15 @@ def process_subset_opt(opt, subset):
285
332
286
333
287
334
# determine dimensional coordinate names and a dict mapping name to DataArray
288
- def _parse_datasets (datasets ):
335
+ def _parse_datasets (
336
+ datasets : Iterable ["Dataset" ],
337
+ ) -> Tuple [Dict [Hashable , Variable ], Dict [Hashable , int ], Set [Hashable ], Set [Hashable ]]:
289
338
290
- dims = set ()
291
- all_coord_names = set ()
292
- data_vars = set () # list of data_vars
293
- dim_coords = {} # maps dim name to variable
294
- dims_sizes = {} # shared dimension sizes to expand variables
339
+ dims : Set [ Hashable ] = set ()
340
+ all_coord_names : Set [ Hashable ] = set ()
341
+ data_vars : Set [ Hashable ] = set () # list of data_vars
342
+ dim_coords : Dict [ Hashable , Variable ] = {} # maps dim name to variable
343
+ dims_sizes : Dict [ Hashable , int ] = {} # shared dimension sizes to expand variables
295
344
296
345
for ds in datasets :
297
346
dims_sizes .update (ds .dims )
@@ -307,16 +356,16 @@ def _parse_datasets(datasets):
307
356
308
357
309
358
def _dataset_concat (
310
- datasets ,
311
- dim ,
312
- data_vars ,
313
- coords ,
314
- compat ,
315
- positions ,
316
- fill_value = dtypes .NA ,
317
- join = "outer" ,
318
- combine_attrs = "override" ,
319
- ):
359
+ datasets : List [ "Dataset" ] ,
360
+ dim : Union [ str , "DataArray" , pd . Index ] ,
361
+ data_vars : Union [ str , List [ str ]] ,
362
+ coords : Union [ str , List [ str ]] ,
363
+ compat : str ,
364
+ positions : Optional [ Iterable [ int ]] ,
365
+ fill_value : object = dtypes .NA ,
366
+ join : str = "outer" ,
367
+ combine_attrs : str = "override" ,
368
+ ) -> "Dataset" :
320
369
"""
321
370
Concatenate a sequence of datasets along a new or existing dimension
322
371
"""
@@ -356,7 +405,9 @@ def _dataset_concat(
356
405
357
406
result_vars = {}
358
407
if variables_to_merge :
359
- to_merge = {var : [] for var in variables_to_merge }
408
+ to_merge : Dict [Hashable , List [Variable ]] = {
409
+ var : [] for var in variables_to_merge
410
+ }
360
411
361
412
for ds in datasets :
362
413
for var in variables_to_merge :
@@ -427,16 +478,16 @@ def ensure_common_dims(vars):
427
478
428
479
429
480
def _dataarray_concat (
430
- arrays ,
431
- dim ,
432
- data_vars ,
433
- coords ,
434
- compat ,
435
- positions ,
436
- fill_value = dtypes .NA ,
437
- join = "outer" ,
438
- combine_attrs = "override" ,
439
- ):
481
+ arrays : Iterable [ "DataArray" ] ,
482
+ dim : Union [ str , "DataArray" , pd . Index ] ,
483
+ data_vars : Union [ str , List [ str ]] ,
484
+ coords : Union [ str , List [ str ]] ,
485
+ compat : str ,
486
+ positions : Optional [ Iterable [ int ]] ,
487
+ fill_value : object = dtypes .NA ,
488
+ join : str = "outer" ,
489
+ combine_attrs : str = "override" ,
490
+ ) -> "DataArray" :
440
491
arrays = list (arrays )
441
492
442
493
if data_vars != "all" :
0 commit comments