Skip to content

[tune] Hyperband Max Iter Fix #1620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions python/ray/tune/hyperband.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ class HyperBandScheduler(FIFOScheduler):
procedures will use this attribute.
max_t (int): max time units per trial. Trials will be stopped after
max_t time units (determined by time_attr) have passed.
The HyperBand scheduler automatically tries to determine a
reasonable number of brackets based on this. The scheduler will
terminate trials after this time has passed.
The scheduler will terminate trials after this time has passed.
Note that this is different from the semantics of `max_t` as
mentioned in the original HyperBand paper.
"""

def __init__(
Expand All @@ -73,6 +73,7 @@ def __init__(
FIFOScheduler.__init__(self)
self._eta = 3
self._s_max_1 = 5
self._max_t_attr = max_t
# bracket max trials
self._get_n0 = lambda s: int(
np.ceil(self._s_max_1/(s+1) * self._eta**s))
Expand Down Expand Up @@ -117,7 +118,7 @@ def on_trial_add(self, trial_runner, trial):
retry = False
cur_bracket = Bracket(
self._time_attr, self._get_n0(s), self._get_r0(s),
self._eta, s)
self._max_t_attr, self._eta, s)
cur_band.append(cur_bracket)
self._state["bracket"] = cur_bracket

Expand Down Expand Up @@ -257,13 +258,14 @@ class Bracket():

Also keeps track of progress to ensure good scheduling.
"""
def __init__(self, time_attr, max_trials, init_t_attr, eta, s):
def __init__(self, time_attr, max_trials, init_t_attr, max_t_attr, eta, s):
self._live_trials = {} # maps trial -> current result
self._all_trials = []
self._time_attr = time_attr # attribute to

self._n = self._n0 = max_trials
self._r = self._r0 = init_t_attr
self._max_t_attr = max_t_attr
self._cumul_r = self._r0

self._eta = eta
Expand Down Expand Up @@ -314,8 +316,9 @@ def successive_halving(self, reward_attr):
self._halves -= 1
self._n /= self._eta
self._n = int(np.ceil(self._n))

self._r *= self._eta
self._r = int((self._r))
self._r = int(min(self._r, self._max_t_attr - self._cumul_r))
self._cumul_r += self._r
sorted_trials = sorted(
self._live_trials,
Expand Down Expand Up @@ -364,6 +367,8 @@ def completion_percentage(self):

This will not be always finish with 100 since dead trials
are dropped."""
if self.finished():
return 1.0
return self._completed_progress / self._total_work

def _get_result_time(self, result):
Expand All @@ -373,18 +378,19 @@ def _get_result_time(self, result):

def _calculate_total_work(self, n, r, s):
work = 0
cumulative_r = r
for i in range(s+1):
work += int(n) * int(r)
n /= self._eta
n = int(np.ceil(n))
r *= self._eta
r = int(r)
r = int(min(r, self._max_t_attr - cumulative_r))
return work

def __repr__(self):
status = ", ".join([
"Max Size (n)={}".format(self._n),
"Milestone (r)={}".format(self._r),
"Milestone (r)={}".format(self._cumul_r),
"completed={:.1%}".format(self.completion_percentage())
])
counts = collections.Counter([t.status for t in self._all_trials])
Expand Down
8 changes: 4 additions & 4 deletions python/ray/tune/test/trial_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,10 @@ def schedulerSetup(self, num_trials):

Bracketing is placed as follows:
(5, 81);
(8, 27) -> (3, 81);
(15, 9) -> (5, 27) -> (2, 81);
(34, 3) -> (12, 9) -> (4, 27) -> (2, 81);
(81, 1) -> (27, 3) -> (9, 9) -> (3, 27) -> (1, 81);"""
(8, 27) -> (3, 54);
(15, 9) -> (5, 27) -> (2, 45);
(34, 3) -> (12, 9) -> (4, 27) -> (2, 42);
(81, 1) -> (27, 3) -> (9, 9) -> (3, 27) -> (1, 41);"""
sched = HyperBandScheduler()
for i in range(num_trials):
t = Trial("__fake")
Expand Down