Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train] Storage refactor: Support PBT and BOHB #38736

Merged
merged 51 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
fa992fc
Adjust save_checkpoint API
Aug 17, 2023
f0f0f41
more
Aug 17, 2023
77be920
fix test
Aug 17, 2023
c2c6073
Merge remote-tracking branch 'upstream/master' into tune/storage-pbt
Aug 17, 2023
711eefa
Update typehints
Aug 17, 2023
8bcc82e
Merge remote-tracking branch 'upstream/master' into tune/storage-pbt
Aug 17, 2023
80ec41e
Merge branch 'master' into tune/storage-pbt
Aug 21, 2023
964247e
undo pause logic
Aug 21, 2023
33e896f
Merge branch 'master' into tune/pbt-bohb-pause
Aug 22, 2023
863ec03
resolve future
Aug 22, 2023
a2eb589
Pausing
Aug 22, 2023
9af362e
skip memory test
Aug 22, 2023
0a98a16
typo
Aug 22, 2023
789752b
Overwrite trial restore path
Aug 22, 2023
965f3db
Merge branch 'master' into tune/pbt-bohb-pause
Aug 22, 2023
fa89632
default 0
Aug 22, 2023
190df4f
[train/tune] Remove save_to_object/restore_from_object
Aug 22, 2023
138f92d
Fixes
Aug 22, 2023
b674dd2
avoid variable name conflict
Aug 22, 2023
b0a1e57
Merge remote-tracking branch 'upstream/master' into tune/remove-save-…
Aug 23, 2023
e6ac302
fix last test
Aug 23, 2023
55f1b84
Merge remote-tracking branch 'upstream/master' into tune/pbt-bohb-pause
Aug 23, 2023
4b624c5
Merge branch 'tune/remove-save-restore-obj' into tune/pbt-bohb-pause
Aug 23, 2023
8c87077
fix last test
Aug 23, 2023
11966c0
Merge remote-tracking branch 'upstream/master' into tune/pbt-bohb-pause
Aug 23, 2023
40819b0
bohb unpause
Aug 23, 2023
031ea23
pbt tests for storage
Aug 23, 2023
209ff6a
fix checkpoint test
Aug 23, 2023
d6839b6
more fixes
Aug 23, 2023
8484c5a
Merge remote-tracking branch 'upstream/master' into tune/pbt-bohb-pause
Aug 23, 2023
101d053
Fix hashing
Aug 23, 2023
0625416
exclude pbt_transformers
Aug 23, 2023
14f2d42
default 0
Aug 23, 2023
04f7c66
fix examples
Aug 23, 2023
75cdcbd
fix some tests
Aug 23, 2023
0b2ee3f
Merge remote-tracking branch 'upstream/master' into tune/pbt-bohb-pause
Aug 23, 2023
359aaad
review
Aug 23, 2023
2515831
Remove changes to old codepath
Aug 23, 2023
7769bae
Merge remote-tracking branch 'upstream/master' into tune/pbt-bohb-pause
Aug 24, 2023
9fc4cc6
remove empty pipeline
Aug 24, 2023
6cbece0
Cache decision in pause
Aug 24, 2023
9ae6dfd
Exploit
Aug 24, 2023
9572d03
Fix trial.checkpoint
Aug 24, 2023
77b4ae9
fix tests
Aug 24, 2023
c3bf12b
review
Aug 24, 2023
80e8a6d
Revert
Aug 24, 2023
d490db5
Merge remote-tracking branch 'upstream/master' into tune/pbt-bohb-pause
Aug 24, 2023
25927f7
Merge branch 'master' into tune/pbt-bohb-pause
krfricke Aug 24, 2023
f697157
Merge remote-tracking branch 'upstream/master' into tune/pbt-bohb-pause
Aug 25, 2023
af8e44c
Update build files, resolve merge logic conflict
Aug 25, 2023
3ba628e
Merge remote-tracking branch 'origin/tune/pbt-bohb-pause' into tune/p…
Aug 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
undo pause logic
Signed-off-by: Kai Fricke <kai@anyscale.com>
  • Loading branch information
Kai Fricke committed Aug 21, 2023
commit 964247e6f7a7e6ac8d3739af6e5b4bea6f833333
10 changes: 5 additions & 5 deletions python/ray/tune/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -359,31 +359,31 @@ py_test(
size = "large",
srcs = ["tests/test_trial_scheduler.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive", "medium_instance", "no_new_storage"],
tags = ["team:ml", "exclusive", "medium_instance"],
)

py_test(
name = "test_trial_scheduler_pbt",
size = "large",
srcs = ["tests/test_trial_scheduler_pbt.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive", "medium_instance", "no_new_storage"],
tags = ["team:ml", "exclusive", "medium_instance"],
)

py_test(
name = "test_trial_scheduler_resource_changing",
size = "small",
srcs = ["tests/test_trial_scheduler_resource_changing.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive", "no_new_storage"],
tags = ["team:ml", "exclusive"],
)

py_test(
name = "test_tune_restore_warm_start",
size = "large",
srcs = ["tests/test_tune_restore_warm_start.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive", "no_new_storage"],
tags = ["team:ml", "exclusive"],
)

py_test(
Expand Down Expand Up @@ -565,7 +565,7 @@ py_test(
size = "medium",
srcs = ["examples/bohb_example.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive", "example", "no_new_storage"]
tags = ["team:ml", "exclusive", "example"]
)

py_test(
Expand Down
65 changes: 21 additions & 44 deletions python/ray/tune/execution/tune_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -1518,19 +1518,18 @@ def _process_trial_failure(
exception: Exception prior to invoking this method.
"""
self._has_errored = True
if trial.status == Trial.RUNNING:
if trial.should_recover():
self._try_recover(trial, exc=exception)
self._callbacks.on_trial_recover(
iteration=self._iteration, trials=self._trials, trial=trial
)
else:
self._scheduler_alg.on_trial_error(self, trial)
self._search_alg.on_trial_complete(trial.trial_id, error=True)
self._schedule_trial_stop(trial, exception=exception)
self._callbacks.on_trial_error(
iteration=self._iteration, trials=self._trials, trial=trial
)
if trial.status == Trial.RUNNING and trial.should_recover():
self._try_recover(trial, exc=exception)
self._callbacks.on_trial_recover(
iteration=self._iteration, trials=self._trials, trial=trial
)
elif trial.status in {Trial.RUNNING, Trial.PENDING}:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trials now may be PENDING and failing - e.g. when a save resolves late

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is if a trial is paused, schedules a save, then gets unpaused (set to pending) and then the save fails?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right!

self._scheduler_alg.on_trial_error(self, trial)
self._search_alg.on_trial_complete(trial.trial_id, error=True)
self._schedule_trial_stop(trial, exception=exception)
self._callbacks.on_trial_error(
iteration=self._iteration, trials=self._trials, trial=trial
)

def _schedule_trial_stop(self, trial: Trial, exception: Optional[Exception] = None):
if trial.status == Trial.ERROR:
Expand Down Expand Up @@ -1619,38 +1618,13 @@ def _schedule_graceful_trial_stop(self, trial: Trial):
self._schedule_trial_stop(trial)

def _schedule_trial_pause(self, trial: Trial, should_checkpoint: bool = True):
if _use_storage_context():
if trial not in self._trial_to_actor:
logger.debug(
f"Trial PAUSE requested for trial {trial} but trial is already "
f"stopping. Ignoring."
)
return

if should_checkpoint:
# We need to wait for the save to finish before stopping the trial.
def stop_after_save_result(*args, **kwargs):
self._on_saving_result(*args, **kwargs)
self._schedule_trial_stop(trial)
self._set_trial_status(trial, Trial.PAUSED)

# NOTE: Ensure that the trial is PAUSED while it's saving a checkpoint.
self._set_trial_status(trial, Trial.PAUSED)
self._schedule_trial_task(
trial=trial,
method_name="save",
on_result=stop_after_save_result,
on_error=self._trial_task_failure,
)
trial.temporary_state.saving_to = True
else:
self._schedule_trial_stop(trial)
self._set_trial_status(trial, Trial.PAUSED)

return

if should_checkpoint:
self._schedule_trial_save(trial, storage=CheckpointStorage.MEMORY)
self._schedule_trial_save(
trial,
storage=CheckpointStorage.PERSISTENT
if _use_storage_context()
else CheckpointStorage.MEMORY,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to touch the old code

)
self._schedule_trial_stop(trial)
self._set_trial_status(trial, Trial.PAUSED)

Expand Down Expand Up @@ -1981,6 +1955,9 @@ def _process_trial_save(

try:
if _use_storage_context() and isinstance(checkpoint_value, _TrainingResult):
if not checkpoint_value.checkpoint:
logger.debug(f"Got empty checkpoint for trial {trial}")
return
try:
self._callbacks.on_checkpoint(
iteration=self._iteration,
Expand Down
4 changes: 2 additions & 2 deletions python/ray/tune/schedulers/pbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ def _checkpoint_or_exploit(
state.last_checkpoint = trial.checkpoint
else:
state.last_checkpoint = tune_controller._schedule_trial_save(
trial, CheckpointStorage.MEMORY, result=state.last_result
trial, CheckpointStorage.PERSISTENT, result=state.last_result
)
self._num_checkpoints += 1
else:
Expand Down Expand Up @@ -1089,7 +1089,7 @@ def on_trial_result(
)

checkpoint = tune_controller._schedule_trial_save(
trial, CheckpointStorage.MEMORY, result=result
trial, CheckpointStorage.PERSISTENT, result=result
)

new_tag = _make_experiment_tag(self.experiment_tag, new_config, new_config)
Expand Down
1 change: 0 additions & 1 deletion python/ray/tune/trainable/function_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,6 @@ def save_checkpoint(self, checkpoint_dir: str = ""):
if _use_storage_context():
# TRAIN -> SAVE remote calls get processed sequentially,
# so `_last_training_result.checkpoint` holds onto the latest ckpt.
assert self._last_training_result.checkpoint
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we change this to

assert self._last_training_result is None or self._last_training_result.checkpoint
# and maybe tell the user to file an issue if they hit this

return self._last_training_result

checkpoint = self._status_reporter.get_checkpoint()
Expand Down