Skip to content

Commit

Permalink
Merge pull request #83 from ImperialCollegeLondon/feature/mypy_options
Browse files Browse the repository at this point in the history
Strict requirement for type hinting
  • Loading branch information
jacobcook1995 authored Oct 11, 2022
2 parents ffeb911 + bbc458a commit bc45fc5
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 30 deletions.
29 changes: 17 additions & 12 deletions docs/source/development/training/test_mfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from mfm import TimesTable, my_float_multiplier, my_picky_float_multiplier


def test_fm():
def test_fm() -> None:

assert 10 == my_float_multiplier(2, 5)


def test_pfm():
def test_pfm() -> None:

with pytest.raises(ValueError) as err_hndlr:

Expand All @@ -18,7 +18,7 @@ def test_pfm():
assert str(err_hndlr.value) == "Both x and y must be of type float"


def test_pfm_fail():
def test_pfm_fail() -> None:

with pytest.raises(ValueError) as err_hndlr:

Expand All @@ -36,7 +36,7 @@ def test_pfm_fail():
(-1.5, -3.0, 4.5),
],
)
def test_pfm_param_noid(x, y, expected):
def test_pfm_param_noid(x: float, y: float, expected: float) -> None:

assert expected == my_picky_float_multiplier(x, y)

Expand All @@ -51,7 +51,7 @@ def test_pfm_param_noid(x, y, expected):
],
ids=["++", "-+", "+-", "--"],
)
def test_pfm_param_ids(x, y, expected):
def test_pfm_param_ids(x: float, y: float, expected: float) -> None:

assert expected == my_picky_float_multiplier(x, y)

Expand All @@ -66,13 +66,13 @@ def test_pfm_param_ids(x, y, expected):
argvalues=[(3.0,), (-3.0,)],
ids=["+", "-"],
)
def test_pfm_twoparam(x, y):
def test_pfm_twoparam(x: float, y: float) -> None:

assert 4.5 == abs(my_picky_float_multiplier(x, y))


@pytest.fixture
def twoparam_expected():
def twoparam_expected() -> dict[str, float]:

expected = {"+-+": 4.5, "---": 4.5, "--+": -4.5, "+--": -4.5}
return expected
Expand All @@ -88,27 +88,32 @@ def twoparam_expected():
argvalues=[(3.0,), (-3.0,)],
ids=["+", "-"],
)
def test_pfm_twoparam_fixture(request, twoparam_expected, x, y):
def test_pfm_twoparam_fixture(
request: pytest.FixtureRequest,
twoparam_expected: dict[str, float],
x: float,
y: float,
) -> None:

expected = twoparam_expected[request.node.callspec.id]

assert expected == my_picky_float_multiplier(x, y)


@pytest.fixture()
def times_table_instance():
def times_table_instance() -> TimesTable:
return TimesTable(num=7)


def test_times_table_errors(times_table_instance):
def test_times_table_errors(times_table_instance: TimesTable) -> None:

with pytest.raises(TypeError) as err_hndlr:
times_table_instance.table(1.6, 23.9)
times_table_instance.table(1.6, 23.9) # type: ignore

assert str(err_hndlr.value) == "'float' object cannot be interpreted as an integer"


def test_times_table_values(times_table_instance):
def test_times_table_values(times_table_instance: TimesTable) -> None:

value = times_table_instance.table(2, 7)
assert value == [14, 21, 28, 35, 42, 49]
13 changes: 11 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ tomli-w = "^1.0.0"
scipy = "^1.9.0"
jsonschema = "^4.14.0"
Shapely = "^1.8.4"
types-jsonschema = "^4.16.1"

[tool.poetry.dev-dependencies]
pytest = "^7.1.2"
Expand Down
10 changes: 9 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,20 @@ extend-ignore =
docstring-convention = google

[mypy]
ignore_missing_imports = True
ignore_missing_imports = False
strict_optional = False
disallow_untyped_calls = True
disallow_untyped_defs = True
disallow_incomplete_defs = True

[mypy-setup]
ignore_errors = True

[mypy-tests.*]
disallow_untyped_calls = False
disallow_untyped_defs = False
disallow_incomplete_defs = False

[isort]
profile = black
multi_line_output = 3
Expand Down
4 changes: 3 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Collection of fixtures to assist the testing scripts."""

import pytest

# An import of LOGGER is required for INFO logging events to be visible to tests
# This can be removed as soon as a script that imports logger is imported
import virtual_rainforest.core.logger # noqa


def log_check(caplog, expected_log):
def log_check(caplog: pytest.LogCaptureFixture, expected_log: tuple[tuple]) -> None:
"""Helper function to check that the captured log is as expected.
Arguments:
Expand Down
8 changes: 5 additions & 3 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
),
],
)
def test_check_dict_leaves(d_a, d_b, overlap):
def test_check_dict_leaves(d_a: dict, d_b: dict, overlap: list) -> None:
"""Checks overlapping dictionary search function."""
assert overlap == config.check_dict_leaves(d_a, d_b, [])

Expand Down Expand Up @@ -203,7 +203,7 @@ def test_find_schema(caplog, config_dict, expected_exception, expected_log_entri
log_check(caplog, expected_log_entries)


def test_construct_combined_schema(caplog):
def test_construct_combined_schema(caplog: pytest.LogCaptureFixture) -> None:
"""Checks errors for bad or missing json schema."""

# Check that construct_combined_schema fails as expected
Expand Down Expand Up @@ -241,6 +241,8 @@ def test_construct_combined_schema(caplog):
def test_final_validation_log(caplog, expected_log_entries):
"""Checks that validation passes as expected and produces the correct output."""

print(type(expected_log_entries))

config.validate_config(["tests/fixtures"], out_file_name="complete_config")

# Remove generated output file
Expand Down Expand Up @@ -321,7 +323,7 @@ def test_register_schema_errors(
with pytest.raises(expected_exception):

@register_schema(schema_name)
def to_be_decorated():
def to_be_decorated() -> dict:
return schema

to_be_decorated()
Expand Down
2 changes: 1 addition & 1 deletion virtual_rainforest/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@register_schema("core")
def schema():
def schema() -> dict:
"""Defines the schema that the core module configuration should conform to."""

schema_file = Path(__file__).parent.resolve() / "core_schema.json"
Expand Down
8 changes: 5 additions & 3 deletions virtual_rainforest/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pathlib import Path
from typing import Callable, Union

import dpath.util
import dpath.util # type: ignore
import jsonschema
import tomli_w

Expand All @@ -37,7 +37,7 @@ def register_schema(module_name: str) -> Callable:
KeyError: If a module schema is missing one of the required keys
"""

def wrap(func: Callable):
def wrap(func: Callable) -> Callable:
if module_name in SCHEMA_REGISTRY:
log_and_raise(
f"The module schema {module_name} is used multiple times, this "
Expand Down Expand Up @@ -75,7 +75,9 @@ def wrap(func: Callable):
COMPLETE_CONFIG: dict = {}


def check_dict_leaves(d1: dict, d2: dict, conflicts: list = [], path: list = []):
def check_dict_leaves(
d1: dict, d2: dict, conflicts: list = [], path: list = []
) -> list:
"""Recursively checks if leaves are repeated between two nested dictionaries.
Args:
Expand Down
12 changes: 6 additions & 6 deletions virtual_rainforest/core/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@

import json
import logging
from typing import Callable
from typing import Any, Callable

import numpy as np
from shapely.affinity import scale, translate
from shapely.geometry import Polygon
from shapely.affinity import scale, translate # type: ignore
from shapely.geometry import Polygon # type: ignore

LOGGER = logging.getLogger("virtual_rainforest.core")

Expand Down Expand Up @@ -250,7 +250,7 @@ def __repr__(self) -> str:
f"cell_ny={self.cell_ny})"
)

def dumps(self, dp: int = 2, **kwargs) -> str:
def dumps(self, dp: int = 2, **kwargs: Any) -> str:
"""Export a grid as a GeoJSON string.
The virtual_rainforest.core.Grid object assumes an unspecified projected
Expand All @@ -267,7 +267,7 @@ def dumps(self, dp: int = 2, **kwargs) -> str:
content = self._get_geojson(dp=dp)
return json.dumps(obj=content, **kwargs)

def dump(self, outfile: str, dp: int = 2, **kwargs) -> None:
def dump(self, outfile: str, dp: int = 2, **kwargs: Any) -> None:
"""Export a grid as a GeoJSON file.
The virtual_rainforest.core.Grid object assumes an unspecified projected
Expand All @@ -287,7 +287,7 @@ def dump(self, outfile: str, dp: int = 2, **kwargs) -> None:
with open(outfile, "w") as outf:
json.dump(obj=content, fp=outf, **kwargs)

def _get_geojson(self, dp):
def _get_geojson(self, dp: int) -> dict:
"""Convert the grid to a GeoJSON structured dictionary.
Args:
Expand Down
2 changes: 1 addition & 1 deletion virtual_rainforest/plants/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@register_schema("plants")
def schema():
def schema() -> dict:
"""Defines the schema that the plant module configuration should conform to."""

schema_file = Path(__file__).parent.resolve() / "plants_schema.json"
Expand Down

0 comments on commit bc45fc5

Please sign in to comment.