Skip to content

Commit b6c689f

Browse files
committed
Add type-hints to adaptive/learner/base_learner.py
1 parent 1b7e84d commit b6c689f

File tree

1 file changed

+26
-16
lines changed

1 file changed

+26
-16
lines changed

adaptive/learner/base_learner.py

+26-16
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1+
from __future__ import annotations
2+
13
import abc
24
from contextlib import suppress
5+
from typing import Any, Callable
36

47
import cloudpickle
58

69
from adaptive.utils import _RequireAttrsABCMeta, load, save
710

811

9-
def uses_nth_neighbors(n: int):
12+
def uses_nth_neighbors(n: int) -> Callable[[int], Callable[[BaseLearner], float]]:
1013
"""Decorator to specify how many neighboring intervals the loss function uses.
1114
1215
Wraps loss functions to indicate that they expect intervals together
@@ -53,7 +56,9 @@ def uses_nth_neighbors(n: int):
5356
... return loss
5457
"""
5558

56-
def _wrapped(loss_per_interval):
59+
def _wrapped(
60+
loss_per_interval: Callable[[BaseLearner], float]
61+
) -> Callable[[BaseLearner], float]:
5762
loss_per_interval.nth_neighbors = n
5863
return loss_per_interval
5964

@@ -82,10 +87,15 @@ class BaseLearner(metaclass=_RequireAttrsABCMeta):
8287
"""
8388

8489
data: dict
85-
npoints: int
8690
pending_points: set
91+
function: Callable
92+
93+
@property
94+
@abc.abstractmethod
95+
def npoints(self) -> int:
96+
"""Number of learned points."""
8797

88-
def tell(self, x, y):
98+
def tell(self, x: Any, y: Any) -> None:
8999
"""Tell the learner about a single value.
90100
91101
Parameters
@@ -95,7 +105,7 @@ def tell(self, x, y):
95105
"""
96106
self.tell_many([x], [y])
97107

98-
def tell_many(self, xs, ys):
108+
def tell_many(self, xs: Any, ys: Any) -> None:
99109
"""Tell the learner about some values.
100110
101111
Parameters
@@ -107,16 +117,16 @@ def tell_many(self, xs, ys):
107117
self.tell(x, y)
108118

109119
@abc.abstractmethod
110-
def tell_pending(self, x):
120+
def tell_pending(self, x: Any) -> None:
111121
"""Tell the learner that 'x' has been requested such
112122
that it's not suggested again."""
113123

114124
@abc.abstractmethod
115-
def remove_unfinished(self):
125+
def remove_unfinished(self) -> None:
116126
"""Remove uncomputed data from the learner."""
117127

118128
@abc.abstractmethod
119-
def loss(self, real=True):
129+
def loss(self, real: bool = True) -> float:
120130
"""Return the loss for the current state of the learner.
121131
122132
Parameters
@@ -128,7 +138,7 @@ def loss(self, real=True):
128138
"""
129139

130140
@abc.abstractmethod
131-
def ask(self, n, tell_pending=True):
141+
def ask(self, n: int, tell_pending: bool = True) -> tuple[list[Any], list[float]]:
132142
"""Choose the next 'n' points to evaluate.
133143
134144
Parameters
@@ -142,19 +152,19 @@ def ask(self, n, tell_pending=True):
142152
"""
143153

144154
@abc.abstractmethod
145-
def _get_data(self):
155+
def _get_data(self) -> Any:
146156
pass
147157

148158
@abc.abstractmethod
149-
def _set_data(self):
159+
def _set_data(self, data: Any):
150160
pass
151161

152162
@abc.abstractmethod
153163
def new(self):
154164
"""Return a new learner with the same function and parameters."""
155165
pass
156166

157-
def copy_from(self, other):
167+
def copy_from(self, other: BaseLearner) -> None:
158168
"""Copy over the data from another learner.
159169
160170
Parameters
@@ -164,7 +174,7 @@ def copy_from(self, other):
164174
"""
165175
self._set_data(other._get_data())
166176

167-
def save(self, fname, compress=True):
177+
def save(self, fname: str, compress: bool = True) -> None:
168178
"""Save the data of the learner into a pickle file.
169179
170180
Parameters
@@ -178,7 +188,7 @@ def save(self, fname, compress=True):
178188
data = self._get_data()
179189
save(fname, data, compress)
180190

181-
def load(self, fname, compress=True):
191+
def load(self, fname: str, compress: bool = True) -> None:
182192
"""Load the data of a learner from a pickle file.
183193
184194
Parameters
@@ -193,8 +203,8 @@ def load(self, fname, compress=True):
193203
data = load(fname, compress)
194204
self._set_data(data)
195205

196-
def __getstate__(self):
206+
def __getstate__(self) -> bytes:
197207
return cloudpickle.dumps(self.__dict__)
198208

199-
def __setstate__(self, state):
209+
def __setstate__(self, state: bytes) -> None:
200210
self.__dict__ = cloudpickle.loads(state)

0 commit comments

Comments
 (0)