Skip to content

Commit

Permalink
[train] New persistence mode: Remove some legacy air.Checkpoint dep…
Browse files Browse the repository at this point in the history
…endencies (ray-project#39049)

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Jim Thompson <jimthompson5802@gmail.com>
  • Loading branch information
justinvyu authored and jimthompson5802 committed Sep 12, 2023
1 parent a0a4448 commit 0480c13
Show file tree
Hide file tree
Showing 15 changed files with 161 additions and 308 deletions.
23 changes: 18 additions & 5 deletions doc/source/tune/doc_code/fault_tolerance.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# flake8: noqa

# __ft_initial_run_start__
import json
import os
import tempfile

from ray import train, tune
from ray.train import Checkpoint
Expand All @@ -10,15 +12,24 @@
def trainable(config):
# Checkpoint loading
checkpoint = train.get_checkpoint()
start = 1 if not checkpoint else checkpoint.to_dict()["epoch"] + 1
start = 1
if checkpoint:
with checkpoint.as_directory() as checkpoint_dir:
with open(os.path.join(checkpoint_dir, "checkpoint.json"), "r") as f:
state = json.load(f)
start = state["epoch"] + 1

for epoch in range(start, config["num_epochs"]):
# Do some training...

# Checkpoint saving
train.report(
{"epoch": epoch}, checkpoint=Checkpoint.from_dict({"epoch": epoch})
)
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
with open(os.path.join(temp_checkpoint_dir, "checkpoint.json"), "w") as f:
json.dump({"epoch": epoch}, f)
train.report(
{"epoch": epoch},
checkpoint=Checkpoint.from_directory(temp_checkpoint_dir),
)


tuner = tune.Tuner(
Expand All @@ -29,9 +40,11 @@ def trainable(config):
name="tune_fault_tolerance_guide",
),
)
tuner.fit()
result_grid = tuner.fit()
# __ft_initial_run_end__

assert not result_grid.errors

# __ft_restored_run_start__
tuner = tune.Tuner.restore(
os.path.expanduser("~/ray_results/tune_fault_tolerance_guide"),
Expand Down
Loading

0 comments on commit 0480c13

Please sign in to comment.