Skip to content

Commit 0113f50

Browse files
russellbminpeter
authored andcommitted
[CI] Fix race condition with StatelessProcessGroup.barrier (vllm-project#18506)
Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: minpeter <kali2005611@gmail.com>
1 parent 88fa0c5 commit 0113f50

File tree

3 files changed

+157
-25
lines changed

3 files changed

+157
-25
lines changed

tests/distributed/test_shm_broadcast.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
1111
from vllm.distributed.utils import StatelessProcessGroup
12-
from vllm.utils import get_ip, get_open_port, update_environment_variables
12+
from vllm.utils import get_open_port, update_environment_variables
1313

1414

1515
def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]:
@@ -60,12 +60,12 @@ def worker_fn():
6060
rank = dist.get_rank()
6161
if rank == 0:
6262
port = get_open_port()
63-
ip = get_ip()
63+
ip = '127.0.0.1'
6464
dist.broadcast_object_list([ip, port], src=0)
6565
else:
6666
recv = [None, None]
6767
dist.broadcast_object_list(recv, src=0)
68-
ip, port = recv
68+
ip, port = recv # type: ignore
6969

7070
stateless_pg = StatelessProcessGroup.create(ip, port, rank,
7171
dist.get_world_size())
@@ -107,10 +107,10 @@ def worker_fn():
107107

108108
if pg == dist.group.WORLD:
109109
dist.barrier()
110-
print("torch distributed passed the test!")
110+
print(f"torch distributed passed the test! Rank {rank}")
111111
else:
112112
pg.barrier()
113-
print("StatelessProcessGroup passed the test!")
113+
print(f"StatelessProcessGroup passed the test! Rank {rank}")
114114

115115

116116
def test_shm_broadcast():

vllm/distributed/device_communicators/shm_broadcast.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
import os
43
import pickle
5-
import sys
64
import time
75
from contextlib import contextmanager
86
from dataclasses import dataclass, field
@@ -19,7 +17,7 @@
1917
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
2018

2119
import vllm.envs as envs
22-
from vllm.distributed.utils import StatelessProcessGroup
20+
from vllm.distributed.utils import StatelessProcessGroup, sched_yield
2321
from vllm.logger import init_logger
2422
from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path,
2523
is_valid_ipv6_address)
@@ -28,20 +26,6 @@
2826

2927
logger = init_logger(__name__)
3028

31-
# We prefer to use os.sched_yield as it results in tighter polling loops,
32-
# measured to be around 3e-7 seconds. However on earlier versions of Python
33-
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
34-
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
35-
or (sys.version_info[:2] == (3, 10)
36-
and sys.version_info[2] >= 8))
37-
38-
39-
def sched_yield():
40-
if USE_SCHED_YIELD:
41-
os.sched_yield()
42-
else:
43-
time.sleep(0)
44-
4529

4630
class ShmRingBuffer:
4731

vllm/distributed/utils.py

Lines changed: 151 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@
66
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
77
import dataclasses
88
import datetime
9+
import os
910
import pickle
1011
import socket
12+
import sys
1113
import time
14+
import uuid
1215
from collections import deque
1316
from collections.abc import Sequence
1417
from typing import Any, Optional
@@ -27,6 +30,20 @@
2730

2831
logger = init_logger(__name__)
2932

33+
# We prefer to use os.sched_yield as it results in tighter polling loops,
34+
# measured to be around 3e-7 seconds. However on earlier versions of Python
35+
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
36+
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
37+
or (sys.version_info[:2] == (3, 10)
38+
and sys.version_info[2] >= 8))
39+
40+
41+
def sched_yield():
42+
if USE_SCHED_YIELD:
43+
os.sched_yield()
44+
else:
45+
time.sleep(0)
46+
3047

3148
def ensure_divisibility(numerator, denominator):
3249
"""Ensure that numerator is divisible by the denominator."""
@@ -212,10 +229,141 @@ def all_gather_obj(self, obj: Any) -> list[Any]:
212229
gathered_objs.append(recv_obj)
213230
return gathered_objs
214231

215-
def barrier(self):
216-
"""A barrier to synchronize all ranks."""
232+
def barrier(self, timeout: float = 30.0):
233+
"""A robust barrier to synchronize all ranks.
234+
235+
236+
Uses a multi-phase approach to ensure all processes reach the barrier
237+
before proceeding:
238+
239+
1. Each process signals it has reached the barrier
240+
241+
2. Each process signals that it has confirmed the arrival of all other
242+
ranks.
243+
244+
3. Rank 0 waits for all other ranks to signal their departure to ensure
245+
that all ranks have departed the barrier first.
246+
247+
Args:
248+
timeout: Maximum time in seconds to wait for each phase (in seconds)
249+
250+
251+
Raises:
252+
RuntimeError: If coordination fails or times out
253+
"""
254+
# Generate a barrier ID that is globally unique
255+
try:
256+
if self.rank == 0:
257+
barrier_id = f"barrier_{uuid.uuid4()}"
258+
self.broadcast_obj(barrier_id, src=0)
259+
else:
260+
barrier_id = self.broadcast_obj(None, src=0)
261+
except Exception as e:
262+
raise RuntimeError("Failed to broadcast barrier_id") from e
263+
264+
# Phase 1: Signal arrival at barrier
265+
# Wait for all processes to arrive
266+
# We need all ranks to confirm the arrival of all other ranks.
267+
# This is the key synchronization point.
268+
arrival_key = f"arrival_{barrier_id}_{self.rank}"
269+
try:
270+
self.store.set(arrival_key, b"1")
271+
except Exception as e:
272+
raise RuntimeError("Failed to signal barrier arrival") from e
273+
274+
start_time = time.time()
275+
processes_arrived: set[int] = set()
276+
277+
while len(processes_arrived) < self.world_size:
278+
# Check for timeout
279+
cur_time = time.time()
280+
if cur_time - start_time > timeout:
281+
raise RuntimeError("Barrier timed out after %f seconds",
282+
timeout)
283+
284+
# Check for each process
285+
for i in range(self.world_size):
286+
if i in processes_arrived:
287+
continue
288+
289+
key = f"arrival_{barrier_id}_{i}"
290+
try:
291+
# Try to get the key - if it exists, we'll get a value
292+
# If it doesn't exist, it will throw an exception
293+
self.store.get(key)
294+
processes_arrived.add(i)
295+
except KeyError:
296+
# Key doesn't exist yet
297+
pass
298+
except Exception as check_e:
299+
logger.debug("Error checking key existence: %s", check_e)
300+
sched_yield()
301+
302+
# Short sleep to avoid tight polling
303+
if len(processes_arrived) < self.world_size:
304+
sched_yield()
305+
306+
# Phase 2: Signal departure from barrier
307+
# We only care to block at this stage in rank 0, which runs the
308+
# server side of the TCPStore. We want to make sure that all
309+
# clients have departed the barrier before rank 0 in case the
310+
# next thing after the barrier is a shutdown, including tearing
311+
# down the TCPStore. Other ranks can exit the barrier immediately
312+
# after signaling their departure.
313+
departure_key = f"departure_{barrier_id}_{self.rank}"
314+
try:
315+
self.store.set(departure_key, b"1")
316+
except Exception as e:
317+
raise RuntimeError("Failed to signal barrier departure") from e
318+
319+
if self.rank != 0:
320+
return
321+
322+
# Make rank 0 wait for all processes to signal departure
323+
start_time = time.time()
324+
processes_departed: set[int] = set()
325+
326+
while len(processes_departed) < self.world_size:
327+
# Check for timeout
328+
if time.time() - start_time > timeout:
329+
raise RuntimeError("Barrier departure timed out after %f s",
330+
timeout)
331+
332+
# Check for each process
333+
for i in range(self.world_size):
334+
if i in processes_departed:
335+
continue
336+
337+
key = f"departure_{barrier_id}_{i}"
338+
try:
339+
# Try to get the key - if it exists, we'll get a value
340+
# If it doesn't exist, it will throw an exception
341+
self.store.get(key)
342+
processes_departed.add(i)
343+
except KeyError:
344+
# Key doesn't exist yet
345+
pass
346+
except Exception as check_e:
347+
logger.debug("Error checking key existence: %s", check_e)
348+
sched_yield()
349+
350+
# Short sleep to avoid tight polling
351+
if len(processes_departed) < self.world_size:
352+
sched_yield()
353+
354+
# Clean up keys to avoid leaking memory in the store
217355
for i in range(self.world_size):
218-
self.broadcast_obj(None, src=i)
356+
try:
357+
self.store.delete_key(f"arrival_{barrier_id}_{i}")
358+
except Exception:
359+
logger.debug("Error deleting key: %s",
360+
f'arrival_{barrier_id}_{i}')
361+
362+
try:
363+
self.store.delete_key(f"departure_{barrier_id}_{i}")
364+
except Exception:
365+
logger.debug("Error deleting key: %s",
366+
f'departure_{barrier_id}_{i}')
219367

220368
@staticmethod
221369
def create(

0 commit comments

Comments
 (0)