Skip to content

Add mypy to pre-commit and fix all current typing issues #414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Apr 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
2abccc6
Remove _RequireAttrsABCMeta metaclass and replace with simple check
basnijholt Apr 8, 2023
556dded
add mypy
basnijholt Apr 28, 2023
31a3632
Fix some typing issues
basnijholt Apr 28, 2023
1f08169
Fix some typing issues
basnijholt Apr 28, 2023
d4cdf3d
Fix some typing issues
basnijholt Apr 29, 2023
41893ac
fix all Runner type issues
basnijholt Apr 29, 2023
b2b1a9a
fix all DataSaver type issues
basnijholt Apr 29, 2023
f53e49a
fix all IntegratorLearner type issues
basnijholt Apr 29, 2023
1429edf
fix all SequenceLearner type issues
basnijholt Apr 29, 2023
f6a10cb
some fixes
basnijholt Apr 29, 2023
9142230
some fixes
basnijholt Apr 29, 2023
6046fd5
some fixes
basnijholt Apr 29, 2023
27f4154
Fix multiple issues
basnijholt Apr 29, 2023
e9b84aa
Fix all mypy issues
basnijholt Apr 29, 2023
d276547
Make data a dict
basnijholt Apr 29, 2023
d55250d
Merge remote-tracking branch 'origin/main' into remove-_RequireAttrsA…
basnijholt Apr 29, 2023
49b295a
make BaseLearner a ABC
basnijholt Apr 29, 2023
218f5e0
Merge branch 'remove-_RequireAttrsABCMeta' into add-mypy
basnijholt Apr 29, 2023
ae733d1
remove BaseLearner._check_required_attributes()
basnijholt Apr 29, 2023
d1664bd
Merge remote-tracking branch 'origin/main' into add-mypy
basnijholt Apr 29, 2023
05a25d6
remove unused deps
basnijholt Apr 29, 2023
8765ee8
Wrap in TYPE_CHECKING
basnijholt Apr 29, 2023
319575a
pin ipython
basnijholt Apr 29, 2023
466d5c3
pin ipython
basnijholt Apr 29, 2023
e8d2f6b
Add NotImplemented methods
basnijholt Apr 29, 2023
802e239
remove unused import
basnijholt Apr 29, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,8 @@ repos:
- id: nbqa
args: ["ruff", "--fix", "--ignore=E402,B018,F704"]
additional_dependencies: [jupytext, ruff]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.2.0"
hooks:
- id: mypy
exclude: ipynb_filter.py|docs/source/conf.py
4 changes: 2 additions & 2 deletions adaptive/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,5 @@
__all__.append("SKOptLearner")

# to avoid confusion with `notebook_extension` and `__version__`
del _version # noqa: F821
del notebook_integration # noqa: F821
del _version # type: ignore[name-defined] # noqa: F821
del notebook_integration # type: ignore[name-defined] # noqa: F821
4 changes: 3 additions & 1 deletion adaptive/_version.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# This file is part of 'miniver': https://github.com/jbweston/miniver
#
from __future__ import annotations

import os
import subprocess
from collections import namedtuple
Expand All @@ -10,7 +12,7 @@
Version = namedtuple("Version", ("release", "dev", "labels"))

# No public API
__all__ = []
__all__: list[str] = []

package_root = os.path.dirname(os.path.realpath(__file__))
package_name = os.path.basename(package_root)
Expand Down
9 changes: 3 additions & 6 deletions adaptive/learner/average_learner.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from __future__ import annotations

from math import sqrt
from numbers import Integral as Int
from numbers import Real
from typing import Callable

import cloudpickle
import numpy as np

from adaptive.learner.base_learner import BaseLearner
from adaptive.notebook_integration import ensure_holoviews
from adaptive.types import Float
from adaptive.types import Float, Int, Real
from adaptive.utils import (
assign_defaults,
cache_latest,
Expand Down Expand Up @@ -75,7 +73,6 @@ def __init__(
self.min_npoints = max(min_npoints, 2)
self.sum_f: Real = 0.0
self.sum_f_sq: Real = 0.0
self._check_required_attributes()

def new(self) -> AverageLearner:
"""Create a copy of `~adaptive.AverageLearner` without the data."""
Expand All @@ -89,7 +86,7 @@ def to_numpy(self):
"""Data as NumPy array of size (npoints, 2) with seeds and values."""
return np.array(sorted(self.data.items()))

def to_dataframe(
def to_dataframe( # type: ignore[override]
self,
with_default_function_args: bool = True,
function_prefix: str = "function.",
Expand Down Expand Up @@ -129,7 +126,7 @@ def to_dataframe(
assign_defaults(self.function, df, function_prefix)
return df

def load_dataframe(
def load_dataframe( # type: ignore[override]
self,
df: pandas.DataFrame,
with_default_function_args: bool = True,
Expand Down
38 changes: 18 additions & 20 deletions adaptive/learner/average_learner1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from collections import defaultdict
from copy import deepcopy
from math import hypot
from numbers import Integral as Int
from numbers import Real
from typing import Callable, DefaultDict, Iterable, List, Sequence, Tuple

import numpy as np
Expand All @@ -16,6 +14,7 @@

from adaptive.learner.learner1D import Learner1D, _get_intervals
from adaptive.notebook_integration import ensure_holoviews
from adaptive.types import Int, Real
from adaptive.utils import assign_defaults, partial_function_from_dataframe

try:
Expand Down Expand Up @@ -99,7 +98,7 @@ def __init__(
if min_samples > max_samples:
raise ValueError("max_samples should be larger than min_samples.")

super().__init__(function, bounds, loss_per_interval)
super().__init__(function, bounds, loss_per_interval) # type: ignore[arg-type]

self.delta = delta
self.alpha = alpha
Expand All @@ -110,7 +109,7 @@ def __init__(

# Contains all samples f(x) for each
# point x in the form {x0: {0: f_0(x0), 1: f_1(x0), ...}, ...}
self._data_samples = SortedDict()
self._data_samples: SortedDict[float, dict[int, Real]] = SortedDict()
# Contains the number of samples taken
# at each point x in the form {x0: n0, x1: n1, ...}
self._number_samples = SortedDict()
Expand All @@ -124,15 +123,14 @@ def __init__(
# form {xi: ((xii-xi)^2 + (yii-yi)^2)^0.5, ...}
self._distances: dict[Real, float] = decreasing_dict()
# {xii: error[xii]/min(_distances[xi], _distances[xii], ...}
self.rescaled_error: dict[Real, float] = decreasing_dict()
self._check_required_attributes()
self.rescaled_error: ItemSortedDict[Real, float] = decreasing_dict()

def new(self) -> AverageLearner1D:
"""Create a copy of `~adaptive.AverageLearner1D` without the data."""
return AverageLearner1D(
self.function,
self.bounds,
self.loss_per_interval,
self.loss_per_interval, # type: ignore[arg-type]
self.delta,
self.alpha,
self.neighbor_sampling,
Expand Down Expand Up @@ -164,7 +162,7 @@ def to_numpy(self, mean: bool = False) -> np.ndarray:
]
)

def to_dataframe(
def to_dataframe( # type: ignore[override]
self,
mean: bool = False,
with_default_function_args: bool = True,
Expand Down Expand Up @@ -202,10 +200,10 @@ def to_dataframe(
if not with_pandas:
raise ImportError("pandas is not installed.")
if mean:
data = sorted(self.data.items())
data: list[tuple[Real, Real]] = sorted(self.data.items())
columns = [x_name, y_name]
else:
data = [
data: list[tuple[int, Real, Real]] = [ # type: ignore[no-redef]
(seed, x, y)
for x, seed_y in sorted(self._data_samples.items())
for seed, y in sorted(seed_y.items())
Expand All @@ -218,7 +216,7 @@ def to_dataframe(
assign_defaults(self.function, df, function_prefix)
return df

def load_dataframe(
def load_dataframe( # type: ignore[override]
self,
df: pandas.DataFrame,
with_default_function_args: bool = True,
Expand Down Expand Up @@ -258,7 +256,7 @@ def load_dataframe(
self.function, df, function_prefix
)

def ask(self, n: int, tell_pending: bool = True) -> tuple[Points, list[float]]:
def ask(self, n: int, tell_pending: bool = True) -> tuple[Points, list[float]]: # type: ignore[override]
"""Return 'n' points that are expected to maximally reduce the loss."""
# If some point is undersampled, resample it
if len(self._undersampled_points):
Expand Down Expand Up @@ -311,18 +309,18 @@ def _ask_for_new_point(self, n: int) -> tuple[Points, list[float]]:
new point, since in general n << min_samples and this point will need
to be resampled many more times"""
points, (loss_improvement,) = self._ask_points_without_adding(1)
points = [(seed, x) for seed, x in zip(range(n), n * points)]
seed_points = [(seed, x) for seed, x in zip(range(n), n * points)]
loss_improvements = [loss_improvement / n] * n
return points, loss_improvements
return seed_points, loss_improvements # type: ignore[return-value]

def tell_pending(self, seed_x: Point) -> None:
def tell_pending(self, seed_x: Point) -> None: # type: ignore[override]
_, x = seed_x
self.pending_points.add(seed_x)
if x not in self.data:
self._update_neighbors(x, self.neighbors_combined)
self._update_losses(x, real=False)

def tell(self, seed_x: Point, y: Real) -> None:
def tell(self, seed_x: Point, y: Real) -> None: # type: ignore[override]
seed, x = seed_x
if y is None:
raise TypeError(
Expand Down Expand Up @@ -493,7 +491,7 @@ def _calc_error_in_mean(self, ys: Iterable[Real], y_avg: Real, n: int) -> float:
t_student = scipy.stats.t.ppf(1 - self.alpha, df=n - 1)
return t_student * (variance_in_mean / n) ** 0.5

def tell_many(
def tell_many( # type: ignore[override]
self, xs: Points | np.ndarray, ys: Sequence[Real] | np.ndarray
) -> None:
# Check that all x are within the bounds
Expand Down Expand Up @@ -578,10 +576,10 @@ def tell_many_at_point(self, x: Real, seed_y_mapping: dict[int, Real]) -> None:
self._update_interpolated_loss_in_interval(*interval)
self._oldscale = deepcopy(self._scale)

def _get_data(self) -> dict[Real, dict[Int, Real]]:
def _get_data(self) -> dict[Real, dict[Int, Real]]: # type: ignore[override]
return self._data_samples

def _set_data(self, data: dict[Real, dict[Int, Real]]) -> None:
def _set_data(self, data: dict[Real, dict[Int, Real]]) -> None: # type: ignore[override]
if data:
for x, samples in data.items():
self.tell_many_at_point(x, samples)
Expand Down Expand Up @@ -616,7 +614,7 @@ def plot(self):
return p.redim(x={"range": plot_bounds})


def decreasing_dict() -> dict:
def decreasing_dict() -> ItemSortedDict:
"""This initialization orders the dictionary from large to small values"""

def sorting_rule(key, value):
Expand Down
Loading