4
4
from collections .abc import Iterable
5
5
from contextlib import suppress
6
6
from functools import partial
7
+ import itertools
7
8
from operator import itemgetter
8
9
9
10
import numpy as np
@@ -53,7 +54,8 @@ class BalancingLearner(BaseLearner):
53
54
strategy : 'loss_improvements' (default), 'loss', or 'npoints'
54
55
The points that the `BalancingLearner` choses can be either based on:
55
56
the best 'loss_improvements', the smallest total 'loss' of the
56
- child learners, or the number of points per learner, using 'npoints'.
57
+ child learners, the number of points per learner, using 'npoints',
58
+ or by cycling through the learners one by one using 'cycle'.
57
59
One can dynamically change the strategy while the simulation is
58
60
running by changing the ``learner.strategy`` attribute.
59
61
@@ -90,10 +92,11 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
90
92
91
93
@property
92
94
def strategy (self ):
93
- """Can be either 'loss_improvements' (default), 'loss', or 'npoints'
94
- The points that the `BalancingLearner` choses can be either based on:
95
- the best 'loss_improvements', the smallest total 'loss' of the
96
- child learners, or the number of points per learner, using 'npoints'.
95
+ """Can be either 'loss_improvements' (default), 'loss', 'npoints', or
96
+ 'cycle'. The points that the `BalancingLearner` choses can be either
97
+ based on: the best 'loss_improvements', the smallest total 'loss' of
98
+ the child learners, the number of points per learner, using 'npoints',
99
+ or by going through all learners one by one using 'cycle'.
97
100
One can dynamically change the strategy while the simulation is
98
101
running by changing the ``learner.strategy`` attribute."""
99
102
return self ._strategy
@@ -107,10 +110,12 @@ def strategy(self, strategy):
107
110
self ._ask_and_tell = self ._ask_and_tell_based_on_loss
108
111
elif strategy == "npoints" :
109
112
self ._ask_and_tell = self ._ask_and_tell_based_on_npoints
113
+ elif strategy == "cycle" :
114
+ self ._ask_and_tell = self ._ask_and_tell_based_on_cycle
110
115
else :
111
116
raise ValueError (
112
- 'Only strategy="loss_improvements", strategy="loss", or '
113
- ' strategy="npoints" is implemented.'
117
+ 'Only strategy="loss_improvements", strategy="loss",'
118
+ ' strategy="npoints", or strategy="cycle" is implemented.'
114
119
)
115
120
116
121
def _ask_and_tell_based_on_loss_improvements (self , n ):
@@ -173,6 +178,20 @@ def _ask_and_tell_based_on_npoints(self, n):
173
178
points , loss_improvements = map (list , zip (* selected ))
174
179
return points , loss_improvements
175
180
181
+ def _ask_and_tell_based_on_cycle (self , n ):
182
+ if not hasattr (self , "_cycle" ):
183
+ self ._cycle = itertools .cycle (range (len (self .learners )))
184
+
185
+ points , loss_improvements = [], []
186
+ for _ in range (n ):
187
+ index = next (self ._cycle )
188
+ point , loss_improvement = self .learners [index ].ask (n = 1 )
189
+ points .append ((index , point [0 ]))
190
+ loss_improvements .append (loss_improvement [0 ])
191
+ self .tell_pending ((index , point [0 ]))
192
+
193
+ return points , loss_improvements
194
+
176
195
def ask (self , n , tell_pending = True ):
177
196
"""Chose points for learners."""
178
197
if n == 0 :
0 commit comments