Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune][minor] Fixes #1383

Merged
merged 4 commits into from
Jan 12, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion python/ray/tune/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from collections import namedtuple
from ray.tune import TuneError
from ray.tune.logger import NoopLogger, UnifiedLogger
from ray.tune.result import TrainingResult, DEFAULT_RESULTS_DIR
from ray.tune.result import TrainingResult, DEFAULT_RESULTS_DIR, pretty_print
from ray.tune.registry import _default_registry, get_registry, TRAINABLE_CLASS


Expand Down Expand Up @@ -285,6 +285,14 @@ def restore_from_obj(self, obj):
print("Error restoring runner:", traceback.format_exc())
self.status = Trial.ERROR

def update_last_result(self, result, terminate=False):
if terminate:
result = result._replace(done=True)
print("TrainingResult for {}:".format(self))
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
self.last_result = result
self.result_logger.on_result(self.last_result)

def _setup_runner(self):
self.status = Trial.RUNNING
trainable_cls = get_registry().get(
Expand Down
39 changes: 18 additions & 21 deletions python/ray/tune/trial_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import traceback

from ray.tune import TuneError
from ray.tune.result import pretty_print
from ray.tune.trial import Trial, Resources
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler

Expand Down Expand Up @@ -157,35 +156,33 @@ def _launch_trial(self):
# have been lost

def _process_events(self):
[result_id], _ = ray.wait(list(self._running.keys()))
trial = self._running[result_id]
del self._running[result_id]
[result_id], _ = ray.wait(list(self._running))
trial = self._running.pop(result_id)
try:
result = ray.get(result_id)
trial.result_logger.on_result(result)
print("TrainingResult for {}:".format(trial))
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
trial.last_result = result
self._total_time += result.time_this_iter_s

if trial.should_stop(result):
self._scheduler_alg.on_trial_complete(self, trial, result)
self._stop_trial(trial)
decision = TrialScheduler.STOP
else:
decision = self._scheduler_alg.on_trial_result(
self, trial, result)
if decision == TrialScheduler.CONTINUE:
if trial.should_checkpoint():
# TODO(rliaw): This is a blocking call
trial.checkpoint()
self._running[trial.train_remote()] = trial
elif decision == TrialScheduler.PAUSE:
self._pause_trial(trial)
elif decision == TrialScheduler.STOP:
self._stop_trial(trial)
else:
assert False, "Invalid scheduling decision: {}".format(
decision)
trial.update_last_result(
result, terminate=(decision == TrialScheduler.STOP))

if decision == TrialScheduler.CONTINUE:
if trial.should_checkpoint():
# TODO(rliaw): This is a blocking call
trial.checkpoint()
self._running[trial.train_remote()] = trial
elif decision == TrialScheduler.PAUSE:
self._pause_trial(trial)
elif decision == TrialScheduler.STOP:
self._stop_trial(trial)
else:
assert False, "Invalid scheduling decision: {}".format(
decision)
except Exception:
print("Error processing event:", traceback.format_exc())
if trial.status == Trial.RUNNING:
Expand Down
18 changes: 18 additions & 0 deletions test/trial_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,24 @@ def testCheckpointing(self):
self.assertEqual(ray.get(trials[1].runner.get_info.remote()), 1)
self.addCleanup(os.remove, path)

def testResultDone(self):
"""Tests that last_result is marked `done` after trial is complete."""
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner()
kwargs = {
"stopping_criterion": {"training_iteration": 2},
"resources": Resources(cpu=1, gpu=1),
}
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()

runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
self.assertNotEqual(trials[0].last_result.done, True)
runner.step()
self.assertEqual(trials[0].last_result.done, True)

def testPauseThenResume(self):
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner()
Expand Down