Skip to content

Commit 016964b

Browse files
authored
Make typing strict (#1879)
* Disallow subclassing any * Disallow returning any * Enable strict * Fix literal import * pre-commit fixes * Remove old group imports
1 parent 9ad01f1 commit 016964b

File tree

9 files changed

+29
-26
lines changed

9 files changed

+29
-26
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,5 @@ repos:
3131
- types-setuptools
3232
- pytest
3333
- numpy
34+
- numcodecs
35+
- zstandard

pyproject.toml

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -187,21 +187,7 @@ python_version = "3.10"
187187
ignore_missing_imports = true
188188
namespace_packages = false
189189

190-
warn_unused_configs = true
191-
warn_redundant_casts = true
192-
warn_unused_ignores = true
193-
strict_equality = true
194-
strict_concatenate = true
195-
196-
check_untyped_defs = true
197-
disallow_untyped_decorators = true
198-
disallow_any_generics = true
199-
200-
disallow_incomplete_defs = true
201-
disallow_untyped_calls = true
202-
203-
disallow_untyped_defs = true
204-
no_implicit_reexport = true
190+
strict = true
205191

206192

207193
[[tool.mypy.overrides]]
@@ -238,6 +224,13 @@ module = [
238224
disallow_untyped_defs = false
239225

240226

227+
[[tool.mypy.overrides]]
228+
module = [
229+
"zarr.metadata",
230+
"zarr.store.remote"
231+
]
232+
warn_return_any = false
233+
241234
[tool.pytest.ini_options]
242235
minversion = "7"
243236
testpaths = ["tests"]

src/zarr/codecs/sharding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def is_all_empty(self) -> bool:
101101
return bool(np.array_equiv(self.offsets_and_lengths, MAX_UINT_64))
102102

103103
def get_full_chunk_map(self) -> npt.NDArray[np.bool_]:
104-
return self.offsets_and_lengths[..., 0] != MAX_UINT_64
104+
return np.not_equal(self.offsets_and_lengths[..., 0], MAX_UINT_64)
105105

106106
def get_chunk_slice(self, chunk_coords: ChunkCoords) -> tuple[int, int] | None:
107107
localized_chunk = self._localize_chunk(chunk_coords)

src/zarr/codecs/zstd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,11 @@ def to_dict(self) -> dict[str, JSON]:
5555

5656
def _compress(self, data: npt.NDArray[Any]) -> bytes:
5757
ctx = ZstdCompressor(level=self.level, write_checksum=self.checksum)
58-
return ctx.compress(data)
58+
return ctx.compress(data.tobytes())
5959

6060
def _decompress(self, data: npt.NDArray[Any]) -> bytes:
6161
ctx = ZstdDecompressor()
62-
return ctx.decompress(data)
62+
return ctx.decompress(data.tobytes())
6363

6464
async def _decode_single(
6565
self,

src/zarr/common.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,15 @@
77
from collections.abc import Iterable
88
from dataclasses import dataclass
99
from enum import Enum
10-
from typing import TYPE_CHECKING, Any, Literal, ParamSpec, TypeVar, overload
10+
from typing import (
11+
TYPE_CHECKING,
12+
Any,
13+
Literal,
14+
ParamSpec,
15+
TypeVar,
16+
cast,
17+
overload,
18+
)
1119

1220
if TYPE_CHECKING:
1321
from collections.abc import Awaitable, Callable, Iterator
@@ -178,5 +186,5 @@ def parse_fill_value(data: Any) -> Any:
178186

179187
def parse_order(data: Any) -> Literal["C", "F"]:
180188
if data in ("C", "F"):
181-
return data
189+
return cast(Literal["C", "F"], data)
182190
raise ValueError(f"Expected one of ('C', 'F'), got {data} instead.")

src/zarr/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any, Literal
3+
from typing import Any, Literal, cast
44

55
from donfig import Config
66

@@ -18,6 +18,6 @@
1818

1919
def parse_indexing_order(data: Any) -> Literal["C", "F"]:
2020
if data in ("C", "F"):
21-
return data
21+
return cast(Literal["C", "F"], data)
2222
msg = f"Expected one of ('C', 'F'), got {data} instead."
2323
raise ValueError(msg)

src/zarr/group.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
from collections.abc import Iterator
77
from dataclasses import asdict, dataclass, field, replace
8-
from typing import TYPE_CHECKING, overload
8+
from typing import TYPE_CHECKING, Literal, cast, overload
99

1010
import numpy.typing as npt
1111

@@ -37,7 +37,7 @@
3737

3838
def parse_zarr_format(data: Any) -> ZarrFormat:
3939
if data in (2, 3):
40-
return data
40+
return cast(Literal[2, 3], data)
4141
msg = msg = f"Invalid zarr_format. Expected one 2 or 3. Got {data}."
4242
raise ValueError(msg)
4343

src/zarr/v2/n5.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,7 @@ def compressor_config_to_zarr(compressor_config: Dict[str, Any]) -> Optional[Dic
780780
return zarr_config
781781

782782

783-
class N5ChunkWrapper(Codec):
783+
class N5ChunkWrapper(Codec): # type: ignore[misc]
784784
codec_id = "n5_wrapper"
785785

786786
def __init__(self, dtype, chunk_shape, compressor_config=None, compressor=None):

src/zarr/v2/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def get_type(self):
444444
return type(self.obj).__name__
445445

446446

447-
class TreeTraversal(Traversal):
447+
class TreeTraversal(Traversal): # type: ignore[misc]
448448
def get_children(self, node):
449449
return node.get_children()
450450

0 commit comments

Comments
 (0)