Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune][wip] Component Checkpointing #3709

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions python/ray/tune/schedulers/median_stopping_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,23 @@ def on_trial_result(self, trial_runner, trial, result):
value by step `t` is strictly worse than the median of the running
averages of all completed trials' objectives reported up to step `t`.
"""

if trial in self._stopped_trials:
trial_id = trial.trial_id
if trial_id in self._stopped_trials:
assert not self._hard_stop
return TrialScheduler.CONTINUE # fall back to FIFO

time = result[self._time_attr]
self._results[trial].append(result)
self._results[trial_id].append(result)
median_result = self._get_median_result(time)
best_result = self._best_result(trial)
best_result = self._best_result(trial_id)
if self._verbose:
logger.info("Trial {} best res={} vs median res={} at t={}".format(
trial, best_result, median_result, time))
if best_result < median_result and time > self._grace_period:
if self._verbose:
logger.info("MedianStoppingRule: "
"early stopping {}".format(trial))
self._stopped_trials.add(trial)
self._stopped_trials.add(trial_id)
if self._hard_stop:
return TrialScheduler.STOP
else:
Expand All @@ -85,36 +85,36 @@ def on_trial_result(self, trial_runner, trial, result):
return TrialScheduler.CONTINUE

def on_trial_complete(self, trial_runner, trial, result):
self._results[trial].append(result)
self._completed_trials.add(trial)
self._results[trial.trial_id].append(result)
self._completed_trials.add(trial.trial_id)

def on_trial_remove(self, trial_runner, trial):
"""Marks trial as completed if it is paused and has previously ran."""
if trial.status is Trial.PAUSED and trial in self._results:
self._completed_trials.add(trial)
if trial.status is Trial.PAUSED and trial.trial_id in self._results:
self._completed_trials.add(trial.trial_id)

def debug_string(self):
return "Using MedianStoppingRule: num_stopped={}.".format(
len(self._stopped_trials))

def _get_median_result(self, time):
scores = []
for trial in self._completed_trials:
scores.append(self._running_result(trial, time))
for trial_id in self._completed_trials:
scores.append(self._running_result(trial_id, time))
if len(scores) >= self._min_samples_required:
return np.median(scores)
else:
return float('-inf')

def _running_result(self, trial, t_max=float('inf')):
results = self._results[trial]
def _running_result(self, trial_id, t_max=float('inf')):
results = self._results[trial_id]
# TODO(ekl) we could do interpolation to be more precise, but for now
# assume len(results) is large and the time diffs are roughly equal
return np.mean([
r[self._reward_attr] for r in results
if r[self._time_attr] <= t_max
])

def _best_result(self, trial):
results = self._results[trial]
def _best_result(self, trial_id):
results = self._results[trial_id]
return max(r[self._reward_attr] for r in results)
14 changes: 14 additions & 0 deletions python/ray/tune/schedulers/trial_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from __future__ import division
from __future__ import print_function

import copy

from ray.tune.trial import Trial


Expand Down Expand Up @@ -64,6 +66,12 @@ def debug_string(self):

raise NotImplementedError

def __getstate__(self):
raise NotImplementedError

def __setstate__(self, state):
raise NotImplementedError


class FIFOScheduler(TrialScheduler):
"""Simple scheduler that just runs trials in submission order."""
Expand Down Expand Up @@ -96,3 +104,9 @@ def choose_trial_to_run(self, trial_runner):

def debug_string(self):
return "Using FIFO scheduling algorithm."

def __getstate__(self):
return copy.deepcopy(self.__dict__)

def __setstate__(self, state):
self.__dict__.update(state)