diff --git a/setup.cfg b/setup.cfg index b806c57cb1a..2f64583f62d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -152,9 +152,11 @@ ignore = E501 # line too long - let black worry about that E731 # do not assign a lambda expression, use a def W503 # line break before binary operator -exclude= +exclude = .eggs doc +builtins = + ellipsis [isort] profile = black diff --git a/xarray/core/computation.py b/xarray/core/computation.py index b4a268b5c98..7d95b9bf373 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -40,7 +40,7 @@ from .coordinates import Coordinates from .dataarray import DataArray from .dataset import Dataset - from .types import CombineAttrsOptions, JoinOptions + from .types import CombineAttrsOptions, Ellipsis, JoinOptions _NO_FILL_VALUE = utils.ReprObject("") _DEFAULT_NAME = utils.ReprObject("") @@ -1622,7 +1622,11 @@ def cross( return c -def dot(*arrays, dims=None, **kwargs): +def dot( + *arrays, + dims: str | Iterable[Hashable] | Ellipsis | None = None, + **kwargs: Any, +): """Generalized dot product for xarray objects. Like np.einsum, but provides a simpler interface based on array dimensions. @@ -1711,10 +1715,7 @@ def dot(*arrays, dims=None, **kwargs): if len(arrays) == 0: raise TypeError("At least one array should be given.") - if isinstance(dims, str): - dims = (dims,) - - common_dims = set.intersection(*[set(arr.dims) for arr in arrays]) + common_dims: set[Hashable] = set.intersection(*(set(arr.dims) for arr in arrays)) all_dims = [] for arr in arrays: all_dims += [d for d in arr.dims if d not in all_dims] @@ -1724,21 +1725,25 @@ def dot(*arrays, dims=None, **kwargs): if dims is ...: dims = all_dims + elif isinstance(dims, str): + dims = (dims,) elif dims is None: # find dimensions that occur more than one times - dim_counts = Counter() + dim_counts: Counter = Counter() for arr in arrays: dim_counts.update(arr.dims) dims = tuple(d for d, c in dim_counts.items() if c > 1) - dims = tuple(dims) # make dims a tuple + dot_dims: set[Hashable] = set(dims) # type:ignore[arg-type] # dimensions to be parallelized - broadcast_dims = tuple(d for d in all_dims if d in common_dims and d not in dims) + broadcast_dims = common_dims - dot_dims input_core_dims = [ [d for d in arr.dims if d not in broadcast_dims] for arr in arrays ] - output_core_dims = [tuple(d for d in all_dims if d not in dims + broadcast_dims)] + output_core_dims = [ + [d for d in all_dims if d not in dot_dims and d not in broadcast_dims] + ] # construct einsum subscripts, such as '...abc,...ab->...c' # Note: input_core_dims are always moved to the last position diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 80e72082c0b..ba4827d7569 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -78,6 +78,7 @@ from .types import ( CoarsenBoundaryOptions, DatetimeUnitOptions, + Ellipsis, ErrorOptions, ErrorOptionsWithWarn, InterpOptions, @@ -3772,7 +3773,7 @@ def imag(self: T_DataArray) -> T_DataArray: def dot( self: T_DataArray, other: T_DataArray, - dims: Hashable | Sequence[Hashable] | None = None, + dims: str | Iterable[Hashable] | Ellipsis | None = None, ) -> T_DataArray: """Perform dot product of two DataArrays along their shared dims. @@ -3782,7 +3783,7 @@ def dot( ---------- other : DataArray The other array with which the dot product is performed. - dims : ..., Hashable or sequence of Hashable, optional + dims : ..., str or Iterable of Hashable, optional Which dimensions to sum over. Ellipsis (`...`) sums over all dimensions. If not specified, then all the common dimensions are summed over. @@ -4776,7 +4777,7 @@ def idxmax( # https://github.com/python/mypy/issues/12846 is resolved def argmin( self, - dim: Hashable | Sequence[Hashable] | None = None, + dim: Hashable | Sequence[Hashable] | Ellipsis | None = None, axis: int | None = None, keep_attrs: bool | None = None, skipna: bool | None = None, @@ -4881,7 +4882,7 @@ def argmin( # https://github.com/python/mypy/issues/12846 is resolved def argmax( self, - dim: Hashable | Sequence[Hashable] = None, + dim: Hashable | Sequence[Hashable] | Ellipsis | None = None, axis: int | None = None, keep_attrs: bool | None = None, skipna: bool | None = None, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 2844905dabc..5d401b47498 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -107,6 +107,7 @@ CombineAttrsOptions, CompatOptions, DatetimeUnitOptions, + Ellipsis, ErrorOptions, ErrorOptionsWithWarn, InterpOptions, @@ -4269,7 +4270,7 @@ def _get_stack_index( def _stack_once( self: T_Dataset, - dims: Sequence[Hashable], + dims: Sequence[Hashable | Ellipsis], new_dim: Hashable, index_cls: type[Index], create_index: bool | None = True, @@ -4328,10 +4329,10 @@ def _stack_once( def stack( self: T_Dataset, - dimensions: Mapping[Any, Sequence[Hashable]] | None = None, + dimensions: Mapping[Any, Sequence[Hashable | Ellipsis]] | None = None, create_index: bool | None = True, index_cls: type[Index] = PandasMultiIndex, - **dimensions_kwargs: Sequence[Hashable], + **dimensions_kwargs: Sequence[Hashable | Ellipsis], ) -> T_Dataset: """ Stack any number of existing dimensions into a single new dimension. diff --git a/xarray/core/types.py b/xarray/core/types.py index a291ea9d8cb..0534833a357 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -5,7 +5,8 @@ import numpy as np if TYPE_CHECKING: - from .common import DataWithCoords + + from .common import AbstractArray, DataWithCoords from .dataarray import DataArray from .dataset import Dataset from .groupby import DataArrayGroupBy, GroupBy @@ -34,13 +35,19 @@ # from typing_extensions import Self # except ImportError: # Self: Any = None - Self: Any = None + Self = TypeVar("Self") + + Ellipsis = ellipsis + else: Self: Any = None + Ellipsis: Any = None + T_Dataset = TypeVar("T_Dataset", bound="Dataset") T_DataArray = TypeVar("T_DataArray", bound="DataArray") T_Variable = TypeVar("T_Variable", bound="Variable") +T_Array = TypeVar("T_Array", bound="AbstractArray") T_Index = TypeVar("T_Index", bound="Index") T_DataArrayOrSet = TypeVar("T_DataArrayOrSet", bound=Union["Dataset", "DataArray"]) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index e61db0517ea..8b457139b6a 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -79,6 +79,7 @@ def _get_NON_NUMPY_SUPPORTED_ARRAY_TYPES(): if TYPE_CHECKING: from .types import ( + Ellipsis, ErrorOptionsWithWarn, PadModeOptions, PadReflectOptions, @@ -1529,7 +1530,7 @@ def roll(self, shifts=None, **shifts_kwargs): def transpose( self, - *dims: Hashable, + *dims: Hashable | Ellipsis, missing_dims: ErrorOptionsWithWarn = "raise", ) -> Variable: """Return a new Variable object with transposed dimensions. @@ -2606,7 +2607,7 @@ def _to_numeric(self, offset=None, datetime_unit=None, dtype=float): def _unravel_argminmax( self, argminmax: str, - dim: Hashable | Sequence[Hashable] | None, + dim: Hashable | Sequence[Hashable] | Ellipsis | None, axis: int | None, keep_attrs: bool | None, skipna: bool | None, @@ -2675,7 +2676,7 @@ def _unravel_argminmax( def argmin( self, - dim: Hashable | Sequence[Hashable] = None, + dim: Hashable | Sequence[Hashable] | Ellipsis | None = None, axis: int = None, keep_attrs: bool = None, skipna: bool = None, diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 730cf9eac8f..9d5be6f7126 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -9,7 +9,7 @@ from .computation import apply_ufunc, dot from .npcompat import ArrayLike from .pycompat import is_duck_dask_array -from .types import T_Xarray +from .types import Ellipsis, T_Xarray # Weighted quantile methods are a subset of the numpy supported quantile methods. QUANTILE_METHODS = Literal[ @@ -206,7 +206,7 @@ def _check_dim(self, dim: Hashable | Iterable[Hashable] | None): def _reduce( da: DataArray, weights: DataArray, - dim: Hashable | Iterable[Hashable] | None = None, + dim: str | Iterable[Hashable] | Ellipsis | None = None, skipna: bool | None = None, ) -> DataArray: """reduce using dot; equivalent to (da * weights).sum(dim, skipna) @@ -227,7 +227,7 @@ def _reduce( return dot(da, weights, dims=dim) def _sum_of_weights( - self, da: DataArray, dim: Hashable | Iterable[Hashable] | None = None + self, da: DataArray, dim: str | Iterable[Hashable] | None = None ) -> DataArray: """Calculate the sum of weights, accounting for missing values""" @@ -251,7 +251,7 @@ def _sum_of_weights( def _sum_of_squares( self, da: DataArray, - dim: Hashable | Iterable[Hashable] | None = None, + dim: str | Iterable[Hashable] | None = None, skipna: bool | None = None, ) -> DataArray: """Reduce a DataArray by a weighted ``sum_of_squares`` along some dimension(s).""" @@ -263,7 +263,7 @@ def _sum_of_squares( def _weighted_sum( self, da: DataArray, - dim: Hashable | Iterable[Hashable] | None = None, + dim: str | Iterable[Hashable] | None = None, skipna: bool | None = None, ) -> DataArray: """Reduce a DataArray by a weighted ``sum`` along some dimension(s).""" @@ -273,7 +273,7 @@ def _weighted_sum( def _weighted_mean( self, da: DataArray, - dim: Hashable | Iterable[Hashable] | None = None, + dim: str | Iterable[Hashable] | None = None, skipna: bool | None = None, ) -> DataArray: """Reduce a DataArray by a weighted ``mean`` along some dimension(s).""" @@ -287,7 +287,7 @@ def _weighted_mean( def _weighted_var( self, da: DataArray, - dim: Hashable | Iterable[Hashable] | None = None, + dim: str | Iterable[Hashable] | None = None, skipna: bool | None = None, ) -> DataArray: """Reduce a DataArray by a weighted ``var`` along some dimension(s).""" @@ -301,7 +301,7 @@ def _weighted_var( def _weighted_std( self, da: DataArray, - dim: Hashable | Iterable[Hashable] | None = None, + dim: str | Iterable[Hashable] | None = None, skipna: bool | None = None, ) -> DataArray: """Reduce a DataArray by a weighted ``std`` along some dimension(s).""" diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index cc40bfb0265..d93adf08474 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1732,7 +1732,7 @@ def apply_truncate_x_x_valid(obj): @pytest.mark.parametrize("use_dask", [True, False]) -def test_dot(use_dask) -> None: +def test_dot(use_dask: bool) -> None: if use_dask: if not has_dask: pytest.skip("test for dask.") @@ -1862,7 +1862,7 @@ def test_dot(use_dask) -> None: @pytest.mark.parametrize("use_dask", [True, False]) -def test_dot_align_coords(use_dask) -> None: +def test_dot_align_coords(use_dask: bool) -> None: # GH 3694 if use_dask: diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 514d177263b..ab6e5763248 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -6420,7 +6420,7 @@ def test_deepcopy_obj_array() -> None: assert x0.values[0] is not x1.values[0] -def test_clip(da) -> None: +def test_clip(da: DataArray) -> None: with raise_if_dask_computes(): result = da.clip(min=0.5) assert result.min(...) >= 0.5