Skip to content

Commit

Permalink
[tune] Asynchronous saves (#6912)
Browse files Browse the repository at this point in the history
* Support asynchronous saves

* Fix merge issues

* Add test, fix existing tests

* More informative warning

* Lint, remove print statements

* Address comments, add checkpoint.is_resolved fn

* Add more detailed comments
  • Loading branch information
ujvl authored Feb 9, 2020
1 parent 0648bd2 commit 98a07fe
Show file tree
Hide file tree
Showing 10 changed files with 254 additions and 128 deletions.
9 changes: 3 additions & 6 deletions python/ray/function_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,7 @@ def _load_function_from_local(self, job_id, function_descriptor):
))
self._num_task_executions[job_id][function_id] = 0
except Exception:
logger.exception(
"Failed to load function {}.".format(function_name))
logger.exception("Failed to load function %s.", function_name)
raise Exception(
"Function {} failed to be loaded from local code.".format(
function_descriptor))
Expand Down Expand Up @@ -428,8 +427,7 @@ def _load_actor_from_local(self, job_id, function_descriptor):
else:
return actor_class
except Exception:
logger.exception(
"Failed to load actor_class %s.".format(class_name))
logger.exception("Failed to load actor_class %s.", class_name)
raise Exception(
"Actor {} failed to be imported from local code.".format(
class_name))
Expand Down Expand Up @@ -475,8 +473,7 @@ def _load_actor_class_from_gcs(self, job_id, function_descriptor):
with self.lock:
actor_class = pickle.loads(pickled_class)
except Exception:
logger.exception(
"Failed to load actor class %s.".format(class_name))
logger.exception("Failed to load actor class %s.", class_name)
# The actor class failed to be unpickled, create a fake actor
# class instead (just to produce error messages and to prevent
# the driver from hanging).
Expand Down
15 changes: 14 additions & 1 deletion python/ray/tune/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ class Checkpoint:
Attributes:
storage (str): Storage type.
value (str): If storage==MEMORY, it is a Python object.
If storage==PERSISTENT, it is a path to persistent storage.
If storage==PERSISTENT, it is a path to persistent storage,
or a future that will be resolved to such a path.
"""

MEMORY = "memory"
Expand All @@ -29,6 +30,18 @@ def from_object(value=None):
"""Creates a checkpoint from a Python object."""
return Checkpoint(Checkpoint.MEMORY, value)

@property
def is_ready(self):
"""Returns whether the checkpoint is ready to be used for restoration.
A PERSISTENT checkpoint is considered ready once its value is resolved
to an actual path. MEMORY checkpoints are always considered ready since
they are transient.
"""
if self.storage == Checkpoint.PERSISTENT:
return isinstance(self.value, str)
return self.storage == Checkpoint.MEMORY


class QueueItem:
def __init__(self, priority, value):
Expand Down
25 changes: 6 additions & 19 deletions python/ray/tune/ray_trial_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def on_step_begin(self, trial_runner):
self._update_avail_resources()

def save(self, trial, storage=Checkpoint.PERSISTENT, result=None):
"""Saves the trial's state to a checkpoint.
"""Saves the trial's state to a checkpoint asynchronously.
Args:
trial (Trial): The trial to be saved.
Expand All @@ -567,29 +567,16 @@ def save(self, trial, storage=Checkpoint.PERSISTENT, result=None):
Checkpoint object, or None if an Exception occurs.
"""
result = result or trial.last_result

with self._change_working_directory(trial):
if storage == Checkpoint.MEMORY:
value = trial.runner.save_to_object.remote()
checkpoint = Checkpoint(storage, value, result)
else:
with warn_if_slow("save_checkpoint_to_storage"):
# TODO(ujvl): Make this asynchronous.
value = ray.get(trial.runner.save.remote())
checkpoint = Checkpoint(storage, value, result)
with warn_if_slow("on_checkpoint", DEFAULT_GET_TIMEOUT) as profile:
try:
trial.on_checkpoint(checkpoint)
except Exception:
logger.exception("Trial %s: Error handling checkpoint %s",
trial, checkpoint.value)
return None
if profile.too_slow and trial.sync_on_checkpoint:
logger.warning(
"Consider turning off forced head-worker trial checkpoint "
"syncs by setting sync_on_checkpoint=False. Note that this "
"might result in faulty trial restoration for some worker "
"failure modes.")
else:
value = trial.runner.save.remote()
checkpoint = Checkpoint(storage, value, result)
trial.saving_to = checkpoint
self._running[value] = trial
return checkpoint

def restore(self, trial, checkpoint=None):
Expand Down
81 changes: 47 additions & 34 deletions python/ray/tune/tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,15 @@ def test_remove_node_before_result(start_connected_emptyhead_cluster):
trial = Trial("__fake", **kwargs)
runner.add_trial(trial)

runner.step() # run 1
runner.step() # Start trial
assert trial.status == Trial.RUNNING
cluster.remove_node(node)
cluster.add_node(num_cpus=1)
cluster.wait_for_nodes()
assert ray.cluster_resources()["CPU"] == 1

for i in range(3):
# Process result (x2), process save, process result.
for _ in range(4):
runner.step()
assert trial.status == Trial.TERMINATED

Expand Down Expand Up @@ -237,39 +238,45 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
# Test recovery of trial that hasn't been checkpointed
t = Trial(trainable_id, **kwargs)
runner.add_trial(t)
runner.step() # start
runner.step() # 1 result
runner.step() # Start trial
runner.step() # Process result
assert t.last_result
node2 = cluster.add_node(num_cpus=1)
cluster.remove_node(node)
cluster.wait_for_nodes()
# TODO(ujvl): Node failure does not propagate until a step after it
# actually should. This is possibly a problem with `Cluster`.
runner.step()
runner.step() # Recovery step

# TODO(rliaw): This assertion is not critical but will not pass
# because checkpoint handling is messy and should be refactored
# rather than hotfixed.
# assert t.last_result is None, "Trial result not restored correctly."
for i in range(4):

# Process result (x2), process save, process result (x2), process save
for _ in range(6):
runner.step()

assert t.status == Trial.TERMINATED
assert t.status == Trial.TERMINATED, runner.debug_string()

# Test recovery of trial that has been checkpointed
t2 = Trial(trainable_id, **kwargs)
runner.add_trial(t2)
runner.step() # start
runner.step() # 1 result
runner.step() # 2 result and checkpoint
# Start trial, process result (x2), process save
for _ in range(4):
runner.step()
assert t2.has_checkpoint()
node3 = cluster.add_node(num_cpus=1)
cluster.remove_node(node2)
cluster.wait_for_nodes()
runner.step() # 3 result + start and fail 4 result
runner.step() # Recovery step
runner.step() # Process recovery
runner.step() # result
runner.step() # Process result 3 + start and fail 4 result
runner.step() # Dispatch restore
runner.step() # Process restore
runner.step() # Process result 5
if t2.status != Trial.TERMINATED:
runner.step()
runner.step() # Process result 6, dispatch save
runner.step() # Process save
assert t2.status == Trial.TERMINATED, runner.debug_string()

# Test recovery of trial that won't be checkpointed
Expand All @@ -282,8 +289,8 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
}
t3 = Trial(trainable_id, **kwargs)
runner.add_trial(t3)
runner.step() # start
runner.step() # 1 result
runner.step() # Start trial
runner.step() # Process result 1
cluster.add_node(num_cpus=1)
cluster.remove_node(node3)
cluster.wait_for_nodes()
Expand Down Expand Up @@ -318,13 +325,16 @@ def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
for t in trials:
runner.add_trial(t)

runner.step() # start
runner.step() # 1 result
runner.step() # Start trial
runner.step() # Process result, dispatch save
runner.step() # Process save

cluster.remove_node(node)
cluster.wait_for_nodes()
runner.step()
assert all(t.status == Trial.PENDING for t in trials)
runner.step() # Process result, dispatch save
runner.step() # Process save (detect error), requeue trial
assert all(
t.status == Trial.PENDING for t in trials), runner.debug_string()

with pytest.raises(TuneError):
runner.step()
Expand Down Expand Up @@ -374,19 +384,21 @@ def mock_find_dir_fn(checkpoint_path):
# Test recovery of trial that has been checkpointed
t1 = Trial(trainable_id, **kwargs)
runner.add_trial(t1)
runner.step() # start
runner.step() # 1 result
runner.step() # 2 result and checkpoint

# Start trial, process result (x2), process save
for _ in range(4):
runner.step()
assert t1.has_checkpoint()

cluster.add_node(num_cpus=1)
cluster.remove_node(node)
cluster.wait_for_nodes()
shutil.rmtree(os.path.dirname(t1.checkpoint.value))

runner.step() # collect result 3, kick off + fail result 4
runner.step() # Recovery step
runner.step() # Process Recovery + step 4
for i in range(3):
runner.step() # Collect result 3, kick off + fail result 4
runner.step() # Dispatch restore
runner.step() # Process restore + step 4
for _ in range(3):
if t1.status != Trial.TERMINATED:
runner.step()
assert t1.status == Trial.TERMINATED, runner.debug_string()
Expand Down Expand Up @@ -414,9 +426,9 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir, trainable_id):
for t in trials:
runner.add_trial(t)

runner.step() # start
runner.step() # start2
runner.step() # step
# Start trial (x2), process result, process save
for _ in range(4):
runner.step()
assert all(t.status == Trial.RUNNING for t in runner.get_trials())
runner.checkpoint()

Expand All @@ -425,11 +437,12 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir, trainable_id):

cluster = _start_new_cluster()
runner = TrialRunner(resume="LOCAL", local_checkpoint_dir=dirpath)
runner.step() # start
runner.step() # process restore
runner.step() # start2
# Start trial, process restore, process result, process save
for _ in range(4):
runner.step()

for i in range(3):
# Start trial 2, process result, process save, process result, process save
for i in range(5):
runner.step()

with pytest.raises(TuneError):
Expand Down
27 changes: 26 additions & 1 deletion python/ray/tune/tests/test_ray_trial_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,25 @@ def testStartStop(self):
self.assertEqual(1, len(running))
self.trial_executor.stop_trial(trial)

def testAsyncSave(self):
"""Tests that saved checkpoint value not immediately set."""
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT)
self.assertEqual(checkpoint, trial.saving_to)
self.assertEqual(trial.checkpoint.value, None)
self.process_trial_save(trial)
self.assertEqual(checkpoint, trial.checkpoint)
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)

def testSaveRestore(self):
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self.trial_executor.save(trial, Checkpoint.PERSISTENT)
self.process_trial_save(trial)
self.trial_executor.restore(trial)
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
Expand All @@ -59,6 +73,8 @@ def testSavePauseResumeRestore(self):
checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT)
self.assertEqual(Trial.RUNNING, trial.status)
self.assertEqual(checkpoint.storage, Checkpoint.PERSISTENT)
# Process save result (simulates trial runner)
self.process_trial_save(trial)
# Pause
self.trial_executor.pause_trial(trial)
self.assertEqual(Trial.PAUSED, trial.status)
Expand Down Expand Up @@ -125,11 +141,20 @@ def reset_config(self, config):
self.assertEqual(trial.experiment_tag, "modified_mock")
self.assertEqual(Trial.RUNNING, trial.status)

def generate_trials(self, spec, name):
@staticmethod
def generate_trials(spec, name):
suggester = BasicVariantGenerator()
suggester.add_configurations({name: spec})
return suggester.next_trials()

@staticmethod
def process_trial_save(trial):
"""Simulates trial runner save."""
checkpoint = trial.saving_to
checkpoint_value = ray.get(checkpoint.value)
checkpoint.value = checkpoint_value
trial.on_checkpoint(checkpoint)


class RayExecutorQueueTest(unittest.TestCase):
def setUp(self):
Expand Down
Loading

0 comments on commit 98a07fe

Please sign in to comment.