Skip to content

Commit eccc192

Browse files
[train] New persistence mode: Sanity-check release test (ray-project#39354)
Signed-off-by: Justin Yu <justinvyu@anyscale.com> Co-authored-by: matthewdeng <matt@anyscale.com>
1 parent dc3d163 commit eccc192

9 files changed

+525
-58
lines changed

python/ray/train/tests/test_new_persistence.py

+87-27
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from contextlib import contextmanager
2+
import logging
23
import os
34
from pathlib import Path
45
import pickle
@@ -13,26 +14,53 @@
1314

1415
import ray
1516
from ray import train, tune
17+
from ray._private.test_utils import simulate_storage
1618
from ray.air._internal.uri_utils import URI
1719
from ray.air.constants import EXPR_RESULT_FILE
18-
from ray.train._internal.storage import _download_from_fs_path, StorageContext
20+
from ray.train._internal.storage import (
21+
_delete_fs_path,
22+
_download_from_fs_path,
23+
StorageContext,
24+
)
1925
from ray.train._checkpoint import Checkpoint
2026
from ray.train.base_trainer import TrainingFailedError
2127
from ray.train.constants import RAY_AIR_NEW_PERSISTENCE_MODE
2228
from ray.train.data_parallel_trainer import DataParallelTrainer
2329
from ray.tune.trainable.trainable import _DICT_CHECKPOINT_FILE_NAME
2430

25-
from ray.train.tests.util import mock_s3_bucket_uri
2631

32+
class TestConstants:
33+
NUM_ITERATIONS = 6 # == num_checkpoints == num_artifacts
34+
NUM_TRIALS = 2
35+
NUM_WORKERS = 3
2736

28-
_SCORE_KEY = "score"
29-
NUM_ITERATIONS = 6 # == num_checkpoints == num_artifacts
30-
NUM_TRIALS = 2
31-
NUM_WORKERS = 3
37+
SCORE_KEY = "score"
3238

3339

3440
@contextmanager
35-
def dummy_context_manager():
41+
def mock_s3_bucket_uri():
42+
port = 5002
43+
region = "us-west-2"
44+
with simulate_storage("s3", port=port, region=region) as s3_uri:
45+
import boto3
46+
47+
s3 = boto3.client(
48+
"s3", region_name=region, endpoint_url=f"http://localhost:{port}"
49+
)
50+
# Bucket name will be autogenerated/unique per test
51+
bucket_name = URI(s3_uri).name
52+
s3.create_bucket(
53+
Bucket=bucket_name,
54+
CreateBucketConfiguration={"LocationConstraint": region},
55+
)
56+
# Disable server HTTP request logging
57+
logging.getLogger("werkzeug").setLevel(logging.WARNING)
58+
yield URI(s3_uri)
59+
logging.getLogger("werkzeug").setLevel(logging.INFO)
60+
61+
62+
@contextmanager
63+
def dummy_context_manager(*args, **kwargs):
3664
yield "dummy value"
3765

3866

@@ -164,16 +192,20 @@ def train_fn(config):
164192

165193
checkpoint = train.get_checkpoint()
166194
if checkpoint:
167-
with checkpoint.as_directory() as checkpoint_dir:
168-
with open(os.path.join(checkpoint_dir, "checkpoint.pkl"), "rb") as f:
169-
state = pickle.load(f)
195+
custom_restore_fn = config.get("custom_restore_fn")
196+
if custom_restore_fn:
197+
state = custom_restore_fn(checkpoint)
198+
else:
199+
with checkpoint.as_directory() as checkpoint_dir:
200+
with open(os.path.join(checkpoint_dir, "checkpoint.pkl"), "rb") as f:
201+
state = pickle.load(f)
170202
print("Loaded back state from checkpoint:", state)
171203
start = state["iter"] + 1
172204

173205
for i in range(start, config.get("num_iterations", 5)):
174-
time.sleep(0.25)
206+
time.sleep(config.get("time_per_iter", 0.25))
175207

176-
metrics = {"iter": i, _SCORE_KEY: i}
208+
metrics = {"iter": i, TestConstants.SCORE_KEY: i}
177209

178210
# Save an artifact in the local trial dir.
179211
rank = train.get_context().get_world_rank()
@@ -199,7 +231,10 @@ def train_fn(config):
199231
with open(os.path.join(temp_dir, checkpoint_file_name), "wb") as f:
200232
pickle.dump({"iter": i}, f)
201233

202-
train.report(metrics, checkpoint=Checkpoint.from_directory(temp_dir))
234+
with config.get("custom_save_fn", dummy_context_manager)(temp_dir):
235+
train.report(
236+
metrics, checkpoint=Checkpoint.from_directory(temp_dir)
237+
)
203238
# `train.report` should not have deleted this!
204239
assert os.path.exists(temp_dir)
205240

@@ -260,7 +295,12 @@ def load_checkpoint(self, checkpoint_dict_or_path):
260295
).read_text() == "dummy"
261296

262297

263-
def _resume_from_checkpoint(checkpoint: Checkpoint, expected_state: dict):
298+
def _resume_from_checkpoint(
299+
checkpoint: Checkpoint,
300+
expected_state: dict,
301+
storage_path: Optional[str] = None,
302+
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
303+
):
264304
print(f"\nStarting run with `resume_from_checkpoint`: {checkpoint}\n")
265305

266306
def assert_fn(config):
@@ -281,7 +321,11 @@ def assert_fn(config):
281321
trainer = DataParallelTrainer(
282322
assert_fn,
283323
scaling_config=train.ScalingConfig(num_workers=2),
284-
run_config=train.RunConfig(name="test_resume_from_checkpoint"),
324+
run_config=train.RunConfig(
325+
name="test_resume_from_checkpoint",
326+
storage_path=storage_path,
327+
storage_filesystem=storage_filesystem,
328+
),
285329
resume_from_checkpoint=checkpoint,
286330
)
287331
result = trainer.fit()
@@ -291,6 +335,9 @@ def assert_fn(config):
291335
result.checkpoint.path
292336
).name == StorageContext._make_checkpoint_dir_name(0)
293337

338+
# Clean up this run's experiment directory immediately after.
339+
_delete_fs_path(result.filesystem, Path(result.path).parent.as_posix())
340+
294341

295342
def _assert_storage_contents(
296343
local_inspect_dir: Path,
@@ -299,7 +346,10 @@ def _assert_storage_contents(
299346
trainable_name: str,
300347
test_trainer: bool,
301348
no_checkpoint_ranks: List[int] = None,
349+
constants: type = TestConstants,
302350
):
351+
no_checkpoint_ranks = no_checkpoint_ranks or []
352+
303353
# Second, inspect the contents of the storage path
304354
storage_path_ls = list(local_inspect_dir.glob("*"))
305355
assert len(storage_path_ls) == 1 # Only expect 1 experiment dir
@@ -319,11 +369,13 @@ def _assert_storage_contents(
319369
assert (
320370
len(list(exp_dir.glob(f"{trainable_name}*"))) == 1
321371
if test_trainer
322-
else NUM_TRIALS
372+
else constants.NUM_TRIALS
323373
)
324374
for trial_dir in exp_dir.glob(f"{trainable_name}*"):
325375
# If set, expect num_to_keep. Otherwise, expect to see all of them.
326-
expected_num_checkpoints = checkpoint_config.num_to_keep or NUM_ITERATIONS
376+
expected_num_checkpoints = (
377+
checkpoint_config.num_to_keep or constants.NUM_ITERATIONS
378+
)
327379

328380
assert len(list(trial_dir.glob("checkpoint_*"))) == expected_num_checkpoints
329381
checkpoint_idxs = sorted(
@@ -335,7 +387,10 @@ def _assert_storage_contents(
335387
# Ex: If num_to_keep=2 out of 6 total checkpoints,
336388
# expect checkpoint_004 and checkpoint_005.
337389
assert checkpoint_idxs == list(
338-
range(NUM_ITERATIONS - expected_num_checkpoints, NUM_ITERATIONS)
390+
range(
391+
constants.NUM_ITERATIONS - expected_num_checkpoints,
392+
constants.NUM_ITERATIONS,
393+
)
339394
)
340395

341396
for checkpoint_dir in trial_dir.glob("checkpoint_*"):
@@ -353,12 +408,16 @@ def _assert_storage_contents(
353408
for checkpoint_shard in checkpoint_dir.glob(
354409
"checkpoint_shard-*.pkl"
355410
)
356-
} == {i for i in range(NUM_WORKERS) if i not in no_checkpoint_ranks}
411+
} == {
412+
i
413+
for i in range(constants.NUM_WORKERS)
414+
if i not in no_checkpoint_ranks
415+
}
357416

358417
if test_trainer:
359-
expected_num_artifacts = NUM_ITERATIONS * NUM_WORKERS
418+
expected_num_artifacts = constants.NUM_ITERATIONS * constants.NUM_WORKERS
360419
else:
361-
expected_num_artifacts = NUM_ITERATIONS
420+
expected_num_artifacts = constants.NUM_ITERATIONS
362421
assert len(list(trial_dir.glob("artifact-*"))) == expected_num_artifacts
363422

364423
# NOTE: This result file is synced by the driver.
@@ -419,7 +478,7 @@ def test_tuner(
419478
tuner = tune.Tuner(
420479
trainable,
421480
param_space={
422-
"num_iterations": NUM_ITERATIONS,
481+
"num_iterations": TestConstants.NUM_ITERATIONS,
423482
"fail_iters": [2, 4],
424483
# NOTE: This param is only used in the ClassTrainable.
425484
"save_checkpoint_as_dict": tune.grid_search([True, False]),
@@ -464,7 +523,7 @@ def test_tuner(
464523
experiment_fs_path = result_grid.experiment_path
465524
assert isinstance(result_grid.filesystem, pyarrow.fs.FileSystem), result_grid
466525
assert experiment_fs_path == os.path.join(storage_fs_path, exp_name)
467-
assert len(result_grid) == NUM_TRIALS
526+
assert len(result_grid) == TestConstants.NUM_TRIALS
468527
for result in result_grid:
469528
trial_fs_path = result.path
470529
assert isinstance(result.filesystem, pyarrow.fs.FileSystem), result
@@ -489,7 +548,7 @@ def test_tuner(
489548
train.CheckpointConfig(),
490549
train.CheckpointConfig(
491550
num_to_keep=1,
492-
checkpoint_score_attribute=_SCORE_KEY,
551+
checkpoint_score_attribute=TestConstants.SCORE_KEY,
493552
checkpoint_score_order="max",
494553
),
495554
],
@@ -538,14 +597,14 @@ def test_trainer(
538597
train_fn,
539598
train_loop_config={
540599
"in_trainer": True,
541-
"num_iterations": NUM_ITERATIONS,
600+
"num_iterations": TestConstants.NUM_ITERATIONS,
542601
"fail_iters": [2, 4],
543602
# TODO(justinvyu): This should be separated into its own test once
544603
# CI has been fully migrated.
545604
# Test that global rank 0 is not required to checkpoint.
546605
"no_checkpoint_ranks": no_checkpoint_ranks,
547606
},
548-
scaling_config=train.ScalingConfig(num_workers=NUM_WORKERS),
607+
scaling_config=train.ScalingConfig(num_workers=TestConstants.NUM_WORKERS),
549608
run_config=train.RunConfig(
550609
storage_path=storage_path,
551610
storage_filesystem=storage_filesystem,
@@ -574,7 +633,8 @@ def test_trainer(
574633
"RAY_AIR_LOCAL_CACHE_DIR", str(tmp_path / "resume_from_checkpoint")
575634
)
576635
_resume_from_checkpoint(
577-
result.checkpoint, expected_state={"iter": NUM_ITERATIONS - 1}
636+
result.checkpoint,
637+
expected_state={"iter": TestConstants.NUM_ITERATIONS - 1},
578638
)
579639

580640
local_inspect_dir, storage_fs_path = _get_local_inspect_dir(

python/ray/train/tests/util.py

-26
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
11
import contextlib
2-
import logging
32
import os
43
import tempfile
5-
from contextlib import contextmanager
64
from typing import Any, Dict, Type
75

86
import ray.cloudpickle as ray_pickle
97
from ray.train import Checkpoint
108
from ray.train._internal.storage import StorageContext
11-
from ray._private.test_utils import simulate_storage
12-
from ray.air._internal.uri_utils import URI
139

1410

1511
@contextlib.contextmanager
@@ -41,25 +37,3 @@ def mock_storage_context() -> StorageContext:
4137
storage.storage_local_path = storage_path
4238
os.makedirs(os.path.join(storage_path, exp_name, trial_name), exist_ok=True)
4339
return storage
44-
45-
46-
@contextmanager
47-
def mock_s3_bucket_uri():
48-
port = 5002
49-
region = "us-west-2"
50-
with simulate_storage("s3", port=port, region=region) as s3_uri:
51-
import boto3
52-
53-
s3 = boto3.client(
54-
"s3", region_name=region, endpoint_url=f"http://localhost:{port}"
55-
)
56-
# Bucket name will be autogenerated/unique per test
57-
bucket_name = URI(s3_uri).name
58-
s3.create_bucket(
59-
Bucket=bucket_name,
60-
CreateBucketConfiguration={"LocationConstraint": region},
61-
)
62-
# Disable server HTTP request logging
63-
logging.getLogger("werkzeug").setLevel(logging.WARNING)
64-
yield URI(s3_uri)
65-
logging.getLogger("werkzeug").setLevel(logging.INFO)

python/ray/tune/tests/test_experiment_analysis.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,8 @@
1717
from ray.tune.experiment import Trial
1818
from ray.tune.utils import flatten_dict
1919

20-
from ray.train.tests.util import (
21-
create_dict_checkpoint,
22-
load_dict_checkpoint,
23-
mock_s3_bucket_uri,
24-
)
20+
from ray.train.tests.util import create_dict_checkpoint, load_dict_checkpoint
21+
from ray.train.tests.test_new_persistence import mock_s3_bucket_uri
2522

2623

2724
NUM_TRIALS = 3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/bin/bash
2+
# This script is used to build an extra layer on top of the base anyscale/ray image
3+
# to run the train_multinode_persistence test.
4+
5+
set -exo pipefail
6+
7+
pip3 install -U torch fsspec s3fs gcsfs pyarrow>=9.0.0 pytest

release/release_tests.yaml

+31
Original file line numberDiff line numberDiff line change
@@ -3383,6 +3383,37 @@
33833383

33843384
alert: default
33853385

3386+
3387+
- name: train_multinode_persistence
3388+
group: Train tests
3389+
working_dir: train_tests/e2e
3390+
3391+
frequency: nightly
3392+
team: ml
3393+
3394+
cluster:
3395+
byod:
3396+
post_build_script: byod_train_persistence_test.sh
3397+
cluster_compute: compute_aws.yaml
3398+
3399+
run:
3400+
timeout: 3000
3401+
script: pytest -v test_persistence.py -s
3402+
3403+
wait_for_nodes:
3404+
num_nodes: 4
3405+
3406+
variations:
3407+
- __suffix__: aws
3408+
- __suffix__: gce
3409+
env: gce
3410+
frequency: manual
3411+
cluster:
3412+
cluster_compute: compute_gce.yaml
3413+
3414+
alert: default
3415+
3416+
33863417
########################
33873418
# Alpa tests
33883419
########################
+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
cloud_id: {{env["ANYSCALE_CLOUD_ID"]}}
2+
region: us-west-2
3+
4+
max_workers: 3
5+
6+
head_node_type:
7+
name: head_node
8+
instance_type: m5.2xlarge
9+
10+
worker_node_types:
11+
- name: worker_node
12+
instance_type: m5.2xlarge
13+
max_workers: 3
14+
min_workers: 3
15+
use_spot: false
16+
17+
aws:
18+
TagSpecifications:
19+
- ResourceType: "instance"
20+
Tags:
21+
- Key: ttl-hours
22+
Value: '24'
+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
cloud_id: {{env["ANYSCALE_CLOUD_ID"]}}
2+
region: us-west1
3+
allowed_azs:
4+
- us-west1-b
5+
6+
max_workers: 3
7+
8+
head_node_type:
9+
name: head_node
10+
instance_type: n2-standard-8
11+
12+
worker_node_types:
13+
- name: worker_node
14+
instance_type: n2-standard-8
15+
max_workers: 3
16+
min_workers: 3
17+
use_spot: false

0 commit comments

Comments
 (0)