1
1
from __future__ import annotations
2
2
3
3
import itertools
4
+ import numbers
4
5
from collections import defaultdict
5
6
from collections .abc import Iterable
6
7
from contextlib import suppress
7
8
from functools import partial
8
9
from operator import itemgetter
10
+ from typing import Any , Callable , Dict , Sequence , Tuple , Union
9
11
10
12
import numpy as np
11
13
12
14
from adaptive .learner .base_learner import BaseLearner
13
15
from adaptive .notebook_integration import ensure_holoviews
14
16
from adaptive .utils import cache_latest , named_product , restore
15
17
18
+ try :
19
+ from typing import Literal , TypeAlias
20
+ except ImportError :
21
+ from typing_extensions import Literal , TypeAlias
22
+
16
23
try :
17
24
import pandas
18
25
19
26
with_pandas = True
20
-
21
27
except ModuleNotFoundError :
22
28
with_pandas = False
23
29
24
30
25
- def dispatch (child_functions , arg ) :
31
+ def dispatch (child_functions : list [ Callable ] , arg : Any ) -> Any :
26
32
index , x = arg
27
33
return child_functions [index ](x )
28
34
29
35
36
+ STRATEGY_TYPE : TypeAlias = Literal ["loss_improvements" , "loss" , "npoints" , "cycle" ]
37
+
38
+ CDIMS_TYPE : TypeAlias = Union [
39
+ Sequence [Dict [str , Any ]],
40
+ Tuple [Sequence [str ], Sequence [Tuple [Any , ...]]],
41
+ None ,
42
+ ]
43
+
44
+
30
45
class BalancingLearner (BaseLearner ):
31
46
r"""Choose the optimal points from a set of learners.
32
47
@@ -78,13 +93,19 @@ class BalancingLearner(BaseLearner):
78
93
behave in an undefined way. Change the `strategy` in that case.
79
94
"""
80
95
81
- def __init__ (self , learners , * , cdims = None , strategy = "loss_improvements" ):
96
+ def __init__ (
97
+ self ,
98
+ learners : list [BaseLearner ],
99
+ * ,
100
+ cdims : CDIMS_TYPE = None ,
101
+ strategy : STRATEGY_TYPE = "loss_improvements" ,
102
+ ) -> None :
82
103
self .learners = learners
83
104
84
105
# Naively we would make 'function' a method, but this causes problems
85
106
# when using executors from 'concurrent.futures' because we have to
86
107
# pickle the whole learner.
87
- self .function = partial (dispatch , [l .function for l in self .learners ])
108
+ self .function = partial (dispatch , [l .function for l in self .learners ]) # type: ignore
88
109
89
110
self ._ask_cache = {}
90
111
self ._loss = {}
@@ -96,7 +117,7 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
96
117
"A BalacingLearner can handle only one type" " of learners."
97
118
)
98
119
99
- self .strategy = strategy
120
+ self .strategy : STRATEGY_TYPE = strategy
100
121
101
122
def new (self ) -> BalancingLearner :
102
123
"""Create a new `BalancingLearner` with the same parameters."""
@@ -107,21 +128,21 @@ def new(self) -> BalancingLearner:
107
128
)
108
129
109
130
@property
110
- def data (self ):
131
+ def data (self ) -> dict [ tuple [ int , Any ], Any ] :
111
132
data = {}
112
133
for i , l in enumerate (self .learners ):
113
134
data .update ({(i , p ): v for p , v in l .data .items ()})
114
135
return data
115
136
116
137
@property
117
- def pending_points (self ):
138
+ def pending_points (self ) -> set [ tuple [ int , Any ]] :
118
139
pending_points = set ()
119
140
for i , l in enumerate (self .learners ):
120
141
pending_points .update ({(i , p ) for p in l .pending_points })
121
142
return pending_points
122
143
123
144
@property
124
- def npoints (self ):
145
+ def npoints (self ) -> int :
125
146
return sum (l .npoints for l in self .learners )
126
147
127
148
@property
@@ -134,7 +155,7 @@ def nsamples(self):
134
155
)
135
156
136
157
@property
137
- def strategy (self ):
158
+ def strategy (self ) -> STRATEGY_TYPE :
138
159
"""Can be either 'loss_improvements' (default), 'loss', 'npoints', or
139
160
'cycle'. The points that the `BalancingLearner` choses can be either
140
161
based on: the best 'loss_improvements', the smallest total 'loss' of
@@ -145,7 +166,7 @@ def strategy(self):
145
166
return self ._strategy
146
167
147
168
@strategy .setter
148
- def strategy (self , strategy ) :
169
+ def strategy (self , strategy : STRATEGY_TYPE ) -> None :
149
170
self ._strategy = strategy
150
171
if strategy == "loss_improvements" :
151
172
self ._ask_and_tell = self ._ask_and_tell_based_on_loss_improvements
@@ -162,7 +183,9 @@ def strategy(self, strategy):
162
183
' strategy="npoints", or strategy="cycle" is implemented.'
163
184
)
164
185
165
- def _ask_and_tell_based_on_loss_improvements (self , n ):
186
+ def _ask_and_tell_based_on_loss_improvements (
187
+ self , n : int
188
+ ) -> tuple [list [tuple [int , Any ]], list [float ]]:
166
189
selected = [] # tuples ((learner_index, point), loss_improvement)
167
190
total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
168
191
for _ in range (n ):
@@ -185,7 +208,9 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
185
208
points , loss_improvements = map (list , zip (* selected ))
186
209
return points , loss_improvements
187
210
188
- def _ask_and_tell_based_on_loss (self , n ):
211
+ def _ask_and_tell_based_on_loss (
212
+ self , n : int
213
+ ) -> tuple [list [tuple [int , Any ]], list [float ]]:
189
214
selected = [] # tuples ((learner_index, point), loss_improvement)
190
215
total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
191
216
for _ in range (n ):
@@ -206,7 +231,9 @@ def _ask_and_tell_based_on_loss(self, n):
206
231
points , loss_improvements = map (list , zip (* selected ))
207
232
return points , loss_improvements
208
233
209
- def _ask_and_tell_based_on_npoints (self , n ):
234
+ def _ask_and_tell_based_on_npoints (
235
+ self , n : numbers .Integral
236
+ ) -> tuple [list [tuple [numbers .Integral , Any ]], list [float ]]:
210
237
selected = [] # tuples ((learner_index, point), loss_improvement)
211
238
total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
212
239
for _ in range (n ):
@@ -222,7 +249,9 @@ def _ask_and_tell_based_on_npoints(self, n):
222
249
points , loss_improvements = map (list , zip (* selected ))
223
250
return points , loss_improvements
224
251
225
- def _ask_and_tell_based_on_cycle (self , n ):
252
+ def _ask_and_tell_based_on_cycle (
253
+ self , n : int
254
+ ) -> tuple [list [tuple [numbers .Integral , Any ]], list [float ]]:
226
255
points , loss_improvements = [], []
227
256
for _ in range (n ):
228
257
index = next (self ._cycle )
@@ -233,7 +262,9 @@ def _ask_and_tell_based_on_cycle(self, n):
233
262
234
263
return points , loss_improvements
235
264
236
- def ask (self , n , tell_pending = True ):
265
+ def ask (
266
+ self , n : int , tell_pending : bool = True
267
+ ) -> tuple [list [tuple [numbers .Integral , Any ]], list [float ]]:
237
268
"""Chose points for learners."""
238
269
if n == 0 :
239
270
return [], []
@@ -244,20 +275,20 @@ def ask(self, n, tell_pending=True):
244
275
else :
245
276
return self ._ask_and_tell (n )
246
277
247
- def tell (self , x , y ) :
278
+ def tell (self , x : tuple [ numbers . Integral , Any ], y : Any ) -> None :
248
279
index , x = x
249
280
self ._ask_cache .pop (index , None )
250
281
self ._loss .pop (index , None )
251
282
self ._pending_loss .pop (index , None )
252
283
self .learners [index ].tell (x , y )
253
284
254
- def tell_pending (self , x ) :
285
+ def tell_pending (self , x : tuple [ numbers . Integral , Any ]) -> None :
255
286
index , x = x
256
287
self ._ask_cache .pop (index , None )
257
288
self ._loss .pop (index , None )
258
289
self .learners [index ].tell_pending (x )
259
290
260
- def _losses (self , real = True ):
291
+ def _losses (self , real : bool = True ) -> list [ float ] :
261
292
losses = []
262
293
loss_dict = self ._loss if real else self ._pending_loss
263
294
@@ -269,11 +300,16 @@ def _losses(self, real=True):
269
300
return losses
270
301
271
302
@cache_latest
272
- def loss (self , real = True ):
303
+ def loss (self , real : bool = True ) -> float :
273
304
losses = self ._losses (real )
274
305
return max (losses )
275
306
276
- def plot (self , cdims = None , plotter = None , dynamic = True ):
307
+ def plot (
308
+ self ,
309
+ cdims : CDIMS_TYPE = None ,
310
+ plotter : Callable [[BaseLearner ], Any ] | None = None ,
311
+ dynamic : bool = True ,
312
+ ):
277
313
"""Returns a DynamicMap with sliders.
278
314
279
315
Parameters
@@ -346,13 +382,19 @@ def plot_function(*args):
346
382
vals = {d .name : d .values for d in dm .dimensions () if d .values }
347
383
return hv .HoloMap (dm .select (** vals ))
348
384
349
- def remove_unfinished (self ):
385
+ def remove_unfinished (self ) -> None :
350
386
"""Remove uncomputed data from the learners."""
351
387
for learner in self .learners :
352
388
learner .remove_unfinished ()
353
389
354
390
@classmethod
355
- def from_product (cls , f , learner_type , learner_kwargs , combos ):
391
+ def from_product (
392
+ cls ,
393
+ f ,
394
+ learner_type : BaseLearner ,
395
+ learner_kwargs : dict [str , Any ],
396
+ combos : dict [str , Sequence [Any ]],
397
+ ) -> BalancingLearner :
356
398
"""Create a `BalancingLearner` with learners of all combinations of
357
399
named variables’ values. The `cdims` will be set correctly, so calling
358
400
`learner.plot` will be a `holoviews.core.HoloMap` with the correct labels.
@@ -448,7 +490,11 @@ def load_dataframe(
448
490
for i , gr in df .groupby (index_name ):
449
491
self .learners [i ].load_dataframe (gr , ** kwargs )
450
492
451
- def save (self , fname , compress = True ):
493
+ def save (
494
+ self ,
495
+ fname : Callable [[BaseLearner ], str ] | Sequence [str ],
496
+ compress : bool = True ,
497
+ ) -> None :
452
498
"""Save the data of the child learners into pickle files
453
499
in a directory.
454
500
@@ -486,7 +532,11 @@ def save(self, fname, compress=True):
486
532
for l in self .learners :
487
533
l .save (fname (l ), compress = compress )
488
534
489
- def load (self , fname , compress = True ):
535
+ def load (
536
+ self ,
537
+ fname : Callable [[BaseLearner ], str ] | Sequence [str ],
538
+ compress : bool = True ,
539
+ ) -> None :
490
540
"""Load the data of the child learners from pickle files
491
541
in a directory.
492
542
@@ -510,20 +560,20 @@ def load(self, fname, compress=True):
510
560
for l in self .learners :
511
561
l .load (fname (l ), compress = compress )
512
562
513
- def _get_data (self ):
563
+ def _get_data (self ) -> list [ Any ] :
514
564
return [l ._get_data () for l in self .learners ]
515
565
516
- def _set_data (self , data ):
566
+ def _set_data (self , data : list [ Any ] ):
517
567
for l , _data in zip (self .learners , data ):
518
568
l ._set_data (_data )
519
569
520
- def __getstate__ (self ):
570
+ def __getstate__ (self ) -> tuple [ list [ BaseLearner ], CDIMS_TYPE , STRATEGY_TYPE ] :
521
571
return (
522
572
self .learners ,
523
573
self ._cdims_default ,
524
574
self .strategy ,
525
575
)
526
576
527
- def __setstate__ (self , state ):
577
+ def __setstate__ (self , state : tuple [ list [ BaseLearner ], CDIMS_TYPE , STRATEGY_TYPE ] ):
528
578
learners , cdims , strategy = state
529
579
self .__init__ (learners , cdims = cdims , strategy = strategy )
0 commit comments