Skip to content

Commit

Permalink
[train/tune] Separate storage checkpoint index bookkeeping (ray-proje…
Browse files Browse the repository at this point in the history
…ct#39927)

Checkpoint IDs are incremented in three different places: The `Trainable` (for class trainables), the `session` (for function trainables), and the `Trial` (on the driver). These are currently implicitly kept in sync. In the future, we may want to synchronize driver and trainable state via other means, or customize the checkpoint directory name to be populated by other metrics. In preparation for this, we separate out the checkpoint ID mutation into a subfunction that can be overwritten or (in a follow-up) provided or otherwise modified.

Signed-off-by: Kai Fricke <kai@anyscale.com>
  • Loading branch information
krfricke authored Sep 29, 2023
1 parent 6f99b2c commit 3135323
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/ray/train/_internal/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def _report_training_result(self, training_result: _TrainingResult) -> None:

# NOTE: This is where the coordinator AND workers increment their
# checkpoint index.
self.storage.current_checkpoint_index += 1
self.storage._increase_checkpoint_index(training_result.metrics)

# Add result to a thread-safe queue.
self.result_queue.put(training_result, block=True)
Expand Down
5 changes: 5 additions & 0 deletions python/ray/train/_internal/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,11 @@ def _check_validation_file(self):
"to the configured storage path."
)

def _increase_checkpoint_index(self, metrics: Dict):
# Per default, increase by 1. This can be overwritten to customize checkpoint
# directories.
self.current_checkpoint_index += 1

def persist_current_checkpoint(self, checkpoint: "Checkpoint") -> "Checkpoint":
"""Persists a given checkpoint to the current checkpoint path on the filesystem.
Expand Down
3 changes: 2 additions & 1 deletion python/ray/tune/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class Experiment:

# Keys that will be present in `public_spec` dict.
PUBLIC_KEYS = {"stop", "num_samples", "time_budget_s"}
_storage_context_cls = StorageContext

def __init__(
self,
Expand Down Expand Up @@ -201,7 +202,7 @@ def __init__(
if not name:
name = StorageContext.get_experiment_dir_name(run)

self.storage = StorageContext(
self.storage = self._storage_context_cls(
storage_path=storage_path,
storage_filesystem=storage_filesystem,
sync_config=sync_config,
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tune/experiment/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,7 @@ def on_checkpoint(self, checkpoint: Union[_TrackedCheckpoint, _TrainingResult]):
# Increment the checkpoint index to keep the checkpoint index in sync.
# This index will get restored when the trial is restored and will
# be passed to the Trainable as the starting checkpoint index.
self.storage.current_checkpoint_index += 1
self.storage._increase_checkpoint_index(checkpoint_result.metrics)
else:
self.run_metadata.checkpoint_manager.on_checkpoint(checkpoint)
self.invalidate_json_state()
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tune/trainable/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ def save(
# The checkpoint index needs to be incremented.
# NOTE: This is no longer using "iteration" as the folder indexing
# to be consistent with fn trainables.
self._storage.current_checkpoint_index += 1
self._storage._increase_checkpoint_index(metrics)

checkpoint_result = _TrainingResult(
checkpoint=persisted_checkpoint, metrics=metrics
Expand Down

0 comments on commit 3135323

Please sign in to comment.