Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train] New persistence mode: Remove some legacy air.Checkpoint dependencies #39049

Merged
merged 14 commits into from
Aug 31, 2023
23 changes: 18 additions & 5 deletions doc/source/tune/doc_code/fault_tolerance.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# flake8: noqa

# __ft_initial_run_start__
import json
import os
import tempfile

from ray import train, tune
from ray.train import Checkpoint
Expand All @@ -10,15 +12,24 @@
def trainable(config):
# Checkpoint loading
checkpoint = train.get_checkpoint()
start = 1 if not checkpoint else checkpoint.to_dict()["epoch"] + 1
start = 1
if checkpoint:
with checkpoint.as_directory() as checkpoint_dir:
with open(os.path.join(checkpoint_dir, "checkpoint.json"), "r") as f:
state = json.load(f)
start = state["epoch"] + 1

for epoch in range(start, config["num_epochs"]):
# Do some training...

# Checkpoint saving
train.report(
{"epoch": epoch}, checkpoint=Checkpoint.from_dict({"epoch": epoch})
)
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
with open(os.path.join(temp_checkpoint_dir, "checkpoint.json"), "w") as f:
json.dump({"epoch": epoch}, f)
train.report(
{"epoch": epoch},
checkpoint=Checkpoint.from_directory(temp_checkpoint_dir),
)


tuner = tune.Tuner(
Expand All @@ -29,9 +40,11 @@ def trainable(config):
name="tune_fault_tolerance_guide",
),
)
tuner.fit()
result_grid = tuner.fit()
# __ft_initial_run_end__

assert not result_grid.errors

# __ft_restored_run_start__
tuner = tune.Tuner.restore(
os.path.expanduser("~/ray_results/tune_fault_tolerance_guide"),
Expand Down
Loading