Skip to content

Commit a3bf5e2

Browse files
authored
Minor changes to Checkpointer (apple#1024)
1 parent 55e1841 commit a3bf5e2

File tree

3 files changed

+26
-10
lines changed

3 files changed

+26
-10
lines changed

axlearn/common/array_serialization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ async def _run_deserializer():
459459
return fut.result()
460460

461461

462-
class BoundedDataShardedAsyncCheckpointManager(serialization.GlobalAsyncCheckpointManager):
462+
class BoundedDataShardedAsyncCheckpointManager(GlobalAsyncCheckpointManager):
463463
"""Similar to GlobalAsyncCheckpointManager but with few improvements:
464464
465465
1. Writing to tensorstore requires no host-to-host copy most of the time. This reduces host

axlearn/common/checkpointer.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,9 @@ def restore_from_dir(
560560
)
561561
return self._restore_tensorstore_state(state, ckpt_dir=ckpt_dir, spec=spec)
562562

563-
def _restore_tensorstore_state(self, state, *, ckpt_dir: str, spec: CheckpointSpec):
563+
def _restore_tensorstore_state(
564+
self, state, *, ckpt_dir: str, spec: CheckpointSpec, sync: bool = True
565+
):
564566
restored_gda_values = self._manager.deserialize(
565567
shardings=spec.shardings,
566568
tensorstore_specs=spec.tensorstore_specs,
@@ -584,7 +586,8 @@ def _restore_tensorstore_state(self, state, *, ckpt_dir: str, spec: CheckpointSp
584586
restored_state = jax.tree_util.tree_unflatten(
585587
jax.tree_util.tree_structure(state), state_leaves
586588
)
587-
multihost_utils.sync_global_devices(ckpt_dir)
589+
if sync:
590+
multihost_utils.sync_global_devices(ckpt_dir)
588591
return restored_state
589592

590593
def stop(self):
@@ -906,7 +909,11 @@ class Config(BaseCheckpointer.Config):
906909
def _all_checkpoint_paths(cls, base_dir: str) -> list[str]:
907910
"""Like `checkpoint_paths`, but also include non-committed checkpoints."""
908911
try:
909-
return [path for path in fs.listdir(base_dir) if path.startswith(STEP_PREFIX)]
912+
return [
913+
os.path.join(base_dir, path.rstrip("/"))
914+
for path in fs.listdir(base_dir)
915+
if path.startswith(STEP_PREFIX)
916+
]
910917
except fs.NotFoundError:
911918
return []
912919

@@ -918,7 +925,7 @@ def checkpoint_paths(cls, base_dir: str) -> list[str]:
918925
# gcs when there are many checkpoint files, even if using a "native" solution like
919926
# `google-cloud-python` SDK.
920927
paths = cls._all_checkpoint_paths(base_dir)
921-
paths = [os.path.join(base_dir, path, "index") for path in paths]
928+
paths = [os.path.join(path, "index") for path in paths]
922929
with futures.ThreadPoolExecutor() as pool:
923930
index_exists = pool.map(fs.exists, paths)
924931
return [os.path.dirname(path) for path, committed in zip(paths, index_exists) if committed]
@@ -1042,12 +1049,12 @@ def _run_garbage_collection(self):
10421049
remaining_dirs, gc_dirs = [], []
10431050

10441051
try:
1045-
step_dirs = [step.rstrip("/") for step in self._all_checkpoint_paths(cfg.dir)]
1052+
step_dirs = self._all_checkpoint_paths(cfg.dir)
10461053
except fs.NotFoundError:
10471054
step_dirs = []
10481055

10491056
# Gather all candidate checkpoint dirs, as well as all committed checkpoint dirs.
1050-
dirs = sorted([os.path.join(cfg.dir, step) for step in step_dirs], reverse=True)
1057+
dirs = sorted(step_dirs, reverse=True)
10511058
committed_dirs = set(self.checkpoint_paths(cfg.dir))
10521059

10531060
# Collect the recent non-committed checkpoints, since any of them could be in-progress.

axlearn/common/checkpointer_test.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,8 +1137,13 @@ def make_state(float_dtype):
11371137
),
11381138
)
11391139

1140-
@parameterized.parameters(jnp.float32, jnp.bfloat16, jnp.int32, jnp.int16)
1141-
def test_save_and_restore_from_dir_async(self, restore_floats_as: jnp.dtype):
1140+
@parameterized.product(
1141+
restore_floats_as=[jnp.float32, jnp.bfloat16, jnp.int32, jnp.int16],
1142+
max_concurrent_gb=[None, 1],
1143+
)
1144+
def test_save_and_restore_from_dir_async(
1145+
self, restore_floats_as: jnp.dtype, max_concurrent_gb: Optional[int]
1146+
):
11421147
mesh_shape = (1, 1)
11431148
if not test_utils.is_supported_mesh_shape(mesh_shape):
11441149
return
@@ -1148,7 +1153,11 @@ def make_state(float_dtype):
11481153

11491154
with _mesh(mesh_shape):
11501155
state = make_state(float_dtype=jnp.float32)
1151-
storage = TensorStoreStateStorage.default_config().instantiate()
1156+
storage = (
1157+
TensorStoreStateStorage.default_config()
1158+
.set(max_concurrent_gb=max_concurrent_gb)
1159+
.instantiate()
1160+
)
11521161
with tempfile.TemporaryDirectory() as root_dir:
11531162
step = 1000
11541163
# Save ckpt.

0 commit comments

Comments
 (0)