Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train] Storage refactor: Support PBT and BOHB #38736

Merged
merged 51 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
fa992fc
Adjust save_checkpoint API
Aug 17, 2023
f0f0f41
more
Aug 17, 2023
77be920
fix test
Aug 17, 2023
c2c6073
Merge remote-tracking branch 'upstream/master' into tune/storage-pbt
Aug 17, 2023
711eefa
Update typehints
Aug 17, 2023
8bcc82e
Merge remote-tracking branch 'upstream/master' into tune/storage-pbt
Aug 17, 2023
80ec41e
Merge branch 'master' into tune/storage-pbt
Aug 21, 2023
964247e
undo pause logic
Aug 21, 2023
33e896f
Merge branch 'master' into tune/pbt-bohb-pause
Aug 22, 2023
863ec03
resolve future
Aug 22, 2023
a2eb589
Pausing
Aug 22, 2023
9af362e
skip memory test
Aug 22, 2023
0a98a16
typo
Aug 22, 2023
789752b
Overwrite trial restore path
Aug 22, 2023
965f3db
Merge branch 'master' into tune/pbt-bohb-pause
Aug 22, 2023
fa89632
default 0
Aug 22, 2023
190df4f
[train/tune] Remove save_to_object/restore_from_object
Aug 22, 2023
138f92d
Fixes
Aug 22, 2023
b674dd2
avoid variable name conflict
Aug 22, 2023
b0a1e57
Merge remote-tracking branch 'upstream/master' into tune/remove-save-…
Aug 23, 2023
e6ac302
fix last test
Aug 23, 2023
55f1b84
Merge remote-tracking branch 'upstream/master' into tune/pbt-bohb-pause
Aug 23, 2023
4b624c5
Merge branch 'tune/remove-save-restore-obj' into tune/pbt-bohb-pause
Aug 23, 2023
8c87077
fix last test
Aug 23, 2023
11966c0
Merge remote-tracking branch 'upstream/master' into tune/pbt-bohb-pause
Aug 23, 2023
40819b0
bohb unpause
Aug 23, 2023
031ea23
pbt tests for storage
Aug 23, 2023
209ff6a
fix checkpoint test
Aug 23, 2023
d6839b6
more fixes
Aug 23, 2023
8484c5a
Merge remote-tracking branch 'upstream/master' into tune/pbt-bohb-pause
Aug 23, 2023
101d053
Fix hashing
Aug 23, 2023
0625416
exclude pbt_transformers
Aug 23, 2023
14f2d42
default 0
Aug 23, 2023
04f7c66
fix examples
Aug 23, 2023
75cdcbd
fix some tests
Aug 23, 2023
0b2ee3f
Merge remote-tracking branch 'upstream/master' into tune/pbt-bohb-pause
Aug 23, 2023
359aaad
review
Aug 23, 2023
2515831
Remove changes to old codepath
Aug 23, 2023
7769bae
Merge remote-tracking branch 'upstream/master' into tune/pbt-bohb-pause
Aug 24, 2023
9fc4cc6
remove empty pipeline
Aug 24, 2023
6cbece0
Cache decision in pause
Aug 24, 2023
9ae6dfd
Exploit
Aug 24, 2023
9572d03
Fix trial.checkpoint
Aug 24, 2023
77b4ae9
fix tests
Aug 24, 2023
c3bf12b
review
Aug 24, 2023
80e8a6d
Revert
Aug 24, 2023
d490db5
Merge remote-tracking branch 'upstream/master' into tune/pbt-bohb-pause
Aug 24, 2023
25927f7
Merge branch 'master' into tune/pbt-bohb-pause
krfricke Aug 24, 2023
f697157
Merge remote-tracking branch 'upstream/master' into tune/pbt-bohb-pause
Aug 25, 2023
af8e44c
Update build files, resolve merge logic conflict
Aug 25, 2023
3ba628e
Merge remote-tracking branch 'origin/tune/pbt-bohb-pause' into tune/p…
Aug 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[train/tune] Remove save_to_object/restore_from_object
Signed-off-by: Kai Fricke <kai@anyscale.com>
  • Loading branch information
Kai Fricke committed Aug 22, 2023
commit 190df4f3e02ff21e914a03be1ed2489213ac1620
1 change: 0 additions & 1 deletion python/ray/tune/examples/pbt_convnet_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def reset_config(self, new_config):

# check if PytorchTrainble will save/restore correctly before execution
validate_save_restore(PytorchTrainable)
validate_save_restore(PytorchTrainable, use_object_store=True)

# __pbt_begin__
scheduler = PopulationBasedTraining(
Expand Down
12 changes: 7 additions & 5 deletions python/ray/tune/execution/tune_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -1919,9 +1919,11 @@ def _schedule_trial_save(
return

if storage == CheckpointStorage.MEMORY:
# This is now technically a persistent checkpoint, but
# we don't resolve it. Instead, we register it directly.
future = self._schedule_trial_task(
trial=trial,
method_name="save_to_object",
method_name="save",
on_result=None,
on_error=self._trial_task_failure,
_return_future=True,
Expand Down Expand Up @@ -2080,7 +2082,7 @@ def _schedule_trial_restore(self, trial: Trial) -> bool:
kwargs = {}

if checkpoint.storage_mode == CheckpointStorage.MEMORY:
method_name = "restore_from_object"
method_name = "restore"
args = (checkpoint.dir_or_data,)
elif (
trial.uses_cloud_checkpointing
Expand All @@ -2099,10 +2101,10 @@ def _schedule_trial_restore(self, trial: Trial) -> bool:
}
elif trial.sync_on_checkpoint:
checkpoint_path = TrainableUtil.find_checkpoint_dir(checkpoint.dir_or_data)
obj = Checkpoint.from_directory(checkpoint_path).to_bytes()
checkpoint = Checkpoint.from_directory(checkpoint_path)

method_name = "restore_from_object"
args = (obj,)
method_name = "restore"
args = (checkpoint,)
else:
raise _AbortTrialExecution(
"Pass in `sync_on_checkpoint=True` for driver-based trial"
Expand Down
101 changes: 9 additions & 92 deletions python/ray/tune/tests/test_function_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,35 +97,7 @@ def train(config, checkpoint_dir=None):
new_trainable.stop()
assert result[TRAINING_ITERATION] == 10

def testCheckpointReuseObject(self):
"""Test that repeated save/restore never reuses same checkpoint dir."""

def train(config, checkpoint_dir=None):
if checkpoint_dir:
count = sum(
"checkpoint-" in path for path in os.listdir(checkpoint_dir)
)
assert count == 1, os.listdir(checkpoint_dir)

for step in range(20):
with tune.checkpoint_dir(step=step) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint-{}".format(step))
open(path, "a").close()
tune.report(test=step)

wrapped = wrap_function(train)
checkpoint = None
for i in range(5):
new_trainable = wrapped(logger_creator=self.logger_creator)
if checkpoint:
new_trainable.restore_from_object(checkpoint)
for i in range(2):
result = new_trainable.train()
checkpoint = new_trainable.save_to_object()
new_trainable.stop()
self.assertTrue(result[TRAINING_ITERATION] == 10)

def testCheckpointReuseObjectWithoutTraining(self):
def testCheckpointReuseWithoutTraining(self):
"""Test that repeated save/restore never reuses same checkpoint dir."""

def train(config, checkpoint_dir=None):
Expand All @@ -145,15 +117,15 @@ def train(config, checkpoint_dir=None):
new_trainable = wrapped(logger_creator=self.logger_creator)
for i in range(2):
result = new_trainable.train()
checkpoint = new_trainable.save_to_object()
checkpoint = new_trainable.save()
new_trainable.stop()

new_trainable2 = wrapped(logger_creator=self.logger_creator)
new_trainable2.restore_from_object(checkpoint)
new_trainable2.restore(checkpoint)
new_trainable2.stop()

new_trainable2 = wrapped(logger_creator=self.logger_creator)
new_trainable2.restore_from_object(checkpoint)
new_trainable2.restore(checkpoint)
result = new_trainable2.train()
new_trainable2.stop()
self.assertTrue(result[TRAINING_ITERATION] == 3)
Expand Down Expand Up @@ -203,23 +175,6 @@ def train(config, checkpoint_dir=None):
new_trainable.stop()
self.assertTrue(result[TRAINING_ITERATION] == 1)

def testMultipleNullMemoryCheckpoints(self):
def train(config, checkpoint_dir=None):
assert not checkpoint_dir
for step in range(10):
tune.report(test=step)

wrapped = wrap_function(train)
checkpoint = None
for i in range(5):
new_trainable = wrapped(logger_creator=self.logger_creator)
if checkpoint:
new_trainable.restore_from_object(checkpoint)
result = new_trainable.train()
checkpoint = new_trainable.save_to_object()
new_trainable.stop()
assert result[TRAINING_ITERATION] == 1

def testFunctionNoCheckpointing(self):
def train(config, checkpoint_dir=None):
if checkpoint_dir:
Expand Down Expand Up @@ -259,8 +214,8 @@ def train(config, checkpoint_dir=None):

new_trainable = wrapped(logger_creator=self.logger_creator)
new_trainable.train()
checkpoint_obj = new_trainable.save_to_object()
new_trainable.restore_from_object(checkpoint_obj)
checkpoint_obj = new_trainable.save()
new_trainable.restore(checkpoint_obj)
checkpoint = new_trainable.save()

new_trainable.stop()
Expand All @@ -287,16 +242,14 @@ def train(config, checkpoint_dir=None):
new_trainable = wrapped(logger_creator=self.logger_creator)
new_trainable.train()
new_trainable.train()
checkpoint_obj = new_trainable.save_to_object()
checkpoint_obj = new_trainable.save()
new_trainable.stop()

new_trainable2 = wrapped(logger_creator=self.logger_creator)
new_trainable2.restore_from_object(checkpoint_obj)
assert sum("tmp" in path for path in os.listdir(self.logdir)) == 1
checkpoint_obj = new_trainable2.save_to_object()
new_trainable2.restore(checkpoint_obj)
checkpoint_obj = new_trainable2.save()
new_trainable2.train()
result = new_trainable2.train()
assert sum("tmp" in path for path in os.listdir(self.logdir)) == 1
new_trainable2.stop()
assert sum("tmp" in path for path in os.listdir(self.logdir)) == 0
assert result[TRAINING_ITERATION] == 4
Expand Down Expand Up @@ -602,42 +555,6 @@ def train(config):
self.assertEqual(trial_2.last_result["m"], 8 + 9)


def test_restore_from_object_delete(tmp_path):
"""Test that temporary checkpoint directories are deleted after restoring.

`FunctionTrainable.restore_from_object` creates a temporary checkpoint directory.
This directory is kept around as we don't control how the user interacts with
the checkpoint - they might load it several times, or no time at all.

Once a new checkpoint is tracked in the status reporter, there is no need to keep
the temporary object around anymore. This test asserts that the temporary
checkpoint directories are then deleted.
"""
# Create 2 checkpoints
cp_1 = TrainableUtil.make_checkpoint_dir(str(tmp_path), index=1, override=True)
cp_2 = TrainableUtil.make_checkpoint_dir(str(tmp_path), index=2, override=True)

# Instantiate function trainable
trainable = FunctionTrainable()
trainable._logdir = str(tmp_path)
trainable._status_reporter.set_checkpoint(cp_1)

# Save to object and restore. This will create a temporary checkpoint directory.
cp_obj = trainable.save_to_object()
trainable.restore_from_object(cp_obj)

# Assert there is at least one `checkpoint_tmpxxxxx` directory in the logdir
assert any(path.name.startswith("checkpoint_tmp") for path in tmp_path.iterdir())

# Track a new checkpoint. This should delete the temporary checkpoint directory.
trainable._status_reporter.set_checkpoint(cp_2)

# Directory should have been deleted
assert not any(
path.name.startswith("checkpoint_tmp") for path in tmp_path.iterdir()
)


if __name__ == "__main__":
import pytest

Expand Down
66 changes: 0 additions & 66 deletions python/ray/tune/tests/test_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,25 +129,6 @@ class trainables.
ray.get(restoring_future)


@pytest.mark.parametrize("return_type", ["object", "root", "subdir", "checkpoint"])
def test_save_load_checkpoint_object_class(ray_start_2_cpus, return_type):
"""Assert that restoring from a Trainable.save_to_object() future works with
class trainables.

Needs Ray cluster so we get actual futures.
"""
trainable = ray.remote(SavingTrainable).remote(return_type=return_type)

saving_future = trainable.save_to_object.remote()

# Check for errors
ray.get(saving_future)

restoring_future = trainable.restore_from_object.remote(saving_future)

ray.get(restoring_future)


@pytest.mark.parametrize(
"fn_trainable", [function_trainable_dict, function_trainable_directory]
)
Expand All @@ -171,53 +152,6 @@ def test_save_load_checkpoint_path_fn(ray_start_2_cpus, fn_trainable):
ray.get(restoring_future)


@pytest.mark.parametrize(
"fn_trainable", [function_trainable_dict, function_trainable_directory]
)
def test_save_load_checkpoint_object_fn(ray_start_2_cpus, fn_trainable):
"""Assert that restoring from a Trainable.save_to_object() future works with
function trainables.

Needs Ray cluster so we get actual futures.
"""
trainable_cls = wrap_function(fn_trainable)
trainable = ray.remote(trainable_cls).remote()
ray.get(trainable.train.remote())

saving_future = trainable.save_to_object.remote()

# Check for errors
ray.get(saving_future)

restoring_future = trainable.restore_from_object.remote(saving_future)

ray.get(restoring_future)


def test_checkpoint_object_no_sync(tmpdir):
"""Asserts that save_to_object() and restore_from_object() do not sync up/down"""
trainable = SavingTrainable(
"object", remote_checkpoint_dir="memory:///test/location"
)

# Save checkpoint
trainable.save()

check_dir = tmpdir / "check_save"
download_from_uri(uri="memory:///test/location", local_path=str(check_dir))
assert os.listdir(str(check_dir)) == ["checkpoint_000000"]

# Save to object
obj = trainable.save_to_object()

check_dir = tmpdir / "check_save_obj"
download_from_uri(uri="memory:///test/location", local_path=str(check_dir))
assert os.listdir(str(check_dir)) == ["checkpoint_000000"]

# Restore from object
trainable.restore_from_object(obj)


@pytest.mark.parametrize("hanging", [True, False])
def test_sync_timeout(tmpdir, monkeypatch, hanging):
monkeypatch.setenv("TUNE_CHECKPOINT_CLOUD_RETRY_WAIT_TIME_S", "0")
Expand Down
4 changes: 0 additions & 4 deletions python/ray/tune/tests/test_tune_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,23 +539,19 @@ def testPBTKeras(self):

cifar10.load_data()
validate_save_restore(Cifar10Model)
validate_save_restore(Cifar10Model, use_object_store=True)

def testPyTorchMNIST(self):
from ray.tune.examples.mnist_pytorch_trainable import TrainMNIST
from torchvision import datasets

datasets.MNIST("~/data", train=True, download=True)
validate_save_restore(TrainMNIST)
validate_save_restore(TrainMNIST, use_object_store=True)

def testHyperbandExample(self):
validate_save_restore(MyTrainableClass)
validate_save_restore(MyTrainableClass, use_object_store=True)

def testAsyncHyperbandExample(self):
validate_save_restore(MyTrainableClass)
validate_save_restore(MyTrainableClass, use_object_store=True)


class AutoInitTest(unittest.TestCase):
Expand Down
1 change: 0 additions & 1 deletion python/ray/tune/tests/test_tune_save_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ def load_checkpoint(self, checkpoint_dir):
return checkpoint_dir

validate_save_restore(MockTrainable)
validate_save_restore(MockTrainable, use_object_store=True)


if __name__ == "__main__":
Expand Down
14 changes: 0 additions & 14 deletions python/ray/tune/trainable/function_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,11 +587,6 @@ def _create_checkpoint_dir(
) -> Optional[str]:
return None

def save_to_object(self):
checkpoint_path = self.save()
checkpoint = Checkpoint.from_directory(checkpoint_path)
return checkpoint.to_bytes()

def load_checkpoint(self, checkpoint):
if _use_storage_context():
checkpoint_result = checkpoint
Expand Down Expand Up @@ -619,15 +614,6 @@ def _restore_from_checkpoint_obj(self, checkpoint: Checkpoint):
checkpoint.to_directory(self.temp_checkpoint_dir)
self.restore(self.temp_checkpoint_dir)

def restore_from_object(self, obj):
self.temp_checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(
self.logdir
)
checkpoint = Checkpoint.from_bytes(obj)
checkpoint.to_directory(self.temp_checkpoint_dir)

self.restore(self.temp_checkpoint_dir)

def cleanup(self):
if _use_storage_context():
session = get_session()
Expand Down
Loading