Skip to content

Commit

Permalink
Fix some mypy issues (pydata#6531)
Browse files Browse the repository at this point in the history
* Fix some mypy issues

Unfortunately these have crept back in. I'll add a workflow job, since the pre-commit is not covering everything on its own

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add mypy workflow

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
max-sixty and pre-commit-ci[bot] committed Apr 28, 2022
1 parent 73a9932 commit 1fb9050
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 57 deletions.
35 changes: 35 additions & 0 deletions .github/workflows/ci-additional.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,41 @@ jobs:
run: |
python -m pytest --doctest-modules xarray --ignore xarray/tests
mypy:
name: Mypy
runs-on: "ubuntu-latest"
if: needs.detect-ci-trigger.outputs.triggered == 'false'
defaults:
run:
shell: bash -l {0}

steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0 # Fetch all history for all branches and tags.
- uses: conda-incubator/setup-miniconda@v2
with:
channels: conda-forge
channel-priority: strict
mamba-version: "*"
activate-environment: xarray-tests
auto-update-conda: false
python-version: "3.9"

- name: Install conda dependencies
run: |
mamba env update -f ci/requirements/environment.yml
- name: Install xarray
run: |
python -m pip install --no-deps -e .
- name: Version info
run: |
conda info -a
conda list
python xarray/util/print_versions.py
- name: Run mypy
run: mypy

min-version-policy:
name: Minimum Version Policy
runs-on: "ubuntu-latest"
Expand Down
5 changes: 3 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,7 +1154,8 @@ def chunk(
chunks = {}

if isinstance(chunks, (float, str, int)):
chunks = dict.fromkeys(self.dims, chunks)
# ignoring type; unclear why it won't accept a Literal into the value.
chunks = dict.fromkeys(self.dims, chunks) # type: ignore
elif isinstance(chunks, (tuple, list)):
chunks = dict(zip(self.dims, chunks))
else:
Expand Down Expand Up @@ -4735,7 +4736,7 @@ def curvefit(

def drop_duplicates(
self,
dim: Hashable | Iterable[Hashable] | ...,
dim: Hashable | Iterable[Hashable],
keep: Literal["first", "last"] | Literal[False] = "first",
):
"""Returns a new DataArray with duplicate dimension values removed.
Expand Down
6 changes: 4 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7981,7 +7981,7 @@ def _wrapper(Y, *coords_, **kwargs):

def drop_duplicates(
self,
dim: Hashable | Iterable[Hashable] | ...,
dim: Hashable | Iterable[Hashable],
keep: Literal["first", "last"] | Literal[False] = "first",
):
"""Returns a new Dataset with duplicate dimension values removed.
Expand All @@ -8005,9 +8005,11 @@ def drop_duplicates(
DataArray.drop_duplicates
"""
if isinstance(dim, str):
dims = (dim,)
dims: Iterable = (dim,)
elif dim is ...:
dims = self.dims
elif not isinstance(dim, Iterable):
dims = [dim]
else:
dims = dim

Expand Down
76 changes: 33 additions & 43 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import enum
import functools
import operator
Expand All @@ -6,19 +8,7 @@
from dataclasses import dataclass, field
from datetime import timedelta
from html import escape
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Hashable,
Iterable,
List,
Mapping,
Optional,
Tuple,
Union,
)
from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterable, Mapping

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -59,12 +49,12 @@ class IndexSelResult:
"""

dim_indexers: Dict[Any, Any]
indexes: Dict[Any, "Index"] = field(default_factory=dict)
variables: Dict[Any, "Variable"] = field(default_factory=dict)
drop_coords: List[Hashable] = field(default_factory=list)
drop_indexes: List[Hashable] = field(default_factory=list)
rename_dims: Dict[Any, Hashable] = field(default_factory=dict)
dim_indexers: dict[Any, Any]
indexes: dict[Any, Index] = field(default_factory=dict)
variables: dict[Any, Variable] = field(default_factory=dict)
drop_coords: list[Hashable] = field(default_factory=list)
drop_indexes: list[Hashable] = field(default_factory=list)
rename_dims: dict[Any, Hashable] = field(default_factory=dict)

def as_tuple(self):
"""Unlike ``dataclasses.astuple``, return a shallow copy.
Expand All @@ -82,7 +72,7 @@ def as_tuple(self):
)


def merge_sel_results(results: List[IndexSelResult]) -> IndexSelResult:
def merge_sel_results(results: list[IndexSelResult]) -> IndexSelResult:
all_dims_count = Counter([dim for res in results for dim in res.dim_indexers])
duplicate_dims = {k: v for k, v in all_dims_count.items() if v > 1}

Expand Down Expand Up @@ -124,13 +114,13 @@ def group_indexers_by_index(
obj: T_Xarray,
indexers: Mapping[Any, Any],
options: Mapping[str, Any],
) -> List[Tuple["Index", Dict[Any, Any]]]:
) -> list[tuple[Index, dict[Any, Any]]]:
"""Returns a list of unique indexes and their corresponding indexers."""
unique_indexes = {}
grouped_indexers: Mapping[Union[int, None], Dict] = defaultdict(dict)
grouped_indexers: Mapping[int | None, dict] = defaultdict(dict)

for key, label in indexers.items():
index: "Index" = obj.xindexes.get(key, None)
index: Index = obj.xindexes.get(key, None)

if index is not None:
index_id = id(index)
Expand Down Expand Up @@ -787,7 +777,7 @@ class IndexingSupport(enum.Enum):

def explicit_indexing_adapter(
key: ExplicitIndexer,
shape: Tuple[int, ...],
shape: tuple[int, ...],
indexing_support: IndexingSupport,
raw_indexing_method: Callable,
) -> Any:
Expand Down Expand Up @@ -821,8 +811,8 @@ def explicit_indexing_adapter(


def decompose_indexer(
indexer: ExplicitIndexer, shape: Tuple[int, ...], indexing_support: IndexingSupport
) -> Tuple[ExplicitIndexer, ExplicitIndexer]:
indexer: ExplicitIndexer, shape: tuple[int, ...], indexing_support: IndexingSupport
) -> tuple[ExplicitIndexer, ExplicitIndexer]:
if isinstance(indexer, VectorizedIndexer):
return _decompose_vectorized_indexer(indexer, shape, indexing_support)
if isinstance(indexer, (BasicIndexer, OuterIndexer)):
Expand All @@ -848,9 +838,9 @@ def _decompose_slice(key, size):

def _decompose_vectorized_indexer(
indexer: VectorizedIndexer,
shape: Tuple[int, ...],
shape: tuple[int, ...],
indexing_support: IndexingSupport,
) -> Tuple[ExplicitIndexer, ExplicitIndexer]:
) -> tuple[ExplicitIndexer, ExplicitIndexer]:
"""
Decompose vectorized indexer to the successive two indexers, where the
first indexer will be used to index backend arrays, while the second one
Expand Down Expand Up @@ -929,10 +919,10 @@ def _decompose_vectorized_indexer(


def _decompose_outer_indexer(
indexer: Union[BasicIndexer, OuterIndexer],
shape: Tuple[int, ...],
indexer: BasicIndexer | OuterIndexer,
shape: tuple[int, ...],
indexing_support: IndexingSupport,
) -> Tuple[ExplicitIndexer, ExplicitIndexer]:
) -> tuple[ExplicitIndexer, ExplicitIndexer]:
"""
Decompose outer indexer to the successive two indexers, where the
first indexer will be used to index backend arrays, while the second one
Expand Down Expand Up @@ -973,7 +963,7 @@ def _decompose_outer_indexer(
return indexer, BasicIndexer(())
assert isinstance(indexer, (OuterIndexer, BasicIndexer))

backend_indexer: List[Any] = []
backend_indexer: list[Any] = []
np_indexer = []
# make indexer positive
pos_indexer: list[np.ndarray | int | np.number] = []
Expand Down Expand Up @@ -1395,7 +1385,7 @@ def __array__(self, dtype: DTypeLike = None) -> np.ndarray:
return np.asarray(array.values, dtype=dtype)

@property
def shape(self) -> Tuple[int]:
def shape(self) -> tuple[int]:
return (len(self.array),)

def _convert_scalar(self, item):
Expand All @@ -1420,13 +1410,13 @@ def _convert_scalar(self, item):

def __getitem__(
self, indexer
) -> Union[
"PandasIndexingAdapter",
NumpyIndexingAdapter,
np.ndarray,
np.datetime64,
np.timedelta64,
]:
) -> (
PandasIndexingAdapter
| NumpyIndexingAdapter
| np.ndarray
| np.datetime64
| np.timedelta64
):
key = indexer.tuple
if isinstance(key, tuple) and len(key) == 1:
# unpack key so it can index a pandas.Index object (pandas.Index
Expand All @@ -1449,7 +1439,7 @@ def transpose(self, order) -> pd.Index:
def __repr__(self) -> str:
return f"{type(self).__name__}(array={self.array!r}, dtype={self.dtype!r})"

def copy(self, deep: bool = True) -> "PandasIndexingAdapter":
def copy(self, deep: bool = True) -> PandasIndexingAdapter:
# Not the same as just writing `self.array.copy(deep=deep)`, as
# shallow copies of the underlying numpy.ndarrays become deep ones
# upon pickling
Expand All @@ -1476,7 +1466,7 @@ def __init__(
self,
array: pd.MultiIndex,
dtype: DTypeLike = None,
level: Optional[str] = None,
level: str | None = None,
):
super().__init__(array, dtype)
self.level = level
Expand Down Expand Up @@ -1535,7 +1525,7 @@ def _repr_html_(self) -> str:
array_repr = short_numpy_repr(self._get_array_subset())
return f"<pre>{escape(array_repr)}</pre>"

def copy(self, deep: bool = True) -> "PandasMultiIndexingAdapter":
def copy(self, deep: bool = True) -> PandasMultiIndexingAdapter:
# see PandasIndexingAdapter.copy
array = self.array.copy(deep=True) if deep else self.array
return type(self)(array, self._dtype, self.level)
3 changes: 2 additions & 1 deletion xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ def remove_incompatible_items(
del first_dict[k]


def is_dict_like(value: Any) -> bool:
# It's probably OK to give this as a TypeGuard; though it's not perfectly robust.
def is_dict_like(value: Any) -> TypeGuard[dict]:
return hasattr(value, "keys") and hasattr(value, "__getitem__")


Expand Down
10 changes: 4 additions & 6 deletions xarray/tests/test_coding_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,20 +1007,18 @@ def test_decode_ambiguous_time_warns(calendar) -> None:
units = "days since 1-1-1"
expected = num2date(dates, units, calendar=calendar, only_use_cftime_datetimes=True)

exp_warn_type = SerializationWarning if is_standard_calendar else None

with pytest.warns(exp_warn_type) as record:
result = decode_cf_datetime(dates, units, calendar=calendar)

if is_standard_calendar:
with pytest.warns(SerializationWarning) as record:
result = decode_cf_datetime(dates, units, calendar=calendar)
relevant_warnings = [
r
for r in record.list
if str(r.message).startswith("Ambiguous reference date string: 1-1-1")
]
assert len(relevant_warnings) == 1
else:
assert not record
with assert_no_warnings():
result = decode_cf_datetime(dates, units, calendar=calendar)

np.testing.assert_array_equal(result, expected)

Expand Down
6 changes: 3 additions & 3 deletions xarray/tests/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,10 +480,10 @@ def test_short_numpy_repr() -> None:
assert num_lines < 30

# threshold option (default: 200)
array = np.arange(100)
assert "..." not in formatting.short_numpy_repr(array)
array2 = np.arange(100)
assert "..." not in formatting.short_numpy_repr(array2)
with xr.set_options(display_values_threshold=10):
assert "..." in formatting.short_numpy_repr(array)
assert "..." in formatting.short_numpy_repr(array2)


def test_large_array_repr_length() -> None:
Expand Down

0 comments on commit 1fb9050

Please sign in to comment.