Skip to content

Commit

Permalink
[train/tune] Remove save_to_object/restore_from_object (ray-project#3…
Browse files Browse the repository at this point in the history
…8757)

This PR removes `Trainable.save_to_object` and `Trainable.restore_from_object`.

These methods were technically public (as they were part of a public class), but they were only used internally in practice. That's why we believe it's fine to remove them without a full deprecation cycle.

The methods stem from a time when a central storage location was not guaranteed. Then, we needed to be able to ship checkpoints between nodes - e.g. from the head node to the worker nodes where the trial is executed, or between nodes, e.g. when PBT trials exploit another trial.

These assumptions are now superseded by a central storage location. We require users to save checkpoints in a location which all nodes can access. We then only have to ship a thin metadata wrapper around, instead of the full serialized checkpoint data. For large checkpoints, this was always infeasible.

Signed-off-by: Kai Fricke <kai@anyscale.com>
  • Loading branch information
krfricke authored Aug 23, 2023
1 parent 548f810 commit 2d0544b
Show file tree
Hide file tree
Showing 12 changed files with 89 additions and 296 deletions.
1 change: 0 additions & 1 deletion python/ray/tune/examples/pbt_convnet_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def reset_config(self, new_config):

# check if PytorchTrainble will save/restore correctly before execution
validate_save_restore(PytorchTrainable)
validate_save_restore(PytorchTrainable, use_object_store=True)

# __pbt_begin__
scheduler = PopulationBasedTraining(
Expand Down
10 changes: 6 additions & 4 deletions python/ray/tune/execution/tune_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -1919,9 +1919,11 @@ def _schedule_trial_save(
return

if storage == CheckpointStorage.MEMORY:
# This is now technically a persistent checkpoint, but
# we don't resolve it. Instead, we register it directly.
future = self._schedule_trial_task(
trial=trial,
method_name="save_to_object",
method_name="save",
on_result=None,
on_error=self._trial_task_failure,
_return_future=True,
Expand Down Expand Up @@ -2080,7 +2082,7 @@ def _schedule_trial_restore(self, trial: Trial) -> bool:
kwargs = {}

if checkpoint.storage_mode == CheckpointStorage.MEMORY:
method_name = "restore_from_object"
method_name = "restore"
args = (checkpoint.dir_or_data,)
elif (
trial.uses_cloud_checkpointing
Expand All @@ -2099,9 +2101,9 @@ def _schedule_trial_restore(self, trial: Trial) -> bool:
}
elif trial.sync_on_checkpoint:
checkpoint_path = TrainableUtil.find_checkpoint_dir(checkpoint.dir_or_data)
obj = Checkpoint.from_directory(checkpoint_path).to_bytes()
obj = Checkpoint.from_directory(checkpoint_path)

method_name = "restore_from_object"
method_name = "restore"
args = (obj,)
else:
raise _AbortTrialExecution(
Expand Down
102 changes: 9 additions & 93 deletions python/ray/tune/tests/test_function_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
with_parameters,
wrap_function,
FuncCheckpointUtil,
FunctionTrainable,
)
from ray.tune.result import DEFAULT_METRIC
from ray.tune.schedulers import ResourceChangingScheduler
Expand Down Expand Up @@ -97,35 +96,7 @@ def train(config, checkpoint_dir=None):
new_trainable.stop()
assert result[TRAINING_ITERATION] == 10

def testCheckpointReuseObject(self):
"""Test that repeated save/restore never reuses same checkpoint dir."""

def train(config, checkpoint_dir=None):
if checkpoint_dir:
count = sum(
"checkpoint-" in path for path in os.listdir(checkpoint_dir)
)
assert count == 1, os.listdir(checkpoint_dir)

for step in range(20):
with tune.checkpoint_dir(step=step) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint-{}".format(step))
open(path, "a").close()
tune.report(test=step)

wrapped = wrap_function(train)
checkpoint = None
for i in range(5):
new_trainable = wrapped(logger_creator=self.logger_creator)
if checkpoint:
new_trainable.restore_from_object(checkpoint)
for i in range(2):
result = new_trainable.train()
checkpoint = new_trainable.save_to_object()
new_trainable.stop()
self.assertTrue(result[TRAINING_ITERATION] == 10)

def testCheckpointReuseObjectWithoutTraining(self):
def testCheckpointReuseWithoutTraining(self):
"""Test that repeated save/restore never reuses same checkpoint dir."""

def train(config, checkpoint_dir=None):
Expand All @@ -145,15 +116,15 @@ def train(config, checkpoint_dir=None):
new_trainable = wrapped(logger_creator=self.logger_creator)
for i in range(2):
result = new_trainable.train()
checkpoint = new_trainable.save_to_object()
checkpoint = new_trainable.save()
new_trainable.stop()

new_trainable2 = wrapped(logger_creator=self.logger_creator)
new_trainable2.restore_from_object(checkpoint)
new_trainable2.restore(checkpoint)
new_trainable2.stop()

new_trainable2 = wrapped(logger_creator=self.logger_creator)
new_trainable2.restore_from_object(checkpoint)
new_trainable2.restore(checkpoint)
result = new_trainable2.train()
new_trainable2.stop()
self.assertTrue(result[TRAINING_ITERATION] == 3)
Expand Down Expand Up @@ -203,23 +174,6 @@ def train(config, checkpoint_dir=None):
new_trainable.stop()
self.assertTrue(result[TRAINING_ITERATION] == 1)

def testMultipleNullMemoryCheckpoints(self):
def train(config, checkpoint_dir=None):
assert not checkpoint_dir
for step in range(10):
tune.report(test=step)

wrapped = wrap_function(train)
checkpoint = None
for i in range(5):
new_trainable = wrapped(logger_creator=self.logger_creator)
if checkpoint:
new_trainable.restore_from_object(checkpoint)
result = new_trainable.train()
checkpoint = new_trainable.save_to_object()
new_trainable.stop()
assert result[TRAINING_ITERATION] == 1

def testFunctionNoCheckpointing(self):
def train(config, checkpoint_dir=None):
if checkpoint_dir:
Expand Down Expand Up @@ -259,8 +213,8 @@ def train(config, checkpoint_dir=None):

new_trainable = wrapped(logger_creator=self.logger_creator)
new_trainable.train()
checkpoint_obj = new_trainable.save_to_object()
new_trainable.restore_from_object(checkpoint_obj)
checkpoint_obj = new_trainable.save()
new_trainable.restore(checkpoint_obj)
checkpoint = new_trainable.save()

new_trainable.stop()
Expand All @@ -287,16 +241,14 @@ def train(config, checkpoint_dir=None):
new_trainable = wrapped(logger_creator=self.logger_creator)
new_trainable.train()
new_trainable.train()
checkpoint_obj = new_trainable.save_to_object()
checkpoint_obj = new_trainable.save()
new_trainable.stop()

new_trainable2 = wrapped(logger_creator=self.logger_creator)
new_trainable2.restore_from_object(checkpoint_obj)
assert sum("tmp" in path for path in os.listdir(self.logdir)) == 1
checkpoint_obj = new_trainable2.save_to_object()
new_trainable2.restore(checkpoint_obj)
checkpoint_obj = new_trainable2.save()
new_trainable2.train()
result = new_trainable2.train()
assert sum("tmp" in path for path in os.listdir(self.logdir)) == 1
new_trainable2.stop()
assert sum("tmp" in path for path in os.listdir(self.logdir)) == 0
assert result[TRAINING_ITERATION] == 4
Expand Down Expand Up @@ -602,42 +554,6 @@ def train(config):
self.assertEqual(trial_2.last_result["m"], 8 + 9)


def test_restore_from_object_delete(tmp_path):
"""Test that temporary checkpoint directories are deleted after restoring.
`FunctionTrainable.restore_from_object` creates a temporary checkpoint directory.
This directory is kept around as we don't control how the user interacts with
the checkpoint - they might load it several times, or no time at all.
Once a new checkpoint is tracked in the status reporter, there is no need to keep
the temporary object around anymore. This test asserts that the temporary
checkpoint directories are then deleted.
"""
# Create 2 checkpoints
cp_1 = TrainableUtil.make_checkpoint_dir(str(tmp_path), index=1, override=True)
cp_2 = TrainableUtil.make_checkpoint_dir(str(tmp_path), index=2, override=True)

# Instantiate function trainable
trainable = FunctionTrainable()
trainable._logdir = str(tmp_path)
trainable._status_reporter.set_checkpoint(cp_1)

# Save to object and restore. This will create a temporary checkpoint directory.
cp_obj = trainable.save_to_object()
trainable.restore_from_object(cp_obj)

# Assert there is at least one `checkpoint_tmpxxxxx` directory in the logdir
assert any(path.name.startswith("checkpoint_tmp") for path in tmp_path.iterdir())

# Track a new checkpoint. This should delete the temporary checkpoint directory.
trainable._status_reporter.set_checkpoint(cp_2)

# Directory should have been deleted
assert not any(
path.name.startswith("checkpoint_tmp") for path in tmp_path.iterdir()
)


if __name__ == "__main__":
import pytest

Expand Down
6 changes: 1 addition & 5 deletions python/ray/tune/tests/test_syncer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,11 +539,7 @@ def train_fn(config):
config={}, logger_creator=logger_creator
)

ray.get(trainable.save_to_object.remote())

# Temporary directory exists
assert_file(True, tmp_source, "checkpoint_-00001/" + NULL_MARKER)
assert_file(True, tmp_source, "checkpoint_-00001")
ray.get(trainable.save.remote())

# Create some bogus test directories for testing
os.mkdir(os.path.join(tmp_source, "checkpoint_tmp123"))
Expand Down
66 changes: 0 additions & 66 deletions python/ray/tune/tests/test_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,25 +129,6 @@ class trainables.
ray.get(restoring_future)


@pytest.mark.parametrize("return_type", ["object", "root", "subdir", "checkpoint"])
def test_save_load_checkpoint_object_class(ray_start_2_cpus, return_type):
"""Assert that restoring from a Trainable.save_to_object() future works with
class trainables.
Needs Ray cluster so we get actual futures.
"""
trainable = ray.remote(SavingTrainable).remote(return_type=return_type)

saving_future = trainable.save_to_object.remote()

# Check for errors
ray.get(saving_future)

restoring_future = trainable.restore_from_object.remote(saving_future)

ray.get(restoring_future)


@pytest.mark.parametrize(
"fn_trainable", [function_trainable_dict, function_trainable_directory]
)
Expand All @@ -171,53 +152,6 @@ def test_save_load_checkpoint_path_fn(ray_start_2_cpus, fn_trainable):
ray.get(restoring_future)


@pytest.mark.parametrize(
"fn_trainable", [function_trainable_dict, function_trainable_directory]
)
def test_save_load_checkpoint_object_fn(ray_start_2_cpus, fn_trainable):
"""Assert that restoring from a Trainable.save_to_object() future works with
function trainables.
Needs Ray cluster so we get actual futures.
"""
trainable_cls = wrap_function(fn_trainable)
trainable = ray.remote(trainable_cls).remote()
ray.get(trainable.train.remote())

saving_future = trainable.save_to_object.remote()

# Check for errors
ray.get(saving_future)

restoring_future = trainable.restore_from_object.remote(saving_future)

ray.get(restoring_future)


def test_checkpoint_object_no_sync(tmpdir):
"""Asserts that save_to_object() and restore_from_object() do not sync up/down"""
trainable = SavingTrainable(
"object", remote_checkpoint_dir="memory:///test/location"
)

# Save checkpoint
trainable.save()

check_dir = tmpdir / "check_save"
download_from_uri(uri="memory:///test/location", local_path=str(check_dir))
assert os.listdir(str(check_dir)) == ["checkpoint_000000"]

# Save to object
obj = trainable.save_to_object()

check_dir = tmpdir / "check_save_obj"
download_from_uri(uri="memory:///test/location", local_path=str(check_dir))
assert os.listdir(str(check_dir)) == ["checkpoint_000000"]

# Restore from object
trainable.restore_from_object(obj)


@pytest.mark.parametrize("hanging", [True, False])
def test_sync_timeout(tmpdir, monkeypatch, hanging):
monkeypatch.setenv("TUNE_CHECKPOINT_CLOUD_RETRY_WAIT_TIME_S", "0")
Expand Down
4 changes: 0 additions & 4 deletions python/ray/tune/tests/test_tune_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,23 +539,19 @@ def testPBTKeras(self):

cifar10.load_data()
validate_save_restore(Cifar10Model)
validate_save_restore(Cifar10Model, use_object_store=True)

def testPyTorchMNIST(self):
from ray.tune.examples.mnist_pytorch_trainable import TrainMNIST
from torchvision import datasets

datasets.MNIST("~/data", train=True, download=True)
validate_save_restore(TrainMNIST)
validate_save_restore(TrainMNIST, use_object_store=True)

def testHyperbandExample(self):
validate_save_restore(MyTrainableClass)
validate_save_restore(MyTrainableClass, use_object_store=True)

def testAsyncHyperbandExample(self):
validate_save_restore(MyTrainableClass)
validate_save_restore(MyTrainableClass, use_object_store=True)


class AutoInitTest(unittest.TestCase):
Expand Down
1 change: 0 additions & 1 deletion python/ray/tune/tests/test_tune_save_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ def load_checkpoint(self, checkpoint_dir):
return checkpoint_dir

validate_save_restore(MockTrainable)
validate_save_restore(MockTrainable, use_object_store=True)


if __name__ == "__main__":
Expand Down
14 changes: 0 additions & 14 deletions python/ray/tune/trainable/function_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,11 +587,6 @@ def _create_checkpoint_dir(
) -> Optional[str]:
return None

def save_to_object(self):
checkpoint_path = self.save()
checkpoint = Checkpoint.from_directory(checkpoint_path)
return checkpoint.to_bytes()

def load_checkpoint(self, checkpoint):
if _use_storage_context():
checkpoint_result = checkpoint
Expand Down Expand Up @@ -619,15 +614,6 @@ def _restore_from_checkpoint_obj(self, checkpoint: Checkpoint):
checkpoint.to_directory(self.temp_checkpoint_dir)
self.restore(self.temp_checkpoint_dir)

def restore_from_object(self, obj):
self.temp_checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(
self.logdir
)
checkpoint = Checkpoint.from_bytes(obj)
checkpoint.to_directory(self.temp_checkpoint_dir)

self.restore(self.temp_checkpoint_dir)

def cleanup(self):
if _use_storage_context():
session = get_session()
Expand Down
Loading

0 comments on commit 2d0544b

Please sign in to comment.