Skip to content

Commit 221a5b3

Browse files
committed
BalancingLearner: add a "cycle" strategy, sampling the learners one by one
1 parent 4889ab5 commit 221a5b3

File tree

1 file changed

+26
-7
lines changed

1 file changed

+26
-7
lines changed

adaptive/learner/balancing_learner.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from collections.abc import Iterable
55
from contextlib import suppress
66
from functools import partial
7+
import itertools
78
from operator import itemgetter
89
import os.path
910

@@ -54,7 +55,8 @@ class BalancingLearner(BaseLearner):
5455
strategy : 'loss_improvements' (default), 'loss', or 'npoints'
5556
The points that the `BalancingLearner` choses can be either based on:
5657
the best 'loss_improvements', the smallest total 'loss' of the
57-
child learners, or the number of points per learner, using 'npoints'.
58+
child learners, the number of points per learner, using 'npoints',
59+
or by cycling through the learners one by one using 'cycle'.
5860
One can dynamically change the strategy while the simulation is
5961
running by changing the ``learner.strategy`` attribute.
6062
@@ -90,10 +92,11 @@ def __init__(self, learners, *, cdims=None, strategy='loss_improvements'):
9092

9193
@property
9294
def strategy(self):
93-
"""Can be either 'loss_improvements' (default), 'loss', or 'npoints'
94-
The points that the `BalancingLearner` choses can be either based on:
95-
the best 'loss_improvements', the smallest total 'loss' of the
96-
child learners, or the number of points per learner, using 'npoints'.
95+
"""Can be either 'loss_improvements' (default), 'loss', 'npoints', or
96+
'cycle'. The points that the `BalancingLearner` choses can be either
97+
based on: the best 'loss_improvements', the smallest total 'loss' of
98+
the child learners, the number of points per learner, using 'npoints',
99+
or by going through all learners one by one using 'cycle'.
97100
One can dynamically change the strategy while the simulation is
98101
running by changing the ``learner.strategy`` attribute."""
99102
return self._strategy
@@ -107,10 +110,12 @@ def strategy(self, strategy):
107110
self._ask_and_tell = self._ask_and_tell_based_on_loss
108111
elif strategy == 'npoints':
109112
self._ask_and_tell = self._ask_and_tell_based_on_npoints
113+
elif strategy == 'cycle':
114+
self._ask_and_tell = self._ask_and_tell_based_on_cycle
110115
else:
111116
raise ValueError(
112-
'Only strategy="loss_improvements", strategy="loss", or'
113-
' strategy="npoints" is implemented.')
117+
'Only strategy="loss_improvements", strategy="loss",'
118+
' strategy="npoints", or strategy="cycle" is implemented.')
114119

115120
def _ask_and_tell_based_on_loss_improvements(self, n):
116121
selected = [] # tuples ((learner_index, point), loss_improvement)
@@ -176,6 +181,20 @@ def _ask_and_tell_based_on_npoints(self, n):
176181
points, loss_improvements = map(list, zip(*selected))
177182
return points, loss_improvements
178183

184+
def _ask_and_tell_based_on_cycle(self, n):
185+
if not hasattr(self, '_cycle'):
186+
self._cycle = itertools.cycle(range(len(self.learners)))
187+
188+
points, loss_improvements = [], []
189+
for _ in range(n):
190+
index = next(self._cycle)
191+
point, loss_improvement = self.learners[index].ask(n=1)
192+
points.append((index, point[0]))
193+
loss_improvements.append(loss_improvement[0])
194+
self.tell_pending((index, point[0]))
195+
196+
return points, loss_improvements
197+
179198
def ask(self, n, tell_pending=True):
180199
"""Chose points for learners."""
181200
if n == 0:

0 commit comments

Comments
 (0)