1
1
from contextlib import contextmanager
2
+ import logging
2
3
import os
3
4
from pathlib import Path
4
5
import pickle
13
14
14
15
import ray
15
16
from ray import train , tune
17
+ from ray ._private .test_utils import simulate_storage
16
18
from ray .air ._internal .uri_utils import URI
17
19
from ray .air .constants import EXPR_RESULT_FILE
18
- from ray .train ._internal .storage import _download_from_fs_path , StorageContext
20
+ from ray .train ._internal .storage import (
21
+ _delete_fs_path ,
22
+ _download_from_fs_path ,
23
+ StorageContext ,
24
+ )
19
25
from ray .train ._checkpoint import Checkpoint
20
26
from ray .train .base_trainer import TrainingFailedError
21
27
from ray .train .constants import RAY_AIR_NEW_PERSISTENCE_MODE
22
28
from ray .train .data_parallel_trainer import DataParallelTrainer
23
29
from ray .tune .trainable .trainable import _DICT_CHECKPOINT_FILE_NAME
24
30
25
- from ray .train .tests .util import mock_s3_bucket_uri
26
31
32
+ class TestConstants :
33
+ NUM_ITERATIONS = 6 # == num_checkpoints == num_artifacts
34
+ NUM_TRIALS = 2
35
+ NUM_WORKERS = 3
27
36
28
- _SCORE_KEY = "score"
29
- NUM_ITERATIONS = 6 # == num_checkpoints == num_artifacts
30
- NUM_TRIALS = 2
31
- NUM_WORKERS = 3
37
+ SCORE_KEY = "score"
32
38
33
39
34
40
@contextmanager
35
- def dummy_context_manager ():
41
+ def mock_s3_bucket_uri ():
42
+ port = 5002
43
+ region = "us-west-2"
44
+ with simulate_storage ("s3" , port = port , region = region ) as s3_uri :
45
+ import boto3
46
+
47
+ s3 = boto3 .client (
48
+ "s3" , region_name = region , endpoint_url = f"http://localhost:{ port } "
49
+ )
50
+ # Bucket name will be autogenerated/unique per test
51
+ bucket_name = URI (s3_uri ).name
52
+ s3 .create_bucket (
53
+ Bucket = bucket_name ,
54
+ CreateBucketConfiguration = {"LocationConstraint" : region },
55
+ )
56
+ # Disable server HTTP request logging
57
+ logging .getLogger ("werkzeug" ).setLevel (logging .WARNING )
58
+ yield URI (s3_uri )
59
+ logging .getLogger ("werkzeug" ).setLevel (logging .INFO )
60
+
61
+
62
+ @contextmanager
63
+ def dummy_context_manager (* args , ** kwargs ):
36
64
yield "dummy value"
37
65
38
66
@@ -164,16 +192,20 @@ def train_fn(config):
164
192
165
193
checkpoint = train .get_checkpoint ()
166
194
if checkpoint :
167
- with checkpoint .as_directory () as checkpoint_dir :
168
- with open (os .path .join (checkpoint_dir , "checkpoint.pkl" ), "rb" ) as f :
169
- state = pickle .load (f )
195
+ custom_restore_fn = config .get ("custom_restore_fn" )
196
+ if custom_restore_fn :
197
+ state = custom_restore_fn (checkpoint )
198
+ else :
199
+ with checkpoint .as_directory () as checkpoint_dir :
200
+ with open (os .path .join (checkpoint_dir , "checkpoint.pkl" ), "rb" ) as f :
201
+ state = pickle .load (f )
170
202
print ("Loaded back state from checkpoint:" , state )
171
203
start = state ["iter" ] + 1
172
204
173
205
for i in range (start , config .get ("num_iterations" , 5 )):
174
- time .sleep (0.25 )
206
+ time .sleep (config . get ( "time_per_iter" , 0.25 ) )
175
207
176
- metrics = {"iter" : i , _SCORE_KEY : i }
208
+ metrics = {"iter" : i , TestConstants . SCORE_KEY : i }
177
209
178
210
# Save an artifact in the local trial dir.
179
211
rank = train .get_context ().get_world_rank ()
@@ -199,7 +231,10 @@ def train_fn(config):
199
231
with open (os .path .join (temp_dir , checkpoint_file_name ), "wb" ) as f :
200
232
pickle .dump ({"iter" : i }, f )
201
233
202
- train .report (metrics , checkpoint = Checkpoint .from_directory (temp_dir ))
234
+ with config .get ("custom_save_fn" , dummy_context_manager )(temp_dir ):
235
+ train .report (
236
+ metrics , checkpoint = Checkpoint .from_directory (temp_dir )
237
+ )
203
238
# `train.report` should not have deleted this!
204
239
assert os .path .exists (temp_dir )
205
240
@@ -260,7 +295,12 @@ def load_checkpoint(self, checkpoint_dict_or_path):
260
295
).read_text () == "dummy"
261
296
262
297
263
- def _resume_from_checkpoint (checkpoint : Checkpoint , expected_state : dict ):
298
+ def _resume_from_checkpoint (
299
+ checkpoint : Checkpoint ,
300
+ expected_state : dict ,
301
+ storage_path : Optional [str ] = None ,
302
+ storage_filesystem : Optional [pyarrow .fs .FileSystem ] = None ,
303
+ ):
264
304
print (f"\n Starting run with `resume_from_checkpoint`: { checkpoint } \n " )
265
305
266
306
def assert_fn (config ):
@@ -281,7 +321,11 @@ def assert_fn(config):
281
321
trainer = DataParallelTrainer (
282
322
assert_fn ,
283
323
scaling_config = train .ScalingConfig (num_workers = 2 ),
284
- run_config = train .RunConfig (name = "test_resume_from_checkpoint" ),
324
+ run_config = train .RunConfig (
325
+ name = "test_resume_from_checkpoint" ,
326
+ storage_path = storage_path ,
327
+ storage_filesystem = storage_filesystem ,
328
+ ),
285
329
resume_from_checkpoint = checkpoint ,
286
330
)
287
331
result = trainer .fit ()
@@ -291,6 +335,9 @@ def assert_fn(config):
291
335
result .checkpoint .path
292
336
).name == StorageContext ._make_checkpoint_dir_name (0 )
293
337
338
+ # Clean up this run's experiment directory immediately after.
339
+ _delete_fs_path (result .filesystem , Path (result .path ).parent .as_posix ())
340
+
294
341
295
342
def _assert_storage_contents (
296
343
local_inspect_dir : Path ,
@@ -299,7 +346,10 @@ def _assert_storage_contents(
299
346
trainable_name : str ,
300
347
test_trainer : bool ,
301
348
no_checkpoint_ranks : List [int ] = None ,
349
+ constants : type = TestConstants ,
302
350
):
351
+ no_checkpoint_ranks = no_checkpoint_ranks or []
352
+
303
353
# Second, inspect the contents of the storage path
304
354
storage_path_ls = list (local_inspect_dir .glob ("*" ))
305
355
assert len (storage_path_ls ) == 1 # Only expect 1 experiment dir
@@ -319,11 +369,13 @@ def _assert_storage_contents(
319
369
assert (
320
370
len (list (exp_dir .glob (f"{ trainable_name } *" ))) == 1
321
371
if test_trainer
322
- else NUM_TRIALS
372
+ else constants . NUM_TRIALS
323
373
)
324
374
for trial_dir in exp_dir .glob (f"{ trainable_name } *" ):
325
375
# If set, expect num_to_keep. Otherwise, expect to see all of them.
326
- expected_num_checkpoints = checkpoint_config .num_to_keep or NUM_ITERATIONS
376
+ expected_num_checkpoints = (
377
+ checkpoint_config .num_to_keep or constants .NUM_ITERATIONS
378
+ )
327
379
328
380
assert len (list (trial_dir .glob ("checkpoint_*" ))) == expected_num_checkpoints
329
381
checkpoint_idxs = sorted (
@@ -335,7 +387,10 @@ def _assert_storage_contents(
335
387
# Ex: If num_to_keep=2 out of 6 total checkpoints,
336
388
# expect checkpoint_004 and checkpoint_005.
337
389
assert checkpoint_idxs == list (
338
- range (NUM_ITERATIONS - expected_num_checkpoints , NUM_ITERATIONS )
390
+ range (
391
+ constants .NUM_ITERATIONS - expected_num_checkpoints ,
392
+ constants .NUM_ITERATIONS ,
393
+ )
339
394
)
340
395
341
396
for checkpoint_dir in trial_dir .glob ("checkpoint_*" ):
@@ -353,12 +408,16 @@ def _assert_storage_contents(
353
408
for checkpoint_shard in checkpoint_dir .glob (
354
409
"checkpoint_shard-*.pkl"
355
410
)
356
- } == {i for i in range (NUM_WORKERS ) if i not in no_checkpoint_ranks }
411
+ } == {
412
+ i
413
+ for i in range (constants .NUM_WORKERS )
414
+ if i not in no_checkpoint_ranks
415
+ }
357
416
358
417
if test_trainer :
359
- expected_num_artifacts = NUM_ITERATIONS * NUM_WORKERS
418
+ expected_num_artifacts = constants . NUM_ITERATIONS * constants . NUM_WORKERS
360
419
else :
361
- expected_num_artifacts = NUM_ITERATIONS
420
+ expected_num_artifacts = constants . NUM_ITERATIONS
362
421
assert len (list (trial_dir .glob ("artifact-*" ))) == expected_num_artifacts
363
422
364
423
# NOTE: This result file is synced by the driver.
@@ -419,7 +478,7 @@ def test_tuner(
419
478
tuner = tune .Tuner (
420
479
trainable ,
421
480
param_space = {
422
- "num_iterations" : NUM_ITERATIONS ,
481
+ "num_iterations" : TestConstants . NUM_ITERATIONS ,
423
482
"fail_iters" : [2 , 4 ],
424
483
# NOTE: This param is only used in the ClassTrainable.
425
484
"save_checkpoint_as_dict" : tune .grid_search ([True , False ]),
@@ -464,7 +523,7 @@ def test_tuner(
464
523
experiment_fs_path = result_grid .experiment_path
465
524
assert isinstance (result_grid .filesystem , pyarrow .fs .FileSystem ), result_grid
466
525
assert experiment_fs_path == os .path .join (storage_fs_path , exp_name )
467
- assert len (result_grid ) == NUM_TRIALS
526
+ assert len (result_grid ) == TestConstants . NUM_TRIALS
468
527
for result in result_grid :
469
528
trial_fs_path = result .path
470
529
assert isinstance (result .filesystem , pyarrow .fs .FileSystem ), result
@@ -489,7 +548,7 @@ def test_tuner(
489
548
train .CheckpointConfig (),
490
549
train .CheckpointConfig (
491
550
num_to_keep = 1 ,
492
- checkpoint_score_attribute = _SCORE_KEY ,
551
+ checkpoint_score_attribute = TestConstants . SCORE_KEY ,
493
552
checkpoint_score_order = "max" ,
494
553
),
495
554
],
@@ -538,14 +597,14 @@ def test_trainer(
538
597
train_fn ,
539
598
train_loop_config = {
540
599
"in_trainer" : True ,
541
- "num_iterations" : NUM_ITERATIONS ,
600
+ "num_iterations" : TestConstants . NUM_ITERATIONS ,
542
601
"fail_iters" : [2 , 4 ],
543
602
# TODO(justinvyu): This should be separated into its own test once
544
603
# CI has been fully migrated.
545
604
# Test that global rank 0 is not required to checkpoint.
546
605
"no_checkpoint_ranks" : no_checkpoint_ranks ,
547
606
},
548
- scaling_config = train .ScalingConfig (num_workers = NUM_WORKERS ),
607
+ scaling_config = train .ScalingConfig (num_workers = TestConstants . NUM_WORKERS ),
549
608
run_config = train .RunConfig (
550
609
storage_path = storage_path ,
551
610
storage_filesystem = storage_filesystem ,
@@ -574,7 +633,8 @@ def test_trainer(
574
633
"RAY_AIR_LOCAL_CACHE_DIR" , str (tmp_path / "resume_from_checkpoint" )
575
634
)
576
635
_resume_from_checkpoint (
577
- result .checkpoint , expected_state = {"iter" : NUM_ITERATIONS - 1 }
636
+ result .checkpoint ,
637
+ expected_state = {"iter" : TestConstants .NUM_ITERATIONS - 1 },
578
638
)
579
639
580
640
local_inspect_dir , storage_fs_path = _get_local_inspect_dir (
0 commit comments