4
4
from collections .abc import Iterable
5
5
from contextlib import suppress
6
6
from functools import partial
7
+ import itertools
7
8
from operator import itemgetter
8
9
import os .path
9
10
@@ -54,7 +55,8 @@ class BalancingLearner(BaseLearner):
54
55
strategy : 'loss_improvements' (default), 'loss', or 'npoints'
55
56
The points that the `BalancingLearner` choses can be either based on:
56
57
the best 'loss_improvements', the smallest total 'loss' of the
57
- child learners, or the number of points per learner, using 'npoints'.
58
+ child learners, the number of points per learner, using 'npoints',
59
+ or by cycling through the learners one by one using 'cycle'.
58
60
One can dynamically change the strategy while the simulation is
59
61
running by changing the ``learner.strategy`` attribute.
60
62
@@ -90,10 +92,11 @@ def __init__(self, learners, *, cdims=None, strategy='loss_improvements'):
90
92
91
93
@property
92
94
def strategy (self ):
93
- """Can be either 'loss_improvements' (default), 'loss', or 'npoints'
94
- The points that the `BalancingLearner` choses can be either based on:
95
- the best 'loss_improvements', the smallest total 'loss' of the
96
- child learners, or the number of points per learner, using 'npoints'.
95
+ """Can be either 'loss_improvements' (default), 'loss', 'npoints', or
96
+ 'cycle'. The points that the `BalancingLearner` choses can be either
97
+ based on: the best 'loss_improvements', the smallest total 'loss' of
98
+ the child learners, the number of points per learner, using 'npoints',
99
+ or by going through all learners one by one using 'cycle'.
97
100
One can dynamically change the strategy while the simulation is
98
101
running by changing the ``learner.strategy`` attribute."""
99
102
return self ._strategy
@@ -107,10 +110,12 @@ def strategy(self, strategy):
107
110
self ._ask_and_tell = self ._ask_and_tell_based_on_loss
108
111
elif strategy == 'npoints' :
109
112
self ._ask_and_tell = self ._ask_and_tell_based_on_npoints
113
+ elif strategy == 'cycle' :
114
+ self ._ask_and_tell = self ._ask_and_tell_based_on_cycle
110
115
else :
111
116
raise ValueError (
112
- 'Only strategy="loss_improvements", strategy="loss", or '
113
- ' strategy="npoints" is implemented.' )
117
+ 'Only strategy="loss_improvements", strategy="loss",'
118
+ ' strategy="npoints", or strategy="cycle" is implemented.' )
114
119
115
120
def _ask_and_tell_based_on_loss_improvements (self , n ):
116
121
selected = [] # tuples ((learner_index, point), loss_improvement)
@@ -176,6 +181,20 @@ def _ask_and_tell_based_on_npoints(self, n):
176
181
points , loss_improvements = map (list , zip (* selected ))
177
182
return points , loss_improvements
178
183
184
+ def _ask_and_tell_based_on_cycle (self , n ):
185
+ if not hasattr (self , '_cycle' ):
186
+ self ._cycle = itertools .cycle (range (len (self .learners )))
187
+
188
+ points , loss_improvements = [], []
189
+ for _ in range (n ):
190
+ index = next (self ._cycle )
191
+ point , loss_improvement = self .learners [index ].ask (n = 1 )
192
+ points .append ((index , point [0 ]))
193
+ loss_improvements .append (loss_improvement [0 ])
194
+ self .tell_pending ((index , point [0 ]))
195
+
196
+ return points , loss_improvements
197
+
179
198
def ask (self , n , tell_pending = True ):
180
199
"""Chose points for learners."""
181
200
if n == 0 :
0 commit comments