@@ -560,7 +560,9 @@ def restore_from_dir(
560
560
)
561
561
return self ._restore_tensorstore_state (state , ckpt_dir = ckpt_dir , spec = spec )
562
562
563
- def _restore_tensorstore_state (self , state , * , ckpt_dir : str , spec : CheckpointSpec ):
563
+ def _restore_tensorstore_state (
564
+ self , state , * , ckpt_dir : str , spec : CheckpointSpec , sync : bool = True
565
+ ):
564
566
restored_gda_values = self ._manager .deserialize (
565
567
shardings = spec .shardings ,
566
568
tensorstore_specs = spec .tensorstore_specs ,
@@ -584,7 +586,8 @@ def _restore_tensorstore_state(self, state, *, ckpt_dir: str, spec: CheckpointSp
584
586
restored_state = jax .tree_util .tree_unflatten (
585
587
jax .tree_util .tree_structure (state ), state_leaves
586
588
)
587
- multihost_utils .sync_global_devices (ckpt_dir )
589
+ if sync :
590
+ multihost_utils .sync_global_devices (ckpt_dir )
588
591
return restored_state
589
592
590
593
def stop (self ):
@@ -906,7 +909,11 @@ class Config(BaseCheckpointer.Config):
906
909
def _all_checkpoint_paths (cls , base_dir : str ) -> list [str ]:
907
910
"""Like `checkpoint_paths`, but also include non-committed checkpoints."""
908
911
try :
909
- return [path for path in fs .listdir (base_dir ) if path .startswith (STEP_PREFIX )]
912
+ return [
913
+ os .path .join (base_dir , path .rstrip ("/" ))
914
+ for path in fs .listdir (base_dir )
915
+ if path .startswith (STEP_PREFIX )
916
+ ]
910
917
except fs .NotFoundError :
911
918
return []
912
919
@@ -918,7 +925,7 @@ def checkpoint_paths(cls, base_dir: str) -> list[str]:
918
925
# gcs when there are many checkpoint files, even if using a "native" solution like
919
926
# `google-cloud-python` SDK.
920
927
paths = cls ._all_checkpoint_paths (base_dir )
921
- paths = [os .path .join (base_dir , path , "index" ) for path in paths ]
928
+ paths = [os .path .join (path , "index" ) for path in paths ]
922
929
with futures .ThreadPoolExecutor () as pool :
923
930
index_exists = pool .map (fs .exists , paths )
924
931
return [os .path .dirname (path ) for path , committed in zip (paths , index_exists ) if committed ]
@@ -1042,12 +1049,12 @@ def _run_garbage_collection(self):
1042
1049
remaining_dirs , gc_dirs = [], []
1043
1050
1044
1051
try :
1045
- step_dirs = [ step . rstrip ( "/" ) for step in self ._all_checkpoint_paths (cfg .dir )]
1052
+ step_dirs = self ._all_checkpoint_paths (cfg .dir )
1046
1053
except fs .NotFoundError :
1047
1054
step_dirs = []
1048
1055
1049
1056
# Gather all candidate checkpoint dirs, as well as all committed checkpoint dirs.
1050
- dirs = sorted ([ os . path . join ( cfg . dir , step ) for step in step_dirs ] , reverse = True )
1057
+ dirs = sorted (step_dirs , reverse = True )
1051
1058
committed_dirs = set (self .checkpoint_paths (cfg .dir ))
1052
1059
1053
1060
# Collect the recent non-committed checkpoints, since any of them could be in-progress.
0 commit comments