Skip to content

Commit 3f587b4

Browse files
Improved DataArray typing (#6637)
* add .env to gitignore * improved typing for dataarray * even more typing * even further typing * finish typing dataArray * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix some cast type not imported problems * fix another cast import bug * fix some error message regexes * fix import and typo * fix wrong case in intp_dimorder test * type all test_dataarray tests * fix typing in test_dataarray Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 607a927 commit 3f587b4

File tree

9 files changed

+1187
-921
lines changed

9 files changed

+1187
-921
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
*.py[cod]
22
__pycache__
3+
.env
4+
.venv
35

46
# example caches from Hypothesis
57
.hypothesis/

xarray/core/alignment.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Tuple,
1717
Type,
1818
TypeVar,
19+
cast,
1920
)
2021

2122
import numpy as np
@@ -30,7 +31,7 @@
3031
if TYPE_CHECKING:
3132
from .dataarray import DataArray
3233
from .dataset import Dataset
33-
from .types import JoinOptions
34+
from .types import JoinOptions, T_DataArray, T_DataArrayOrSet, T_Dataset
3435

3536
DataAlignable = TypeVar("DataAlignable", bound=DataWithCoords)
3637

@@ -559,7 +560,7 @@ def align(self) -> None:
559560
def align(
560561
*objects: DataAlignable,
561562
join: JoinOptions = "inner",
562-
copy=True,
563+
copy: bool = True,
563564
indexes=None,
564565
exclude=frozenset(),
565566
fill_value=dtypes.NA,
@@ -592,7 +593,7 @@ def align(
592593
those of the first object with that dimension. Indexes for the same
593594
dimension must have the same size in all objects.
594595
595-
copy : bool, optional
596+
copy : bool, default: True
596597
If ``copy=True``, data in the return values is always copied. If
597598
``copy=False`` and reindexing is unnecessary, or can be performed with
598599
only slice operations, then the output may share memory with the input.
@@ -609,7 +610,7 @@ def align(
609610
610611
Returns
611612
-------
612-
aligned : DataArray or Dataset
613+
aligned : tuple of DataArray or Dataset
613614
Tuple of objects with the same type as `*objects` with aligned
614615
coordinates.
615616
@@ -935,7 +936,9 @@ def _get_broadcast_dims_map_common_coords(args, exclude):
935936
return dims_map, common_coords
936937

937938

938-
def _broadcast_helper(arg, exclude, dims_map, common_coords):
939+
def _broadcast_helper(
940+
arg: T_DataArrayOrSet, exclude, dims_map, common_coords
941+
) -> T_DataArrayOrSet:
939942

940943
from .dataarray import DataArray
941944
from .dataset import Dataset
@@ -950,22 +953,25 @@ def _set_dims(var):
950953

951954
return var.set_dims(var_dims_map)
952955

953-
def _broadcast_array(array):
956+
def _broadcast_array(array: T_DataArray) -> T_DataArray:
954957
data = _set_dims(array.variable)
955958
coords = dict(array.coords)
956959
coords.update(common_coords)
957-
return DataArray(data, coords, data.dims, name=array.name, attrs=array.attrs)
960+
return array.__class__(
961+
data, coords, data.dims, name=array.name, attrs=array.attrs
962+
)
958963

959-
def _broadcast_dataset(ds):
964+
def _broadcast_dataset(ds: T_Dataset) -> T_Dataset:
960965
data_vars = {k: _set_dims(ds.variables[k]) for k in ds.data_vars}
961966
coords = dict(ds.coords)
962967
coords.update(common_coords)
963-
return Dataset(data_vars, coords, ds.attrs)
968+
return ds.__class__(data_vars, coords, ds.attrs)
964969

970+
# remove casts once https://github.com/python/mypy/issues/12800 is resolved
965971
if isinstance(arg, DataArray):
966-
return _broadcast_array(arg)
972+
return cast("T_DataArrayOrSet", _broadcast_array(arg))
967973
elif isinstance(arg, Dataset):
968-
return _broadcast_dataset(arg)
974+
return cast("T_DataArrayOrSet", _broadcast_dataset(arg))
969975
else:
970976
raise ValueError("all input must be Dataset or DataArray objects")
971977

0 commit comments

Comments
 (0)