Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions tests/distributed/test_shm_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
from vllm.distributed.utils import StatelessProcessGroup
from vllm.utils import get_ip, get_open_port, update_environment_variables
from vllm.utils import get_open_port, update_environment_variables


def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]:
Expand Down Expand Up @@ -60,12 +60,12 @@ def worker_fn():
rank = dist.get_rank()
if rank == 0:
port = get_open_port()
ip = get_ip()
ip = '127.0.0.1'
dist.broadcast_object_list([ip, port], src=0)
else:
recv = [None, None]
dist.broadcast_object_list(recv, src=0)
ip, port = recv
ip, port = recv # type: ignore

stateless_pg = StatelessProcessGroup.create(ip, port, rank,
dist.get_world_size())
Expand Down Expand Up @@ -107,10 +107,10 @@ def worker_fn():

if pg == dist.group.WORLD:
dist.barrier()
print("torch distributed passed the test!")
print(f"torch distributed passed the test! Rank {rank}")
else:
pg.barrier()
print("StatelessProcessGroup passed the test!")
print(f"StatelessProcessGroup passed the test! Rank {rank}")


def test_shm_broadcast():
Expand Down
18 changes: 1 addition & 17 deletions vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

import os
import pickle
import sys
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
Expand All @@ -19,7 +17,7 @@
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore

import vllm.envs as envs
from vllm.distributed.utils import StatelessProcessGroup
from vllm.distributed.utils import StatelessProcessGroup, sched_yield
from vllm.logger import init_logger
from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path,
is_valid_ipv6_address)
Expand All @@ -28,20 +26,6 @@

logger = init_logger(__name__)

# We prefer to use os.sched_yield as it results in tighter polling loops,
# measured to be around 3e-7 seconds. However on earlier versions of Python
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
or (sys.version_info[:2] == (3, 10)
and sys.version_info[2] >= 8))


def sched_yield():
if USE_SCHED_YIELD:
os.sched_yield()
else:
time.sleep(0)


class ShmRingBuffer:

Expand Down
154 changes: 151 additions & 3 deletions vllm/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import dataclasses
import datetime
import os
import pickle
import socket
import sys
import time
import uuid
from collections import deque
from collections.abc import Sequence
from typing import Any, Optional
Expand All @@ -27,6 +30,20 @@

logger = init_logger(__name__)

# We prefer to use os.sched_yield as it results in tighter polling loops,
# measured to be around 3e-7 seconds. However on earlier versions of Python
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
or (sys.version_info[:2] == (3, 10)
and sys.version_info[2] >= 8))


def sched_yield():
if USE_SCHED_YIELD:
os.sched_yield()
else:
time.sleep(0)


def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
Expand Down Expand Up @@ -212,10 +229,141 @@ def all_gather_obj(self, obj: Any) -> list[Any]:
gathered_objs.append(recv_obj)
return gathered_objs

def barrier(self):
"""A barrier to synchronize all ranks."""
def barrier(self, timeout: float = 30.0):
"""A robust barrier to synchronize all ranks.


Uses a multi-phase approach to ensure all processes reach the barrier
before proceeding:

1. Each process signals it has reached the barrier

2. Each process signals that it has confirmed the arrival of all other
ranks.

3. Rank 0 waits for all other ranks to signal their departure to ensure
that all ranks have departed the barrier first.

Args:
timeout: Maximum time in seconds to wait for each phase (in seconds)


Raises:
RuntimeError: If coordination fails or times out
"""
# Generate a barrier ID that is globally unique
try:
if self.rank == 0:
barrier_id = f"barrier_{uuid.uuid4()}"
self.broadcast_obj(barrier_id, src=0)
else:
barrier_id = self.broadcast_obj(None, src=0)
except Exception as e:
raise RuntimeError("Failed to broadcast barrier_id") from e

# Phase 1: Signal arrival at barrier
# Wait for all processes to arrive
# We need all ranks to confirm the arrival of all other ranks.
# This is the key synchronization point.
arrival_key = f"arrival_{barrier_id}_{self.rank}"
try:
self.store.set(arrival_key, b"1")
except Exception as e:
raise RuntimeError("Failed to signal barrier arrival") from e

start_time = time.time()
processes_arrived: set[int] = set()

while len(processes_arrived) < self.world_size:
# Check for timeout
cur_time = time.time()
if cur_time - start_time > timeout:
raise RuntimeError("Barrier timed out after %f seconds",
timeout)

# Check for each process
for i in range(self.world_size):
if i in processes_arrived:
continue

key = f"arrival_{barrier_id}_{i}"
try:
# Try to get the key - if it exists, we'll get a value
# If it doesn't exist, it will throw an exception
self.store.get(key)
processes_arrived.add(i)
except KeyError:
# Key doesn't exist yet
pass
except Exception as check_e:
logger.debug("Error checking key existence: %s", check_e)
sched_yield()

# Short sleep to avoid tight polling
if len(processes_arrived) < self.world_size:
sched_yield()

# Phase 2: Signal departure from barrier
# We only care to block at this stage in rank 0, which runs the
# server side of the TCPStore. We want to make sure that all
# clients have departed the barrier before rank 0 in case the
# next thing after the barrier is a shutdown, including tearing
# down the TCPStore. Other ranks can exit the barrier immediately
# after signaling their departure.
departure_key = f"departure_{barrier_id}_{self.rank}"
try:
self.store.set(departure_key, b"1")
except Exception as e:
raise RuntimeError("Failed to signal barrier departure") from e

if self.rank != 0:
return

# Make rank 0 wait for all processes to signal departure
start_time = time.time()
processes_departed: set[int] = set()

while len(processes_departed) < self.world_size:
# Check for timeout
if time.time() - start_time > timeout:
raise RuntimeError("Barrier departure timed out after %f s",
timeout)

# Check for each process
for i in range(self.world_size):
if i in processes_departed:
continue

key = f"departure_{barrier_id}_{i}"
try:
# Try to get the key - if it exists, we'll get a value
# If it doesn't exist, it will throw an exception
self.store.get(key)
processes_departed.add(i)
except KeyError:
# Key doesn't exist yet
pass
except Exception as check_e:
logger.debug("Error checking key existence: %s", check_e)
sched_yield()

# Short sleep to avoid tight polling
if len(processes_departed) < self.world_size:
sched_yield()

# Clean up keys to avoid leaking memory in the store
for i in range(self.world_size):
self.broadcast_obj(None, src=i)
try:
self.store.delete_key(f"arrival_{barrier_id}_{i}")
except Exception:
logger.debug("Error deleting key: %s",
f'arrival_{barrier_id}_{i}')

try:
self.store.delete_key(f"departure_{barrier_id}_{i}")
except Exception:
logger.debug("Error deleting key: %s",
f'departure_{barrier_id}_{i}')

@staticmethod
def create(
Expand Down