Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""P2P composite checkpoint argument."""

from typing import Any, final
from typing import final
from orbax.checkpoint import args as args_lib
from orbax.checkpoint.experimental.emergency.p2p import constants

Expand All @@ -30,10 +30,3 @@ def __init__(self, *args, **kwargs):
f'Composite must contain "{constants.STATE_SUBDIR}" key and no other'
f' keys: {list(self.keys())}'
)

def __setitem__(self, key: str, value: Any):
if key != constants.STATE_SUBDIR:
raise KeyError(
f'Invalid key: {key}. Only "{constants.STATE_SUBDIR}" is supported.'
)
self[key] = value
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Composite Checkpoint Manager handling P2P syncing with optional Persistent Fallback."""

import shutil
import threading
import time
from typing import Any, Iterable, Mapping, Optional, Sequence, Union, final
Expand Down Expand Up @@ -340,70 +341,123 @@ def save(

return p2p_saved or persistent_saved

def _restore_from_persistent_storage(
self, step: int, args: p2p_args_lib.Composite | None
) -> Any:
assert self._persistent_manager is not None
restore_args = args
if not restore_args:
restore_args = p2p_args_lib.Composite(
state=self._abstract_state,
)
return self._persistent_manager.restore(step, args=restore_args)

@override
def restore(
self, step: int | None, args: p2p_args_lib.Composite | None = None
) -> Union[Any, Mapping[str, Any], p2p_args_lib.Composite, None]:
start_time = time.time()
# 1. Registry Sync: Ensure P2P registry is current.
logging.info('1. Registry Sync - ensuring P2P registry is current.')
self._p2p.sync_registry_if_stale()

# TODO(exlin): Enhance restore logic:
# 1. Registry Sync: Ensure P2P registry is current.
# 2. Unified Restore: Attempt restore from local, then P2P.
# 3. Coordinated Fallback: Barrier sync before persistent storage restore
# if local/P2P fails.
if step is None:
step = self._local_manager.latest_step()
if step is None:
step = self._p2p.get_latest_complete_step()
# Prefer global view to ensure all hosts agree on the latest step.
# We explicitly do NOT fall back to local_manager.latest_step() here.
# If P2P doesn't see a complete step, falling back to local storage
# risks different hosts picking different steps, causing divergence.
step = self._p2p.get_latest_complete_step()

if step is None and self._persistent_manager:
# If P2P network is empty, check if we have a step in persistent
# storage. Persistent storage is assumed to be consistent across all
# hosts.
step = self._persistent_manager.latest_step()

if step is not None:
logging.info(
'Targeting restore step: %d (found in persistent storage)', step
)
# Found in persistent only. Directly restore.
restored_item = self._restore_from_persistent_storage(step, args)
logging.info(
'Restoration finished using Persistent Storage in %.2fs',
time.time() - start_time,
)
return restored_item

if step is None:
logging.warning('No restore step found in local storage or P2P registry.')
return None

logging.info('Targeting restore step: %d', step)
start_time = time.time()

# Strategy A: Local Restore
restored_item = None
restore_source = 'Unknown'

# 2. Local Restore
if step in self._local_manager.all_steps():
logging.info('Strategy A - Found locally. Restoring...')
logging.info('2. Local Restore - Found locally. Restoring...')
try:
res = self._local_manager.restore(step)
logging.info(
'Local restore finished in %.2fs', time.time() - start_time
)
return res
restored_item = self._local_manager.restore(step)
restore_source = 'Local'
except (OSError, ValueError) as e:
logging.exception('Local restore failed: %s', e)

# Strategy B: P2P Network Restore
logging.info('Strategy B - Not found locally. Attempting P2P fetch...')

fetch_succeeded = self._p2p.fetch(step)

if fetch_succeeded:
p2p_restore_dir = self._local_directory / constants.P2P_RESTORE_DIR_NAME
try:
res = self._local_manager.restore(step, directory=p2p_restore_dir)
logging.info(
'P2P restore finished in %.2fs',
time.time() - start_time,
)
return res
except (OSError, ValueError) as e:
logging.exception('P2P restore failed after download: %s', e)

# Strategy C: Persistent Storage Fallback
if self._persistent_manager:
logging.warning(
'Strategy C - P2P failed. Falling back to persistent storage.'
# 3. P2P Network Restore
if restored_item is None:
logging.info(
'3. P2P Network Restore - Not found locally or failed. Attempting P2P'
' fetch...'
)
restore_args = args
if not restore_args:
restore_args = p2p_args_lib.Composite(
state=self._abstract_state,
fetch_succeeded = self._p2p.fetch(step)

if fetch_succeeded:
p2p_restore_dir = self._local_directory / constants.P2P_RESTORE_DIR_NAME
try:
restored_item = self._local_manager.restore(
step, directory=p2p_restore_dir
)
restore_source = 'P2P'
except (OSError, ValueError) as e:
logging.exception('P2P restore failed after download: %s', e)
finally:
if p2p_restore_dir.exists():
logging.info(
'Removing P2P restore directory: %s after restoration is'
' complete',
p2p_restore_dir,
)
try:
shutil.rmtree(str(p2p_restore_dir))
except OSError as e:
logging.exception('Failed to remove P2P restore directory: %s', e)

# 4. Coordinated Fallback: Barrier sync before persistent storage restore
# if local/P2P fails.
if self._persistent_manager:
# If any host failed to restore from Local/P2P, we must fallback to
# persistent storage to ensure all hosts are in sync.
local_failure = 1 if restored_item is None else 0
any_failure_list = multihost.global_max([local_failure])
any_failure = any_failure_list[0] if any_failure_list else 0

if any_failure:
logging.warning(
'4. Coordinated Fallback - '
'At least one host failed Local/P2P restore. '
'All hosts falling back to persistent storage.'
)
restored_item = self._restore_from_persistent_storage(step, args)
restore_source = 'Persistent Storage'

return self._persistent_manager.restore(step, args=restore_args)
if restored_item is not None:
logging.info(
'Restoration finished using %s in %.2fs',
restore_source,
time.time() - start_time,
)
return restored_item

raise FileNotFoundError(f'All restore strategies failed for step {step}.')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def setUp(self):
self.enter_context(
mock.patch('socket.gethostname', return_value='localhost')
)
self.mock_global_max = self.enter_context(
mock.patch.object(multihost, 'global_max', return_value=[0])
)

# Mock instances returned by constructors
self.local_manager_instance = self.mock_local.return_value
Expand Down Expand Up @@ -243,6 +246,7 @@ def test_restore_strategy_c_persistent_fallback(self, process_index):
# P2P fetch fails
self.peer_selector_instance.get_source_peer.return_value = None
self.persistent_manager_instance.restore.return_value = {'a': 1}
self.mock_global_max.return_value = [1]

manager = p2p_cm.CheckpointManager(
self.mesh,
Expand Down Expand Up @@ -284,6 +288,147 @@ def test_restore_all_fail(self, _):
manager.restore(1)
manager.close()

@mock.patch.object(multihost, 'process_index', return_value=0)
def test_restore_coordinated_fallback_peer_failed(self, _):
"""Tests that we fall back to persistent if a peer fails, even if we succeeded locally."""
# 1. Setup: Local restore succeeds
self.local_manager_instance.scan_stored_steps.return_value = (0, [1])
self.local_manager_instance.all_steps.return_value = [1]
self.local_manager_instance.restore.return_value = {'a': 1} # Local success
self.mock_sync_global_data.return_value = []

# 2. Setup: Persistent manager returns a different value so we can verify
# fallback was used
self.persistent_manager_instance.restore.return_value = {'a': 999}

# 3. Setup: global_max returns 1, indicating SOMEONE failed
self.mock_global_max.return_value = [1]

manager = p2p_cm.CheckpointManager(
self.mesh,
self.abstract_state,
self.local_dir,
persistent_directory=self.persistent_dir,
)

# 4. Action
result = manager.restore(1)

# 5. Verification
# Should use persistent result
self.assertEqual(result, {'a': 999})

# Local restore WAS attempted
self.local_manager_instance.restore.assert_called_once_with(1)

# Persistent restore WAS called
self.persistent_manager_instance.restore.assert_called_once_with(
1, args=mock.ANY
)

# global_max was called with [0] because WE succeeded (my_failure=0)
self.mock_global_max.assert_called_once_with([0])

manager.close()

@mock.patch.object(multihost, 'process_index', return_value=0)
def test_restore_coordinated_fallback_local_failed(self, _):
"""Tests that we fall back to persistent if we fail locally."""
# 1. Setup: Local/P2P fail
self.local_manager_instance.scan_stored_steps.return_value = (0, [])
self.local_manager_instance.all_steps.return_value = []
self.mock_sync_global_data.return_value = []
self.peer_selector_instance.get_source_peer.return_value = None # P2P fails

self.persistent_manager_instance.restore.return_value = {'a': 999}
self.mock_global_max.return_value = [1] # Everyone knows someone failed

manager = p2p_cm.CheckpointManager(
self.mesh,
self.abstract_state,
self.local_dir,
persistent_directory=self.persistent_dir,
)

# 4. Action
result = manager.restore(1)

# 5. Verification
self.assertEqual(result, {'a': 999})
self.persistent_manager_instance.restore.assert_called_once()

# global_max was called with [1] because WE failed (my_failure=1)
self.mock_global_max.assert_called_once_with([1])

manager.close()

@mock.patch.object(multihost, 'process_index', return_value=0)
def test_restore_no_step_in_p2p_but_in_persistent(self, _):
"""Tests fallback to persistent step if P2P has no step."""
self.local_manager_instance.scan_stored_steps.return_value = (0, [])
self.mock_sync_global_data.return_value = []

# P2P has no latest complete step
self.peer_selector_instance.get_latest_complete_step.return_value = None

# Persistent has step 100
self.persistent_manager_instance.latest_step.return_value = 100
self.persistent_manager_instance.restore.return_value = {'a': 100}

manager = p2p_cm.CheckpointManager(
self.mesh,
self.abstract_state,
self.local_dir,
persistent_directory=self.persistent_dir,
)

# Reset mock to ensure we only check calls during restore.
self.mock_global_max.reset_mock()

result = manager.restore(None)

self.assertEqual(result, {'a': 100})
self.persistent_manager_instance.latest_step.assert_called_once()
self.persistent_manager_instance.restore.assert_called_once_with(
100, args=mock.ANY
)
# Persistent storage is trusted; no global sync needed.
self.mock_global_max.assert_not_called()

manager.close()

@mock.patch.object(p2p_cm.shutil, 'rmtree', autospec=True)
@mock.patch.object(multihost, 'process_index', return_value=0)
def test_restore_p2p_cleanup(self, unused_process_index, mock_rmtree):
"""Tests that P2P restore directory is cleaned up after restore."""
self.local_manager_instance.scan_stored_steps.return_value = (0, [])
self.local_manager_instance.all_steps.return_value = []
self.mock_sync_global_data.return_value = []

# P2P fetch succeeds
self.peer_selector_instance.get_source_peer.return_value = (
protocol.PeerDiscoveryInfo(
ip='1.2.3.4', port=5678, process_index=1, steps=[1]
)
)
self.p2p_node_instance.fetch_shard_from_peer.return_value = True
self.local_manager_instance.restore.return_value = {'a': 1}

manager = p2p_cm.CheckpointManager(
self.mesh,
self.abstract_state,
self.local_dir,
)

# Make p2p_restore_dir exist so cleanup is triggered
p2p_restore_dir = self.local_dir / service.constants.P2P_RESTORE_DIR_NAME
p2p_restore_dir.mkdir()

manager.restore(1)

mock_rmtree.assert_called_once_with(str(p2p_restore_dir))
manager.close()


if __name__ == '__main__':
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ def __init__(
def directory(self) -> epath.Path:
return self._directory

def latest_step(self) -> int | None:
return self._manager.latest_step()

def save(
self,
step: int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,6 @@ def fetch_shard_from_peer(
)
return False

# TODO(exlin): Remove this directory once the transfer is globally completed
# to save memory space.
stage_dir = self.directory / f'stage_{step}_{stored_process_index}'
if stage_dir.exists():
shutil.rmtree(str(stage_dir))
Expand Down
Loading
Loading