1
+ from __future__ import annotations
2
+
1
3
import abc
2
4
from contextlib import suppress
5
+ from typing import Any , Callable
3
6
4
7
import cloudpickle
5
8
6
9
from adaptive .utils import _RequireAttrsABCMeta , load , save
7
10
8
11
9
- def uses_nth_neighbors (n : int ):
12
+ def uses_nth_neighbors (n : int ) -> Callable [[ int ], Callable [[ BaseLearner ], float ]] :
10
13
"""Decorator to specify how many neighboring intervals the loss function uses.
11
14
12
15
Wraps loss functions to indicate that they expect intervals together
@@ -53,7 +56,9 @@ def uses_nth_neighbors(n: int):
53
56
... return loss
54
57
"""
55
58
56
- def _wrapped (loss_per_interval ):
59
+ def _wrapped (
60
+ loss_per_interval : Callable [[BaseLearner ], float ]
61
+ ) -> Callable [[BaseLearner ], float ]:
57
62
loss_per_interval .nth_neighbors = n
58
63
return loss_per_interval
59
64
@@ -82,10 +87,15 @@ class BaseLearner(metaclass=_RequireAttrsABCMeta):
82
87
"""
83
88
84
89
data : dict
85
- npoints : int
86
90
pending_points : set
91
+ function : Callable
92
+
93
+ @property
94
+ @abc .abstractmethod
95
+ def npoints (self ) -> int :
96
+ """Number of learned points."""
87
97
88
- def tell (self , x , y ) :
98
+ def tell (self , x : Any , y : Any ) -> None :
89
99
"""Tell the learner about a single value.
90
100
91
101
Parameters
@@ -95,7 +105,7 @@ def tell(self, x, y):
95
105
"""
96
106
self .tell_many ([x ], [y ])
97
107
98
- def tell_many (self , xs , ys ) :
108
+ def tell_many (self , xs : Any , ys : Any ) -> None :
99
109
"""Tell the learner about some values.
100
110
101
111
Parameters
@@ -107,16 +117,16 @@ def tell_many(self, xs, ys):
107
117
self .tell (x , y )
108
118
109
119
@abc .abstractmethod
110
- def tell_pending (self , x ) :
120
+ def tell_pending (self , x : Any ) -> None :
111
121
"""Tell the learner that 'x' has been requested such
112
122
that it's not suggested again."""
113
123
114
124
@abc .abstractmethod
115
- def remove_unfinished (self ):
125
+ def remove_unfinished (self ) -> None :
116
126
"""Remove uncomputed data from the learner."""
117
127
118
128
@abc .abstractmethod
119
- def loss (self , real = True ):
129
+ def loss (self , real : bool = True ) -> float :
120
130
"""Return the loss for the current state of the learner.
121
131
122
132
Parameters
@@ -128,7 +138,7 @@ def loss(self, real=True):
128
138
"""
129
139
130
140
@abc .abstractmethod
131
- def ask (self , n , tell_pending = True ):
141
+ def ask (self , n : int , tell_pending : bool = True ) -> tuple [ list [ Any ], list [ float ]] :
132
142
"""Choose the next 'n' points to evaluate.
133
143
134
144
Parameters
@@ -142,19 +152,19 @@ def ask(self, n, tell_pending=True):
142
152
"""
143
153
144
154
@abc .abstractmethod
145
- def _get_data (self ):
155
+ def _get_data (self ) -> Any :
146
156
pass
147
157
148
158
@abc .abstractmethod
149
- def _set_data (self ):
159
+ def _set_data (self , data : Any ):
150
160
pass
151
161
152
162
@abc .abstractmethod
153
163
def new (self ):
154
164
"""Return a new learner with the same function and parameters."""
155
165
pass
156
166
157
- def copy_from (self , other ) :
167
+ def copy_from (self , other : BaseLearner ) -> None :
158
168
"""Copy over the data from another learner.
159
169
160
170
Parameters
@@ -164,7 +174,7 @@ def copy_from(self, other):
164
174
"""
165
175
self ._set_data (other ._get_data ())
166
176
167
- def save (self , fname , compress = True ):
177
+ def save (self , fname : str , compress : bool = True ) -> None :
168
178
"""Save the data of the learner into a pickle file.
169
179
170
180
Parameters
@@ -178,7 +188,7 @@ def save(self, fname, compress=True):
178
188
data = self ._get_data ()
179
189
save (fname , data , compress )
180
190
181
- def load (self , fname , compress = True ):
191
+ def load (self , fname : str , compress : bool = True ) -> None :
182
192
"""Load the data of a learner from a pickle file.
183
193
184
194
Parameters
@@ -193,8 +203,8 @@ def load(self, fname, compress=True):
193
203
data = load (fname , compress )
194
204
self ._set_data (data )
195
205
196
- def __getstate__ (self ):
206
+ def __getstate__ (self ) -> bytes :
197
207
return cloudpickle .dumps (self .__dict__ )
198
208
199
- def __setstate__ (self , state ) :
209
+ def __setstate__ (self , state : bytes ) -> None :
200
210
self .__dict__ = cloudpickle .loads (state )
0 commit comments