Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions checkpoint/orbax/checkpoint/experimental/emergency/p2p/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,38 @@
from typing import Any, final
from orbax.checkpoint import args as args_lib
from orbax.checkpoint.experimental.emergency.p2p import constants
from orbax.checkpoint.experimental.emergency.p2p import utils


def _check_data_iter(value: Any):
"""Checks if data_iter is valid."""
if utils.pygrain() is None:
raise ImportError(
'grain library is not available. Please install grain to use data_iter.'
)
if not isinstance(
value,
(
utils.pygrain().PyGrainCheckpointSave,
utils.pygrain().PyGrainCheckpointRestore,
),
):
raise TypeError(f'Unsupported type for data_iter: {type(value)}')


@final
class Composite(args_lib.Composite):
"""Composite argument that only supports 'state' key."""
"""Composite argument that supports 'state' and 'data_iter' keys."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if constants.STATE_SUBDIR not in self or len(self) > 1:
if constants.STATE_SUBDIR not in self:
raise ValueError(
f'Composite must contain "{constants.STATE_SUBDIR}" key and no other'
f' keys: {list(self.keys())}'
)

def __setitem__(self, key: str, value: Any):
if key != constants.STATE_SUBDIR:
raise KeyError(
f'Invalid key: {key}. Only "{constants.STATE_SUBDIR}" is supported.'
f'Composite must contain "{constants.STATE_SUBDIR}" key:'
f' {list(self.keys())}'
)
self[key] = value
for key in self:
if key not in [constants.STATE_SUBDIR, constants.DATA_ITER_KEY]:
raise ValueError(f'Unsupported key in Composite: {key}')
if key == constants.DATA_ITER_KEY:
_check_data_iter(self[key])
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Composite Checkpoint Manager handling P2P syncing with optional Persistent Fallback."""

import shutil
import threading
import time
from typing import Any, Iterable, Mapping, Optional, Sequence, Union, final
Expand Down Expand Up @@ -174,6 +175,14 @@ def get_latest_complete_step(self) -> int | None:
assert self._peer_selector is not None
return self._peer_selector.get_latest_complete_step()

def has_shard_for_step(self, step: int) -> bool:
"""Checks if this process's shard for a given step exists in the P2P network."""
assert self._peer_selector is not None
return (
self._peer_selector.get_source_peer(step, self._process_index)
is not None
)

def fetch(self, step: int) -> bool:
"""Fetches from peers, or waits for background fetch if it's in progress."""
if step == self._bkg_fetch_step:
Expand Down Expand Up @@ -340,70 +349,114 @@ def save(

return p2p_saved or persistent_saved

def _restore_from_persistent_storage(
self, step: int, args: p2p_args_lib.Composite
) -> Any:
"""Restores from persistent storage."""
assert self._persistent_manager is not None
logging.info('Restoring step %d from persistent storage.', step)
return self._persistent_manager.restore(step, args=args)

def _restore_from_local_or_p2p(
self, step: int, args: p2p_args_lib.Composite
) -> Any:
"""Restores from local storage or P2P network."""
logging.info('Attempting to restore step %d from local or P2P.', step)
if step in self._local_manager.all_steps():
logging.info('Step %d found in local storage.', step)
return self._local_manager.restore(step, args=args)
else:
logging.info('Step %d not found locally, fetching from P2P.', step)
p2p_restore_dir = self._local_directory / constants.P2P_RESTORE_DIR_NAME
try:
if self._p2p.fetch(step):
return self._local_manager.restore(
step, args=args, directory=p2p_restore_dir
)
else:
raise FileNotFoundError(
f'Failed to fetch step {step} from P2P network.'
)
finally:
if p2p_restore_dir.exists():
logging.info(
'Removing P2P restore directory: %s after restoration is'
' complete',
p2p_restore_dir,
)
try:
shutil.rmtree(str(p2p_restore_dir))
except OSError as e:
logging.exception('Failed to remove P2P restore directory: %s', e)

@override
def restore(
self, step: int | None, args: p2p_args_lib.Composite | None = None
self, step: int | None, args: p2p_args_lib.Composite | None
) -> Union[Any, Mapping[str, Any], p2p_args_lib.Composite, None]:
self._p2p.sync_registry_if_stale()
if args is None:
raise ValueError('The `args` parameter is required for restore.')

# TODO(exlin): Enhance restore logic:
start_time = time.time()
# 1. Registry Sync: Ensure P2P registry is current.
# 2. Unified Restore: Attempt restore from local, then P2P.
# 3. Coordinated Fallback: Barrier sync before persistent storage restore
# if local/P2P fails.
logging.info('Registry Sync - ensuring P2P registry is current.')
self._p2p.sync_registry_if_stale()

use_persistent = False
if step is None:
step = self._local_manager.latest_step()
if step is None:
step = self._p2p.get_latest_complete_step()
step = self._p2p.get_latest_complete_step()
logging.info('P2P latest_step=%s', step)

if step is None and self._persistent_manager:
step = self._persistent_manager.latest_step()
logging.info('Persistent latest_step=%s', step)
if step is not None:
use_persistent = True

if step is None:
logging.warning('No restore step found in local storage or P2P registry.')
return None

logging.info('Targeting restore step: %d', step)
start_time = time.time()

# Strategy A: Local Restore
if step in self._local_manager.all_steps():
logging.info('Strategy A - Found locally. Restoring...')
try:
res = self._local_manager.restore(step)
logging.info(
'Local restore finished in %.2fs', time.time() - start_time
)
return res
except (OSError, ValueError) as e:
logging.exception('Local restore failed: %s', e)

# Strategy B: P2P Network Restore
logging.info('Strategy B - Not found locally. Attempting P2P fetch...')

fetch_succeeded = self._p2p.fetch(step)

if fetch_succeeded:
p2p_restore_dir = self._local_directory / constants.P2P_RESTORE_DIR_NAME
try:
res = self._local_manager.restore(step, directory=p2p_restore_dir)
logging.info(
'P2P restore finished in %.2fs',
time.time() - start_time,
# 2. Try P2P/Local Restore
restored = None
restore_source = 'Unknown'
if not use_persistent:
if self._p2p.has_shard_for_step(step):
try:
restored = self._restore_from_local_or_p2p(step, args)
restore_source = 'P2P'
except (OSError, ValueError, FileNotFoundError) as e:
logging.exception('Local/P2P restore for step %d failed: %s', step, e)
else:
logging.warning(
'Step %d not available in P2P network, falling back to'
' persistent storage.',
step,
)
return res
except (OSError, ValueError) as e:
logging.exception('P2P restore failed after download: %s', e)

# Strategy C: Persistent Storage Fallback
# 3. Coordinated Fallback to Persistent Storage
if self._persistent_manager:
logging.warning(
'Strategy C - P2P failed. Falling back to persistent storage.'
# If any host failed local/P2P restore, all hosts must use persistent.
fallback_to_persistent = 1 if restored is None else 0
any_host_needs_fallback_list = multihost.global_max(
[fallback_to_persistent]
)
restore_args = args
if not restore_args:
restore_args = p2p_args_lib.Composite(
state=self._abstract_state,
if any_host_needs_fallback_list and any_host_needs_fallback_list[0]:
logging.warning(
'At least one host failed Local/P2P restore or step not available'
' in P2P. All hosts falling back to persistent storage.'
)
restored = self._restore_from_persistent_storage(step, args)
restore_source = 'Persistent Storage'

return self._persistent_manager.restore(step, args=restore_args)
if restored is not None:
logging.info(
'Restoration finished using %s in %.2fs',
restore_source,
time.time() - start_time,
)
return restored

raise FileNotFoundError(f'All restore strategies failed for step {step}.')

Expand Down
Loading
Loading