Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements gRPC Checkpointer #1005

Merged
merged 1 commit into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 59 additions & 8 deletions axlearn/common/array_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
https://github.com/google/orbax/blob/3cc343c63c769e4b2df44f3e57f6b5b43569df32/checkpoint/orbax/checkpoint/serialization.py
https://github.com/google/jax/blob/595a620804e810335a870e93975a78504b2e95e5/jax/experimental/array_serialization/serialization.py
"""

import asyncio
import functools
import threading
Expand Down Expand Up @@ -52,6 +51,21 @@ class _ShardInfo:
slice_arg: Optional[tuple[int, int, int]]
replica_count: int

def shard_coordinate(self):
"""Gets the shard coordinate according to the zarr format used by tensorstore."""
coords = []
for s in self.index:
if s.start is None:
coords.append(0)
continue
size = s.stop - s.start
assert s.start % size == 0
coords.append(s.start // size)
# Special case for scalar.
if len(coords) == 0:
return "0"
return ".".join(str(x) for x in coords)


# Tuple (and thus hashable) representation of a slice object (start, end, step).
_SliceTuple = tuple[Optional[int], Optional[int], Optional[int]]
Expand Down Expand Up @@ -145,7 +159,7 @@ def _transfer_to_host(data: Tensor) -> Tensor:
return data


async def _slice_shard_and_copy_to_host(shard_infos: list[_ShardInfo], d2h_future: futures.Future):
async def _slice_shard_and_copy_to_host(shard_infos: list[_ShardInfo]):
"""Slices each shard according to shard info and then copy the sliced result to host.

The .data field of each shard_info is modified in-place.
Expand All @@ -162,9 +176,6 @@ async def _slice_shard_and_copy_to_host(shard_infos: list[_ShardInfo], d2h_futur
# against consumers like tensorstore that would otherwise copy silently.
info.data = np.array(data, copy=False)

d2h_future.set_result(shard_infos)
await asyncio.sleep(0) # Allow other D2Hs to set result.


def _slice_fn(info: _ShardInfo) -> Tensor:
"""Performs slicing according to a shard_info and returns the sliced array."""
Expand All @@ -183,16 +194,22 @@ def _fix_metadata(tspec: dict[str, Any], shard_infos: list[_ShardInfo]):
"""Revises the medadata of a tensorspec based on `shard_infos`."""
if len(shard_infos) != 0:
# All shards have the same shape after data-sharding, so using [0] is sufficient.
tspec["chunks"] = np.array(np.maximum(1, shard_infos[0].data.shape))
tspec["chunks"] = tuple(int(x) for x in np.maximum(1, shard_infos[0].data.shape))
return tspec


class TensorstoreSpecModifier:
def __call__(self, spec: dict[str, Any], *, shard_infos: list[_ShardInfo]):
...


async def _async_serialize(
arr_inp: Tensor,
tensorstore_spec: dict[str, Any],
d2h_future: futures.Future,
*,
limiter: Optional[serialization._LimitInFlightBytes] = None,
limiter: Optional[serialization._LimitInFlightBytes],
tensorstore_spec_modifier: Optional[TensorstoreSpecModifier] = None,
max_data_shard_degree: int,
shard_threshold_bytes: int,
):
Expand Down Expand Up @@ -236,9 +253,15 @@ async def _async_serialize(
tensorstore_spec["dtype"] = jax.numpy.dtype(arr_inp.dtype).name

# Original `arr_inp` might be deleted after this point.
await _slice_shard_and_copy_to_host(shard_infos, d2h_future)
await _slice_shard_and_copy_to_host(shard_infos)
# Fix metadata after slicing to get the right shape.
_fix_metadata(tensorstore_spec["metadata"], shard_infos)
if tensorstore_spec_modifier is not None:
tensorstore_spec_modifier(tensorstore_spec, shard_infos=shard_infos)

# Set future after we updated tensorstore spec.
d2h_future.set_result(shard_infos)
await asyncio.sleep(0) # Allow other D2Hs to set result.

# `ts.open` runs twice for process 0 because for the first time, we just get the future to be
# awaited upon in the background thread. The second one runs with `assume_metadata=True` which
Expand Down Expand Up @@ -278,6 +301,7 @@ async def _run_serializer(
d2h_futures: list[futures.Future],
*,
max_concurrent_bytes: Optional[int] = None,
tensorstore_spec_modifier: Optional[TensorstoreSpecModifier] = None,
max_data_shard_degree: int,
shard_threshold_bytes: int,
):
Expand All @@ -296,6 +320,7 @@ async def _run_serializer(
limiter=limiter,
max_data_shard_degree=max_data_shard_degree,
shard_threshold_bytes=shard_threshold_bytes,
tensorstore_spec_modifier=tensorstore_spec_modifier,
),
arrays,
tensorstore_specs,
Expand Down Expand Up @@ -435,6 +460,23 @@ def __init__(
raise NotImplementedError("max_data_shard_degree cannot be 0.")
self._shard_threshold_bytes = shard_threshold_bytes or 0

def _tensorstore_spec_modifier(self, spec: dict[str, Any], *, shard_infos: list[_ShardInfo]):
"""A function that modifies the tensorstore spec for an array in-place.

This function will be called after tensorstore metadata is populated and the shard infos
for the array are computed.
"""
del spec, shard_infos

def _tensorstore_spec_log_fn(self, specs: list[dict[str, Any]]):
"""A function that will be called **once** after the tensorstore specs are populated.

Specifically, this function will be called **once** during the first checkpoint after
`self._tensorstore_spec_modifier` is invoked for each array. `specs` is a list of specs
corresponding the `arrays` argument in `self.serialize`.
"""
del specs

def serialize(
self,
arrays: list[Tensor],
Expand Down Expand Up @@ -478,6 +520,7 @@ def serialize(
arrays,
tensorstore_specs,
d2h_futures,
tensorstore_spec_modifier=self._tensorstore_spec_modifier, # type: ignore
max_concurrent_bytes=max_concurrent_bytes,
max_data_shard_degree=self._max_data_shard_degree,
shard_threshold_bytes=self._shard_threshold_bytes,
Expand All @@ -501,6 +544,14 @@ def serialize(
str(spec),
)
self._logged_spec = True
self._tensorstore_spec_log_fn(tensorstore_specs)

logging.info("D2H during save took %fs. Starting async commit.", time.time() - start_t)
self._start_async_commit(on_commit_callback)

def stop(self):
"""Disposes and cleanup any internal resources."""

def __del__(self):
super().__del__()
self.stop()
31 changes: 27 additions & 4 deletions axlearn/common/array_serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,12 @@ async def write(self, data: jax.Array, **_):
async def open_patch(*_, **__):
return FakeTs()

async def _copy_to_host_patch(shard_infos: list[_ShardInfo], d2h_future: futures.Future):
async def _copy_to_host_patch(shard_infos: list[_ShardInfo]):
nonlocal concurrent_bytes
for info in shard_infos:
concurrent_bytes += info.data.nbytes
# In-flight bytes should be lower than the expected max bytes
self.assertLessEqual(concurrent_bytes, expect_max_concurrent_bytes)
d2h_future.set_result(shard_infos)

manager = BoundedDataShardedAsyncCheckpointManager(max_concurrent_gb=max_concurrent_gb)
with (
Expand Down Expand Up @@ -328,8 +327,7 @@ def _verify_shard_info(
# single device array. If same, that means all shards should cover all
# indices of the original array.
out_array = np.empty_like(single_device_arr)
d2h_future = futures.Future()
asyncio.run(_slice_shard_and_copy_to_host(shard_infos, d2h_future))
asyncio.run(_slice_shard_and_copy_to_host(shard_infos))
for info in shard_infos:
out_array[info.index] = info.data
self.assertTrue(np.all(out_array == np.array(single_device_arr)))
Expand Down Expand Up @@ -404,3 +402,28 @@ def test_shard_info_fully_replicated(
self._verify_shard_info(
single_device_arr, arr, max_data_shard_degree, shard_threshold_bytes
)

@parameterized.parameters(
dict(
index=(slice(2, 4, None), slice(None, None, None)),
expected="1.0",
),
dict(
index=(slice(None, None, None), slice(2, 3, None), slice(None, None, None)),
expected="0.2.0",
),
dict(
index=(), # Scalar.
expected="0",
),
dict(
index=(slice(None, None, None),), # Replicated.
expected="0",
),
)
def test_shard_coordinate(self, index, expected):
data = jnp.zeros(())
self.assertEqual(
_ShardInfo(data=data, index=index, slice_arg=None, replica_count=1).shard_coordinate(),
expected,
)
45 changes: 32 additions & 13 deletions axlearn/common/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@
)
from axlearn.common.config import (
REQUIRED,
ConfigOr,
Configurable,
InstantiableConfig,
Required,
config_class,
config_for_function,
maybe_instantiate,
)
from axlearn.common.metrics import WeightedScalar
from axlearn.common.module import (
Expand Down Expand Up @@ -287,6 +289,11 @@ def stop(self):
raise NotImplementedError(type(self))


class IndexFileWriter(Protocol):
def __call__(self, ckpt_dir: str, index: Any):
...


def write_index_file(*, ckpt_dir: str, index: Any):
"""An on_commit_callback that writes an index file to ckpt_dir."""
index_path = os.path.join(ckpt_dir, "index")
Expand Down Expand Up @@ -551,7 +558,9 @@ def restore_from_dir(
maybe_restore_grain_savables(
spec.grain_ckpt_map, dir=os.path.join(ckpt_dir, f"grain_{jax.process_index()}")
)
return self._restore_tensorstore_state(state, ckpt_dir=ckpt_dir, spec=spec)

def _restore_tensorstore_state(self, state, *, ckpt_dir: str, spec: CheckpointSpec):
restored_gda_values = self._manager.deserialize(
shardings=spec.shardings,
tensorstore_specs=spec.tensorstore_specs,
Expand Down Expand Up @@ -890,6 +899,16 @@ class Config(BaseCheckpointer.Config):
storage: StateStorage.Config = TensorStoreStateStorage.default_config()
# A config that instantiates an optional SummaryWriter, and is used to log checkpoints.
summary_writer: Optional[SummaryWriter.Config] = None
# An optional custom index file writer that writes after checkpoint write is complete.
index_writer: Optional[ConfigOr[IndexFileWriter]] = None

@classmethod
def _all_checkpoint_paths(cls, base_dir: str) -> list[str]:
"""Like `checkpoint_paths`, but also include non-committed checkpoints."""
try:
return [path for path in fs.listdir(base_dir) if path.startswith(STEP_PREFIX)]
except fs.NotFoundError:
return []

@classmethod
def checkpoint_paths(cls, base_dir: str) -> list[str]:
Expand All @@ -898,14 +917,8 @@ def checkpoint_paths(cls, base_dir: str) -> list[str]:
# concurrent `exists` check for the index file can be several times faster than `glob` on
# gcs when there are many checkpoint files, even if using a "native" solution like
# `google-cloud-python` SDK.
try:
paths = fs.listdir(base_dir)
except fs.NotFoundError:
return []

paths = [
os.path.join(base_dir, path, "index") for path in paths if path.startswith(STEP_PREFIX)
]
paths = cls._all_checkpoint_paths(base_dir)
paths = [os.path.join(base_dir, path, "index") for path in paths]
with futures.ThreadPoolExecutor() as pool:
index_exists = pool.map(fs.exists, paths)
return [os.path.dirname(path) for path, committed in zip(paths, index_exists) if committed]
Expand Down Expand Up @@ -938,6 +951,10 @@ def cleanup_checkpoint(cls, ckpt_dir: str, *, sync: bool = True):
def __init__(self, cfg: Config, *, parent: Optional[Module]):
super().__init__(cfg, parent=parent)
cfg: Checkpointer.Config = self.config
if cfg.index_writer is None:
self._index_writer = write_index_file
else:
self._index_writer = maybe_instantiate(cfg.index_writer)

self._storage: StateStorage = cfg.storage.instantiate()
self._gc_stopping = None
Expand Down Expand Up @@ -1002,7 +1019,7 @@ def save(
raise ValueError(f"Out-of-range: {step}")
ckpt_dir = self.ckpt_dir(step)
self._storage.save_to_dir(
step=step, state=state, ckpt_dir=ckpt_dir, on_commit_callback=write_index_file
step=step, state=state, ckpt_dir=ckpt_dir, on_commit_callback=self._index_writer
)
if "summary_writer" in self.children:
self.summary_writer.log_checkpoint(
Expand All @@ -1025,9 +1042,7 @@ def _run_garbage_collection(self):
remaining_dirs, gc_dirs = [], []

try:
step_dirs = [
step.rstrip("/") for step in fs.listdir(cfg.dir) if step.startswith(STEP_PREFIX)
]
step_dirs = [step.rstrip("/") for step in self._all_checkpoint_paths(cfg.dir)]
except fs.NotFoundError:
step_dirs = []

Expand Down Expand Up @@ -1096,6 +1111,10 @@ def wait_until_finished(self):
"""See `BaseCheckpointer.wait_until_finished` docstring for details."""
self._storage.wait_until_finished()

def _index_exists(self, ckpt_dir: str):
ckpt_index = os.path.join(ckpt_dir, "index")
return fs.exists(ckpt_index)

def restore(
self,
*,
Expand All @@ -1111,7 +1130,7 @@ def restore(

def validate_and_restore(*, step: int, ckpt_dir: str):
ckpt_index = os.path.join(ckpt_dir, "index")
if not fs.exists(ckpt_index):
if not self._index_exists(ckpt_dir):
raise ValueError(
f"Checkpoint {ckpt_dir} is incomplete -- expected {ckpt_index} to be present."
)
Expand Down
Loading
Loading