Skip to content

Commit 68d0e0d

Browse files
authored
annotate concat (#4346)
* annotate concat * fix annotation: _from_temp_dataset * whats new entry * Update xarray/core/concat.py * revert faulty change
1 parent d9ebcaf commit 68d0e0d

File tree

3 files changed

+81
-28
lines changed

3 files changed

+81
-28
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ Internal Changes
8585
By `Guido Imperiale <https://github.com/crusaderky>`_
8686
- Only load resource files when running inside a Jupyter Notebook
8787
(:issue:`4294`) By `Guido Imperiale <https://github.com/crusaderky>`_
88+
- Enable type checking for :py:func:`concat` (:issue:`4238`)
89+
By `Mathias Hauser <https://github.com/mathause>`_.
8890

8991

9092
.. _whats-new.0.16.0:

xarray/core/concat.py

Lines changed: 78 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
from typing import (
2+
TYPE_CHECKING,
3+
Dict,
4+
Hashable,
5+
Iterable,
6+
List,
7+
Optional,
8+
Set,
9+
Tuple,
10+
Union,
11+
overload,
12+
)
13+
114
import pandas as pd
215

316
from . import dtypes, utils
@@ -7,6 +20,40 @@
720
from .variable import IndexVariable, Variable, as_variable
821
from .variable import concat as concat_vars
922

23+
if TYPE_CHECKING:
24+
from .dataarray import DataArray
25+
from .dataset import Dataset
26+
27+
28+
@overload
29+
def concat(
30+
objs: Iterable["Dataset"],
31+
dim: Union[str, "DataArray", pd.Index],
32+
data_vars: Union[str, List[str]] = "all",
33+
coords: Union[str, List[str]] = "different",
34+
compat: str = "equals",
35+
positions: Optional[Iterable[int]] = None,
36+
fill_value: object = dtypes.NA,
37+
join: str = "outer",
38+
combine_attrs: str = "override",
39+
) -> "Dataset":
40+
...
41+
42+
43+
@overload
44+
def concat(
45+
objs: Iterable["DataArray"],
46+
dim: Union[str, "DataArray", pd.Index],
47+
data_vars: Union[str, List[str]] = "all",
48+
coords: Union[str, List[str]] = "different",
49+
compat: str = "equals",
50+
positions: Optional[Iterable[int]] = None,
51+
fill_value: object = dtypes.NA,
52+
join: str = "outer",
53+
combine_attrs: str = "override",
54+
) -> "DataArray":
55+
...
56+
1057

1158
def concat(
1259
objs,
@@ -285,13 +332,15 @@ def process_subset_opt(opt, subset):
285332

286333

287334
# determine dimensional coordinate names and a dict mapping name to DataArray
288-
def _parse_datasets(datasets):
335+
def _parse_datasets(
336+
datasets: Iterable["Dataset"],
337+
) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, int], Set[Hashable], Set[Hashable]]:
289338

290-
dims = set()
291-
all_coord_names = set()
292-
data_vars = set() # list of data_vars
293-
dim_coords = {} # maps dim name to variable
294-
dims_sizes = {} # shared dimension sizes to expand variables
339+
dims: Set[Hashable] = set()
340+
all_coord_names: Set[Hashable] = set()
341+
data_vars: Set[Hashable] = set() # list of data_vars
342+
dim_coords: Dict[Hashable, Variable] = {} # maps dim name to variable
343+
dims_sizes: Dict[Hashable, int] = {} # shared dimension sizes to expand variables
295344

296345
for ds in datasets:
297346
dims_sizes.update(ds.dims)
@@ -307,16 +356,16 @@ def _parse_datasets(datasets):
307356

308357

309358
def _dataset_concat(
310-
datasets,
311-
dim,
312-
data_vars,
313-
coords,
314-
compat,
315-
positions,
316-
fill_value=dtypes.NA,
317-
join="outer",
318-
combine_attrs="override",
319-
):
359+
datasets: List["Dataset"],
360+
dim: Union[str, "DataArray", pd.Index],
361+
data_vars: Union[str, List[str]],
362+
coords: Union[str, List[str]],
363+
compat: str,
364+
positions: Optional[Iterable[int]],
365+
fill_value: object = dtypes.NA,
366+
join: str = "outer",
367+
combine_attrs: str = "override",
368+
) -> "Dataset":
320369
"""
321370
Concatenate a sequence of datasets along a new or existing dimension
322371
"""
@@ -356,7 +405,9 @@ def _dataset_concat(
356405

357406
result_vars = {}
358407
if variables_to_merge:
359-
to_merge = {var: [] for var in variables_to_merge}
408+
to_merge: Dict[Hashable, List[Variable]] = {
409+
var: [] for var in variables_to_merge
410+
}
360411

361412
for ds in datasets:
362413
for var in variables_to_merge:
@@ -427,16 +478,16 @@ def ensure_common_dims(vars):
427478

428479

429480
def _dataarray_concat(
430-
arrays,
431-
dim,
432-
data_vars,
433-
coords,
434-
compat,
435-
positions,
436-
fill_value=dtypes.NA,
437-
join="outer",
438-
combine_attrs="override",
439-
):
481+
arrays: Iterable["DataArray"],
482+
dim: Union[str, "DataArray", pd.Index],
483+
data_vars: Union[str, List[str]],
484+
coords: Union[str, List[str]],
485+
compat: str,
486+
positions: Optional[Iterable[int]],
487+
fill_value: object = dtypes.NA,
488+
join: str = "outer",
489+
combine_attrs: str = "override",
490+
) -> "DataArray":
440491
arrays = list(arrays)
441492

442493
if data_vars != "all":

xarray/core/dataarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ def _to_temp_dataset(self) -> Dataset:
422422
return self._to_dataset_whole(name=_THIS_ARRAY, shallow_copy=False)
423423

424424
def _from_temp_dataset(
425-
self, dataset: Dataset, name: Hashable = _default
425+
self, dataset: Dataset, name: Union[Hashable, None, Default] = _default
426426
) -> "DataArray":
427427
variable = dataset._variables.pop(_THIS_ARRAY)
428428
coords = dataset._variables

0 commit comments

Comments
 (0)