Skip to content

CheckpointServer: start in disallowed state + tests #90

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions torchft/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
import threading
import urllib.request
from abc import ABC, abstractmethod
from contextlib import contextmanager
from datetime import timedelta
from http.server import BaseHTTPRequestHandler
from typing import Generic, List, Optional, TypeVar
from typing import Generator, Generic, List, Optional, TypeVar

import torch

Expand Down Expand Up @@ -87,6 +88,25 @@ def shutdown(self, wait: bool = True) -> None:
"""


@contextmanager
def _timed_acquire(
lock: threading.Lock, timeout: timedelta
) -> Generator[None, None, None]:
"""
Acquire a lock with a timeout.

Args:
lock: the lock to acquire
timeout: the timeout to acquire the lock
"""
if not lock.acquire(timeout=timeout.total_seconds()):
raise TimeoutError(f"timed out acquiring lock after {timeout}")
try:
yield
finally:
lock.release()


class CheckpointServer(CheckpointTransport[T]):
"""
This is an HTTP server that can be used to transfer checkpoints
Expand All @@ -106,6 +126,10 @@ def __init__(self, timeout: timedelta) -> None:
self._timeout = timeout
self._state_dict: Optional[T] = None

# We don't allow checkpoints until the first send_checkpoint to avoid
# serving the default step=-1 invalid checkpoint.
self.disallow_checkpoint()

ckpt_server = self

class RequestHandler(BaseHTTPRequestHandler):
Expand All @@ -117,7 +141,9 @@ def do_GET(self):
# validate socket timeout is actually set
assert self.connection.gettimeout() == self.timeout

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to add an assert to check self._step == -1 (or maybe initialize step to None and check that) to fail quickly rather than to wait and timeout?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually are doing that which is the issue. Down below we check if self._step == the requested step

We end up with a race condition between send_checkpoint and the remote worker requesting the checkpoint which causes the requester to get an error due to mismatched step

This lock allows us to have some tolerance when requesting a checkpoint. We could busy loop/retry on caller side but the lock makes the code simpler and a bit more efficient

with ckpt_server._checkpoint_lock:
with _timed_acquire(
ckpt_server._checkpoint_lock, ckpt_server._timeout
):
step = ckpt_server._step

if self.path != f"/checkpoint/{step}":
Expand Down
50 changes: 49 additions & 1 deletion torchft/checkpointing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import threading
import urllib.error
from datetime import timedelta
from unittest import TestCase
from unittest.mock import MagicMock

from torchft.checkpointing import CheckpointServer
from torchft.checkpointing import CheckpointServer, _timed_acquire


class TestCheckpointing(TestCase):
Expand Down Expand Up @@ -55,3 +56,50 @@ def test_checkpoint_server(self) -> None:
)

server.shutdown()

def test_checkpoint_server_locking(self) -> None:
server = CheckpointServer(
timeout=timedelta(seconds=10),
)

# server should start up in a disallowed state this will block incoming
# requests until allow_checkpoint is called
self.assertTrue(server._checkpoint_lock.locked())
self.assertTrue(server._disallowed)
self.assertEqual(server._step, -1)

# allow requests
server.allow_checkpoint(1)

self.assertFalse(server._checkpoint_lock.locked())
self.assertFalse(server._disallowed)
self.assertEqual(server._step, 1)

# duplicate allow/disallow is fine
server.allow_checkpoint(2)
self.assertEqual(server._step, 2)

server.disallow_checkpoint()
server.disallow_checkpoint()
self.assertTrue(server._checkpoint_lock.locked())
self.assertTrue(server._disallowed)

server.shutdown()

def test_timed_acquire(self) -> None:
lock = threading.Lock()

with _timed_acquire(lock, timedelta(seconds=10)):
self.assertTrue(lock.locked())

self.assertFalse(lock.locked())

lock.acquire()

with self.assertRaisesRegex(
TimeoutError, r"timed out acquiring lock after 0.0"
):
with _timed_acquire(lock, timedelta(seconds=0.0)):
pass

self.assertTrue(lock.locked())