Skip to content

Commit 9f249cf

Browse files
committed
BalancingLearner: add a "cycle" strategy, sampling the learners one by one
1 parent 01e26ef commit 9f249cf

File tree

1 file changed

+26
-7
lines changed

1 file changed

+26
-7
lines changed

adaptive/learner/balancing_learner.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from collections.abc import Iterable
55
from contextlib import suppress
66
from functools import partial
7+
import itertools
78
from operator import itemgetter
89

910
import numpy as np
@@ -53,7 +54,8 @@ class BalancingLearner(BaseLearner):
5354
strategy : 'loss_improvements' (default), 'loss', or 'npoints'
5455
The points that the `BalancingLearner` choses can be either based on:
5556
the best 'loss_improvements', the smallest total 'loss' of the
56-
child learners, or the number of points per learner, using 'npoints'.
57+
child learners, the number of points per learner, using 'npoints',
58+
or by cycling through the learners one by one using 'cycle'.
5759
One can dynamically change the strategy while the simulation is
5860
running by changing the ``learner.strategy`` attribute.
5961
@@ -90,10 +92,11 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
9092

9193
@property
9294
def strategy(self):
93-
"""Can be either 'loss_improvements' (default), 'loss', or 'npoints'
94-
The points that the `BalancingLearner` choses can be either based on:
95-
the best 'loss_improvements', the smallest total 'loss' of the
96-
child learners, or the number of points per learner, using 'npoints'.
95+
"""Can be either 'loss_improvements' (default), 'loss', 'npoints', or
96+
'cycle'. The points that the `BalancingLearner` choses can be either
97+
based on: the best 'loss_improvements', the smallest total 'loss' of
98+
the child learners, the number of points per learner, using 'npoints',
99+
or by going through all learners one by one using 'cycle'.
97100
One can dynamically change the strategy while the simulation is
98101
running by changing the ``learner.strategy`` attribute."""
99102
return self._strategy
@@ -107,10 +110,12 @@ def strategy(self, strategy):
107110
self._ask_and_tell = self._ask_and_tell_based_on_loss
108111
elif strategy == "npoints":
109112
self._ask_and_tell = self._ask_and_tell_based_on_npoints
113+
elif strategy == "cycle":
114+
self._ask_and_tell = self._ask_and_tell_based_on_cycle
110115
else:
111116
raise ValueError(
112-
'Only strategy="loss_improvements", strategy="loss", or'
113-
' strategy="npoints" is implemented.'
117+
'Only strategy="loss_improvements", strategy="loss",'
118+
' strategy="npoints", or strategy="cycle" is implemented.'
114119
)
115120

116121
def _ask_and_tell_based_on_loss_improvements(self, n):
@@ -173,6 +178,20 @@ def _ask_and_tell_based_on_npoints(self, n):
173178
points, loss_improvements = map(list, zip(*selected))
174179
return points, loss_improvements
175180

181+
def _ask_and_tell_based_on_cycle(self, n):
182+
if not hasattr(self, "_cycle"):
183+
self._cycle = itertools.cycle(range(len(self.learners)))
184+
185+
points, loss_improvements = [], []
186+
for _ in range(n):
187+
index = next(self._cycle)
188+
point, loss_improvement = self.learners[index].ask(n=1)
189+
points.append((index, point[0]))
190+
loss_improvements.append(loss_improvement[0])
191+
self.tell_pending((index, point[0]))
192+
193+
return points, loss_improvements
194+
176195
def ask(self, n, tell_pending=True):
177196
"""Chose points for learners."""
178197
if n == 0:

0 commit comments

Comments
 (0)