Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ repos:
rev: v0.12.2
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
args: [--fix, --exit-non-zero-on-fix, --ignore, S603,]
- id: ruff-format
- repo: https://github.com/codespell-project/codespell
rev: v2.4.1
Expand Down
15 changes: 9 additions & 6 deletions checkpoint_engine/ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,7 @@ def __init__(
self._all_hosts = []
self._global_device_uuids: list[str] = []
self._mem_fraction = mem_fraction or 0.9
self._logger_rank = 0

Comment on lines +681 to 682
Copy link

Copilot AI Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _logger_rank initialization should be documented to explain when and how this value changes during execution.

Suggested change
self._logger_rank = 0
# _logger_rank determines which rank is responsible for logging output.
# By default, it is set to 0, meaning only rank 0 will perform logging.
# If logging from other ranks is required, this value can be changed accordingly.
self._logger_rank = 0

Copilot uses AI. Check for mistakes.
assert self._rank is not None and self._rank >= 0, self._rank
assert self._world_size and self._world_size > 0, self._world_size
Expand Down Expand Up @@ -706,8 +707,8 @@ def __init__(
torch.cuda.set_device(device_index)
self._device_uuid = _get_physical_gpu_id(device_index)

def _logger_rank0(self, msg: str):
if self._local_rank == 0:
def _logger_once(self, msg: str):
if self._local_rank == self._logger_rank:
logger.info(msg)

def get_metas(self) -> dict[int, MemoryBufferMetaList]:
Expand Down Expand Up @@ -871,10 +872,12 @@ def update(
try:
# if both ranks is None or [], it will use fully broadcast to update to all ranks
if not ranks:
self._logger_rank = 0
if self._auto_pg and not dist.is_initialized():
self.init_process_group()
self._update_per_bucket(checkpoint_name, req_func)
else:
self._logger_rank = ranks[0]
if not self._auto_pg and self._rank not in ranks:
return
if self._auto_pg:
Expand Down Expand Up @@ -936,15 +939,15 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
max_tensor_bytes = max(max_tensor_bytes, _align_size(meta.dtype, meta.shape))
free_bytes_divided_3 = free_bytes // (3 * _ALIGN_SIZE) * _ALIGN_SIZE
if max_tensor_bytes <= free_bytes_divided_3 and not disable_h2d_buffer:
self._logger_rank0(f"[rank{self._rank}] use h2d buffer")
self._logger_once(f"[rank{self._rank}] use h2d buffer")
# using h2d_buffer can make all ranks' h2d parallel execution
# the cost is that we need to allocate extra h2d_buffer's GPU memory
free_bytes = free_bytes_divided_3
else:
# if the memory is not enough, it will fallback to disable_h2d_buffer mode,
# at this time, the bandwidth will be limited by the h2d of a single machine,
# but we can save GPU memory
self._logger_rank0(
self._logger_once(
f"[rank{self._rank}] disable h2d buffer when max_tensor_bytes {max_tensor_bytes} is larger than free_bytes {free_bytes} // 3"
)
free_bytes = free_bytes // (2 * _ALIGN_SIZE) * _ALIGN_SIZE
Expand Down Expand Up @@ -1074,7 +1077,7 @@ def _update_per_bucket_p2p(
req_thread.start()
socket.send_pyobj(handle)
for gidx, (owner_rank, bucket) in enumerate(buckets):
self._logger_rank0(
self._logger_once(
f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
)
_buffer = buffer[gidx % 2 * bucket_size : gidx % 2 * bucket_size + bucket.size]
Expand Down Expand Up @@ -1178,7 +1181,7 @@ def _update_per_bucket(
torch.cuda.memory_allocated() / 1024 / 1024,
torch.cuda.memory_reserved() / 1024 / 1024,
)
self._logger_rank0(
self._logger_once(
f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
f"Current CUDA allocated {alloc:.2f} MB, "
f"reserved {reserved:.2f} MB."
Expand Down
53 changes: 46 additions & 7 deletions tests/test_update.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import random
import subprocess
import time

import pytest
import torch
import zmq
from torch.multiprocessing import Queue, get_context
Expand Down Expand Up @@ -63,9 +65,8 @@ def check_weights(names_to_check: dict[str, bool], socket_paths: list[tuple[str,
check_weights(names_to_check, socket_paths)


def run():
def run_with_specified_ranks(ranks: list[int]):
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
ctx = get_context("spawn")
queue = ctx.Queue()
_device_uuid = _get_physical_gpu_id(rank)
Expand All @@ -76,15 +77,53 @@ def run():
proc.start()
ps.register_checkpoint(checkpoint_name, named_tensors=named_tensors)
ps.gather_metas(checkpoint_name)
ranks_list = [[], list(range(world_size // 2)), [], list(range(world_size))]
for ranks in ranks_list:
ps.update(checkpoint_name, queue.put, ranks=ranks)
# sleep 3s to wait process group is destroyed
time.sleep(3)
ps.update(checkpoint_name, queue.put, ranks=ranks)
time.sleep(5)
ps.unregister_checkpoint(checkpoint_name)
queue.put(None)
proc.join()


def run():
world_size = int(os.getenv("WORLD_SIZE"))
random.seed(42)
ranklist = [
list(random.sample(range(world_size), k=num_ranks)) for num_ranks in range(world_size + 1)
]
for ranks in ranklist:
run_with_specified_ranks(ranks)


@pytest.mark.gpu
def test_update():
world_size = torch.cuda.device_count()
assert world_size >= 2, "This test requires at least 2 GPUs."

master_addr = "localhost"
master_port = random.randint(20000, 30000)
Copy link

Copilot AI Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using random port selection for distributed training can lead to port conflicts in CI environments. Consider using a more robust port selection mechanism or environment variable override.

Copilot uses AI. Check for mistakes.

cmd = [
"torchrun",
"--nproc_per_node",
str(world_size),
"--master_addr",
master_addr,
"--master_port",
str(master_port),
"tests/test_update.py",
]

result = subprocess.run(
cmd,
capture_output=False,
text=True,
cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
shell=False,
check=False,
)

assert result.returncode == 0


if __name__ == "__main__":
run()