Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train/tune] Remove save_to_object/restore_from_object #38757

Merged
merged 4 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion python/ray/tune/examples/pbt_convnet_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def reset_config(self, new_config):

# check if PytorchTrainble will save/restore correctly before execution
validate_save_restore(PytorchTrainable)
validate_save_restore(PytorchTrainable, use_object_store=True)

# __pbt_begin__
scheduler = PopulationBasedTraining(
Expand Down
12 changes: 7 additions & 5 deletions python/ray/tune/execution/tune_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -1919,9 +1919,11 @@ def _schedule_trial_save(
return

if storage == CheckpointStorage.MEMORY:
# This is now technically a persistent checkpoint, but
# we don't resolve it. Instead, we register it directly.
future = self._schedule_trial_task(
trial=trial,
method_name="save_to_object",
method_name="save",
on_result=None,
on_error=self._trial_task_failure,
_return_future=True,
Expand Down Expand Up @@ -2080,7 +2082,7 @@ def _schedule_trial_restore(self, trial: Trial) -> bool:
kwargs = {}

if checkpoint.storage_mode == CheckpointStorage.MEMORY:
method_name = "restore_from_object"
method_name = "restore"
args = (checkpoint.dir_or_data,)
elif (
trial.uses_cloud_checkpointing
Expand All @@ -2099,10 +2101,10 @@ def _schedule_trial_restore(self, trial: Trial) -> bool:
}
elif trial.sync_on_checkpoint:
checkpoint_path = TrainableUtil.find_checkpoint_dir(checkpoint.dir_or_data)
obj = Checkpoint.from_directory(checkpoint_path).to_bytes()
checkpoint = Checkpoint.from_directory(checkpoint_path)

method_name = "restore_from_object"
args = (obj,)
method_name = "restore"
args = (checkpoint,)
else:
raise _AbortTrialExecution(
"Pass in `sync_on_checkpoint=True` for driver-based trial"
Expand Down
101 changes: 9 additions & 92 deletions python/ray/tune/tests/test_function_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,35 +97,7 @@ def train(config, checkpoint_dir=None):
new_trainable.stop()
assert result[TRAINING_ITERATION] == 10

def testCheckpointReuseObject(self):
"""Test that repeated save/restore never reuses same checkpoint dir."""

def train(config, checkpoint_dir=None):
if checkpoint_dir:
count = sum(
"checkpoint-" in path for path in os.listdir(checkpoint_dir)
)
assert count == 1, os.listdir(checkpoint_dir)

for step in range(20):
with tune.checkpoint_dir(step=step) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint-{}".format(step))
open(path, "a").close()
tune.report(test=step)

wrapped = wrap_function(train)
checkpoint = None
for i in range(5):
new_trainable = wrapped(logger_creator=self.logger_creator)
if checkpoint:
new_trainable.restore_from_object(checkpoint)
for i in range(2):
result = new_trainable.train()
checkpoint = new_trainable.save_to_object()
new_trainable.stop()
self.assertTrue(result[TRAINING_ITERATION] == 10)

def testCheckpointReuseObjectWithoutTraining(self):
def testCheckpointReuseWithoutTraining(self):
"""Test that repeated save/restore never reuses same checkpoint dir."""

def train(config, checkpoint_dir=None):
Expand All @@ -145,15 +117,15 @@ def train(config, checkpoint_dir=None):
new_trainable = wrapped(logger_creator=self.logger_creator)
for i in range(2):
result = new_trainable.train()
checkpoint = new_trainable.save_to_object()
checkpoint = new_trainable.save()
new_trainable.stop()

new_trainable2 = wrapped(logger_creator=self.logger_creator)
new_trainable2.restore_from_object(checkpoint)
new_trainable2.restore(checkpoint)
new_trainable2.stop()

new_trainable2 = wrapped(logger_creator=self.logger_creator)
new_trainable2.restore_from_object(checkpoint)
new_trainable2.restore(checkpoint)
result = new_trainable2.train()
new_trainable2.stop()
self.assertTrue(result[TRAINING_ITERATION] == 3)
Expand Down Expand Up @@ -203,23 +175,6 @@ def train(config, checkpoint_dir=None):
new_trainable.stop()
self.assertTrue(result[TRAINING_ITERATION] == 1)

def testMultipleNullMemoryCheckpoints(self):
def train(config, checkpoint_dir=None):
assert not checkpoint_dir
for step in range(10):
tune.report(test=step)

wrapped = wrap_function(train)
checkpoint = None
for i in range(5):
new_trainable = wrapped(logger_creator=self.logger_creator)
if checkpoint:
new_trainable.restore_from_object(checkpoint)
result = new_trainable.train()
checkpoint = new_trainable.save_to_object()
new_trainable.stop()
assert result[TRAINING_ITERATION] == 1

def testFunctionNoCheckpointing(self):
def train(config, checkpoint_dir=None):
if checkpoint_dir:
Expand Down Expand Up @@ -259,8 +214,8 @@ def train(config, checkpoint_dir=None):

new_trainable = wrapped(logger_creator=self.logger_creator)
new_trainable.train()
checkpoint_obj = new_trainable.save_to_object()
new_trainable.restore_from_object(checkpoint_obj)
checkpoint_obj = new_trainable.save()
new_trainable.restore(checkpoint_obj)
checkpoint = new_trainable.save()

new_trainable.stop()
Expand All @@ -287,16 +242,14 @@ def train(config, checkpoint_dir=None):
new_trainable = wrapped(logger_creator=self.logger_creator)
new_trainable.train()
new_trainable.train()
checkpoint_obj = new_trainable.save_to_object()
checkpoint_obj = new_trainable.save()
new_trainable.stop()

new_trainable2 = wrapped(logger_creator=self.logger_creator)
new_trainable2.restore_from_object(checkpoint_obj)
assert sum("tmp" in path for path in os.listdir(self.logdir)) == 1
checkpoint_obj = new_trainable2.save_to_object()
new_trainable2.restore(checkpoint_obj)
checkpoint_obj = new_trainable2.save()
new_trainable2.train()
result = new_trainable2.train()
assert sum("tmp" in path for path in os.listdir(self.logdir)) == 1
new_trainable2.stop()
assert sum("tmp" in path for path in os.listdir(self.logdir)) == 0
assert result[TRAINING_ITERATION] == 4
Expand Down Expand Up @@ -602,42 +555,6 @@ def train(config):
self.assertEqual(trial_2.last_result["m"], 8 + 9)


def test_restore_from_object_delete(tmp_path):
"""Test that temporary checkpoint directories are deleted after restoring.

`FunctionTrainable.restore_from_object` creates a temporary checkpoint directory.
This directory is kept around as we don't control how the user interacts with
the checkpoint - they might load it several times, or no time at all.

Once a new checkpoint is tracked in the status reporter, there is no need to keep
the temporary object around anymore. This test asserts that the temporary
checkpoint directories are then deleted.
"""
# Create 2 checkpoints
cp_1 = TrainableUtil.make_checkpoint_dir(str(tmp_path), index=1, override=True)
cp_2 = TrainableUtil.make_checkpoint_dir(str(tmp_path), index=2, override=True)

# Instantiate function trainable
trainable = FunctionTrainable()
trainable._logdir = str(tmp_path)
trainable._status_reporter.set_checkpoint(cp_1)

# Save to object and restore. This will create a temporary checkpoint directory.
cp_obj = trainable.save_to_object()
trainable.restore_from_object(cp_obj)

# Assert there is at least one `checkpoint_tmpxxxxx` directory in the logdir
assert any(path.name.startswith("checkpoint_tmp") for path in tmp_path.iterdir())

# Track a new checkpoint. This should delete the temporary checkpoint directory.
trainable._status_reporter.set_checkpoint(cp_2)

# Directory should have been deleted
assert not any(
path.name.startswith("checkpoint_tmp") for path in tmp_path.iterdir()
)


if __name__ == "__main__":
import pytest

Expand Down
66 changes: 0 additions & 66 deletions python/ray/tune/tests/test_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,25 +129,6 @@ class trainables.
ray.get(restoring_future)


@pytest.mark.parametrize("return_type", ["object", "root", "subdir", "checkpoint"])
def test_save_load_checkpoint_object_class(ray_start_2_cpus, return_type):
"""Assert that restoring from a Trainable.save_to_object() future works with
class trainables.

Needs Ray cluster so we get actual futures.
"""
trainable = ray.remote(SavingTrainable).remote(return_type=return_type)

saving_future = trainable.save_to_object.remote()

# Check for errors
ray.get(saving_future)

restoring_future = trainable.restore_from_object.remote(saving_future)

ray.get(restoring_future)


@pytest.mark.parametrize(
"fn_trainable", [function_trainable_dict, function_trainable_directory]
)
Expand All @@ -171,53 +152,6 @@ def test_save_load_checkpoint_path_fn(ray_start_2_cpus, fn_trainable):
ray.get(restoring_future)


@pytest.mark.parametrize(
"fn_trainable", [function_trainable_dict, function_trainable_directory]
)
def test_save_load_checkpoint_object_fn(ray_start_2_cpus, fn_trainable):
"""Assert that restoring from a Trainable.save_to_object() future works with
function trainables.

Needs Ray cluster so we get actual futures.
"""
trainable_cls = wrap_function(fn_trainable)
trainable = ray.remote(trainable_cls).remote()
ray.get(trainable.train.remote())

saving_future = trainable.save_to_object.remote()

# Check for errors
ray.get(saving_future)

restoring_future = trainable.restore_from_object.remote(saving_future)

ray.get(restoring_future)


def test_checkpoint_object_no_sync(tmpdir):
"""Asserts that save_to_object() and restore_from_object() do not sync up/down"""
trainable = SavingTrainable(
"object", remote_checkpoint_dir="memory:///test/location"
)

# Save checkpoint
trainable.save()

check_dir = tmpdir / "check_save"
download_from_uri(uri="memory:///test/location", local_path=str(check_dir))
assert os.listdir(str(check_dir)) == ["checkpoint_000000"]

# Save to object
obj = trainable.save_to_object()

check_dir = tmpdir / "check_save_obj"
download_from_uri(uri="memory:///test/location", local_path=str(check_dir))
assert os.listdir(str(check_dir)) == ["checkpoint_000000"]

# Restore from object
trainable.restore_from_object(obj)


@pytest.mark.parametrize("hanging", [True, False])
def test_sync_timeout(tmpdir, monkeypatch, hanging):
monkeypatch.setenv("TUNE_CHECKPOINT_CLOUD_RETRY_WAIT_TIME_S", "0")
Expand Down
4 changes: 0 additions & 4 deletions python/ray/tune/tests/test_tune_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,23 +539,19 @@ def testPBTKeras(self):

cifar10.load_data()
validate_save_restore(Cifar10Model)
validate_save_restore(Cifar10Model, use_object_store=True)

def testPyTorchMNIST(self):
from ray.tune.examples.mnist_pytorch_trainable import TrainMNIST
from torchvision import datasets

datasets.MNIST("~/data", train=True, download=True)
validate_save_restore(TrainMNIST)
validate_save_restore(TrainMNIST, use_object_store=True)

def testHyperbandExample(self):
validate_save_restore(MyTrainableClass)
validate_save_restore(MyTrainableClass, use_object_store=True)

def testAsyncHyperbandExample(self):
validate_save_restore(MyTrainableClass)
validate_save_restore(MyTrainableClass, use_object_store=True)


class AutoInitTest(unittest.TestCase):
Expand Down
1 change: 0 additions & 1 deletion python/ray/tune/tests/test_tune_save_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ def load_checkpoint(self, checkpoint_dir):
return checkpoint_dir

validate_save_restore(MockTrainable)
validate_save_restore(MockTrainable, use_object_store=True)


if __name__ == "__main__":
Expand Down
14 changes: 0 additions & 14 deletions python/ray/tune/trainable/function_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,11 +587,6 @@ def _create_checkpoint_dir(
) -> Optional[str]:
return None

def save_to_object(self):
checkpoint_path = self.save()
checkpoint = Checkpoint.from_directory(checkpoint_path)
return checkpoint.to_bytes()

def load_checkpoint(self, checkpoint):
if _use_storage_context():
checkpoint_result = checkpoint
Expand Down Expand Up @@ -619,15 +614,6 @@ def _restore_from_checkpoint_obj(self, checkpoint: Checkpoint):
checkpoint.to_directory(self.temp_checkpoint_dir)
self.restore(self.temp_checkpoint_dir)

def restore_from_object(self, obj):
self.temp_checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(
self.logdir
)
checkpoint = Checkpoint.from_bytes(obj)
checkpoint.to_directory(self.temp_checkpoint_dir)

self.restore(self.temp_checkpoint_dir)

def cleanup(self):
if _use_storage_context():
session = get_session()
Expand Down
Loading