Skip to content

Commit e7653d5

Browse files
committed
Add type-hints to adaptive/learner/balancing_learner.py (#371)
1 parent 96b4995 commit e7653d5

File tree

1 file changed

+78
-28
lines changed

1 file changed

+78
-28
lines changed

adaptive/learner/balancing_learner.py

+78-28
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,47 @@
11
from __future__ import annotations
22

33
import itertools
4+
import numbers
45
from collections import defaultdict
56
from collections.abc import Iterable
67
from contextlib import suppress
78
from functools import partial
89
from operator import itemgetter
10+
from typing import Any, Callable, Dict, Sequence, Tuple, Union
911

1012
import numpy as np
1113

1214
from adaptive.learner.base_learner import BaseLearner
1315
from adaptive.notebook_integration import ensure_holoviews
1416
from adaptive.utils import cache_latest, named_product, restore
1517

18+
try:
19+
from typing import Literal, TypeAlias
20+
except ImportError:
21+
from typing_extensions import Literal, TypeAlias
22+
1623
try:
1724
import pandas
1825

1926
with_pandas = True
20-
2127
except ModuleNotFoundError:
2228
with_pandas = False
2329

2430

25-
def dispatch(child_functions, arg):
31+
def dispatch(child_functions: list[Callable], arg: Any) -> Any:
2632
index, x = arg
2733
return child_functions[index](x)
2834

2935

36+
STRATEGY_TYPE: TypeAlias = Literal["loss_improvements", "loss", "npoints", "cycle"]
37+
38+
CDIMS_TYPE: TypeAlias = Union[
39+
Sequence[Dict[str, Any]],
40+
Tuple[Sequence[str], Sequence[Tuple[Any, ...]]],
41+
None,
42+
]
43+
44+
3045
class BalancingLearner(BaseLearner):
3146
r"""Choose the optimal points from a set of learners.
3247
@@ -78,13 +93,19 @@ class BalancingLearner(BaseLearner):
7893
behave in an undefined way. Change the `strategy` in that case.
7994
"""
8095

81-
def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
96+
def __init__(
97+
self,
98+
learners: list[BaseLearner],
99+
*,
100+
cdims: CDIMS_TYPE = None,
101+
strategy: STRATEGY_TYPE = "loss_improvements",
102+
) -> None:
82103
self.learners = learners
83104

84105
# Naively we would make 'function' a method, but this causes problems
85106
# when using executors from 'concurrent.futures' because we have to
86107
# pickle the whole learner.
87-
self.function = partial(dispatch, [l.function for l in self.learners])
108+
self.function = partial(dispatch, [l.function for l in self.learners]) # type: ignore
88109

89110
self._ask_cache = {}
90111
self._loss = {}
@@ -96,7 +117,7 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
96117
"A BalacingLearner can handle only one type" " of learners."
97118
)
98119

99-
self.strategy = strategy
120+
self.strategy: STRATEGY_TYPE = strategy
100121

101122
def new(self) -> BalancingLearner:
102123
"""Create a new `BalancingLearner` with the same parameters."""
@@ -107,21 +128,21 @@ def new(self) -> BalancingLearner:
107128
)
108129

109130
@property
110-
def data(self):
131+
def data(self) -> dict[tuple[int, Any], Any]:
111132
data = {}
112133
for i, l in enumerate(self.learners):
113134
data.update({(i, p): v for p, v in l.data.items()})
114135
return data
115136

116137
@property
117-
def pending_points(self):
138+
def pending_points(self) -> set[tuple[int, Any]]:
118139
pending_points = set()
119140
for i, l in enumerate(self.learners):
120141
pending_points.update({(i, p) for p in l.pending_points})
121142
return pending_points
122143

123144
@property
124-
def npoints(self):
145+
def npoints(self) -> int:
125146
return sum(l.npoints for l in self.learners)
126147

127148
@property
@@ -134,7 +155,7 @@ def nsamples(self):
134155
)
135156

136157
@property
137-
def strategy(self):
158+
def strategy(self) -> STRATEGY_TYPE:
138159
"""Can be either 'loss_improvements' (default), 'loss', 'npoints', or
139160
'cycle'. The points that the `BalancingLearner` choses can be either
140161
based on: the best 'loss_improvements', the smallest total 'loss' of
@@ -145,7 +166,7 @@ def strategy(self):
145166
return self._strategy
146167

147168
@strategy.setter
148-
def strategy(self, strategy):
169+
def strategy(self, strategy: STRATEGY_TYPE) -> None:
149170
self._strategy = strategy
150171
if strategy == "loss_improvements":
151172
self._ask_and_tell = self._ask_and_tell_based_on_loss_improvements
@@ -162,7 +183,9 @@ def strategy(self, strategy):
162183
' strategy="npoints", or strategy="cycle" is implemented.'
163184
)
164185

165-
def _ask_and_tell_based_on_loss_improvements(self, n):
186+
def _ask_and_tell_based_on_loss_improvements(
187+
self, n: int
188+
) -> tuple[list[tuple[int, Any]], list[float]]:
166189
selected = [] # tuples ((learner_index, point), loss_improvement)
167190
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
168191
for _ in range(n):
@@ -185,7 +208,9 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
185208
points, loss_improvements = map(list, zip(*selected))
186209
return points, loss_improvements
187210

188-
def _ask_and_tell_based_on_loss(self, n):
211+
def _ask_and_tell_based_on_loss(
212+
self, n: int
213+
) -> tuple[list[tuple[int, Any]], list[float]]:
189214
selected = [] # tuples ((learner_index, point), loss_improvement)
190215
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
191216
for _ in range(n):
@@ -206,7 +231,9 @@ def _ask_and_tell_based_on_loss(self, n):
206231
points, loss_improvements = map(list, zip(*selected))
207232
return points, loss_improvements
208233

209-
def _ask_and_tell_based_on_npoints(self, n):
234+
def _ask_and_tell_based_on_npoints(
235+
self, n: numbers.Integral
236+
) -> tuple[list[tuple[numbers.Integral, Any]], list[float]]:
210237
selected = [] # tuples ((learner_index, point), loss_improvement)
211238
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
212239
for _ in range(n):
@@ -222,7 +249,9 @@ def _ask_and_tell_based_on_npoints(self, n):
222249
points, loss_improvements = map(list, zip(*selected))
223250
return points, loss_improvements
224251

225-
def _ask_and_tell_based_on_cycle(self, n):
252+
def _ask_and_tell_based_on_cycle(
253+
self, n: int
254+
) -> tuple[list[tuple[numbers.Integral, Any]], list[float]]:
226255
points, loss_improvements = [], []
227256
for _ in range(n):
228257
index = next(self._cycle)
@@ -233,7 +262,9 @@ def _ask_and_tell_based_on_cycle(self, n):
233262

234263
return points, loss_improvements
235264

236-
def ask(self, n, tell_pending=True):
265+
def ask(
266+
self, n: int, tell_pending: bool = True
267+
) -> tuple[list[tuple[numbers.Integral, Any]], list[float]]:
237268
"""Chose points for learners."""
238269
if n == 0:
239270
return [], []
@@ -244,20 +275,20 @@ def ask(self, n, tell_pending=True):
244275
else:
245276
return self._ask_and_tell(n)
246277

247-
def tell(self, x, y):
278+
def tell(self, x: tuple[numbers.Integral, Any], y: Any) -> None:
248279
index, x = x
249280
self._ask_cache.pop(index, None)
250281
self._loss.pop(index, None)
251282
self._pending_loss.pop(index, None)
252283
self.learners[index].tell(x, y)
253284

254-
def tell_pending(self, x):
285+
def tell_pending(self, x: tuple[numbers.Integral, Any]) -> None:
255286
index, x = x
256287
self._ask_cache.pop(index, None)
257288
self._loss.pop(index, None)
258289
self.learners[index].tell_pending(x)
259290

260-
def _losses(self, real=True):
291+
def _losses(self, real: bool = True) -> list[float]:
261292
losses = []
262293
loss_dict = self._loss if real else self._pending_loss
263294

@@ -269,11 +300,16 @@ def _losses(self, real=True):
269300
return losses
270301

271302
@cache_latest
272-
def loss(self, real=True):
303+
def loss(self, real: bool = True) -> float:
273304
losses = self._losses(real)
274305
return max(losses)
275306

276-
def plot(self, cdims=None, plotter=None, dynamic=True):
307+
def plot(
308+
self,
309+
cdims: CDIMS_TYPE = None,
310+
plotter: Callable[[BaseLearner], Any] | None = None,
311+
dynamic: bool = True,
312+
):
277313
"""Returns a DynamicMap with sliders.
278314
279315
Parameters
@@ -346,13 +382,19 @@ def plot_function(*args):
346382
vals = {d.name: d.values for d in dm.dimensions() if d.values}
347383
return hv.HoloMap(dm.select(**vals))
348384

349-
def remove_unfinished(self):
385+
def remove_unfinished(self) -> None:
350386
"""Remove uncomputed data from the learners."""
351387
for learner in self.learners:
352388
learner.remove_unfinished()
353389

354390
@classmethod
355-
def from_product(cls, f, learner_type, learner_kwargs, combos):
391+
def from_product(
392+
cls,
393+
f,
394+
learner_type: BaseLearner,
395+
learner_kwargs: dict[str, Any],
396+
combos: dict[str, Sequence[Any]],
397+
) -> BalancingLearner:
356398
"""Create a `BalancingLearner` with learners of all combinations of
357399
named variables’ values. The `cdims` will be set correctly, so calling
358400
`learner.plot` will be a `holoviews.core.HoloMap` with the correct labels.
@@ -448,7 +490,11 @@ def load_dataframe(
448490
for i, gr in df.groupby(index_name):
449491
self.learners[i].load_dataframe(gr, **kwargs)
450492

451-
def save(self, fname, compress=True):
493+
def save(
494+
self,
495+
fname: Callable[[BaseLearner], str] | Sequence[str],
496+
compress: bool = True,
497+
) -> None:
452498
"""Save the data of the child learners into pickle files
453499
in a directory.
454500
@@ -486,7 +532,11 @@ def save(self, fname, compress=True):
486532
for l in self.learners:
487533
l.save(fname(l), compress=compress)
488534

489-
def load(self, fname, compress=True):
535+
def load(
536+
self,
537+
fname: Callable[[BaseLearner], str] | Sequence[str],
538+
compress: bool = True,
539+
) -> None:
490540
"""Load the data of the child learners from pickle files
491541
in a directory.
492542
@@ -510,20 +560,20 @@ def load(self, fname, compress=True):
510560
for l in self.learners:
511561
l.load(fname(l), compress=compress)
512562

513-
def _get_data(self):
563+
def _get_data(self) -> list[Any]:
514564
return [l._get_data() for l in self.learners]
515565

516-
def _set_data(self, data):
566+
def _set_data(self, data: list[Any]):
517567
for l, _data in zip(self.learners, data):
518568
l._set_data(_data)
519569

520-
def __getstate__(self):
570+
def __getstate__(self) -> tuple[list[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]:
521571
return (
522572
self.learners,
523573
self._cdims_default,
524574
self.strategy,
525575
)
526576

527-
def __setstate__(self, state):
577+
def __setstate__(self, state: tuple[list[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]):
528578
learners, cdims, strategy = state
529579
self.__init__(learners, cdims=cdims, strategy=strategy)

0 commit comments

Comments
 (0)