Skip to content

[MISC][Bugfix] Use less CPU when message queue has been empty for some time #16226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,21 @@ def test_models(
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model, distributed_executor_backend, attention_backend, "
"test_suite", [
("distilbert/distilgpt2", "ray", "", "L4"),
("distilbert/distilgpt2", "mp", "", "L4"),
("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4"),
("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4"),
("distilbert/distilgpt2", "ray", "", "A100"),
("distilbert/distilgpt2", "mp", "", "A100"),
("distilbert/distilgpt2", "mp", "FLASHINFER", "A100"),
("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"),
"test_suite, extra_env", [
("distilbert/distilgpt2", "ray", "", "L4", {}),
("distilbert/distilgpt2", "mp", "", "L4", {}),
("distilbert/distilgpt2", "ray", "", "L4", {
"VLLM_SLEEP_WHEN_IDLE": "1"
}),
("distilbert/distilgpt2", "mp", "", "L4", {
"VLLM_SLEEP_WHEN_IDLE": "1"
}),
("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}),
("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}),
("distilbert/distilgpt2", "ray", "", "A100", {}),
("distilbert/distilgpt2", "mp", "", "A100", {}),
("distilbert/distilgpt2", "mp", "FLASHINFER", "A100", {}),
("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100", {}),
])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_models_distributed(
Expand All @@ -148,6 +154,7 @@ def test_models_distributed(
distributed_executor_backend: str,
attention_backend: str,
test_suite: str,
extra_env: dict[str, str],
enable_prompt_embeds: bool,
) -> None:

Expand All @@ -173,6 +180,9 @@ def test_models_distributed(
attention_backend,
)

for k, v in extra_env.items():
monkeypatch_context.setenv(k, v)

dtype = "half"
max_tokens = 5

Expand Down
47 changes: 45 additions & 2 deletions vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,43 @@
logger = init_logger(__name__)


class SpinTimer:

def record_activity(self):
pass

def spin(self):
sched_yield()


class SpinSleepTimer(SpinTimer):
"""
In setups which have long inactivity periods it is desirable to reduce
system power consumption when vllm does nothing. This would lead to more
CPU thermal headroom when a request eventually comes, especially when
multiple GPUs are connected as each GPU would otherwise pin one thread at
100% CPU usage.

The simplest solution is to reduce polling frequency when there is no
activity for a certain period of time.
"""

def __init__(self, busy_loop_s: float = 3.0, wait_sleep_s: float = 0.1):
self.last_activity = time.monotonic()
self.busy_loop_s = busy_loop_s
self.wait_sleep_s = wait_sleep_s

def record_activity(self):
self.last_activity = time.monotonic()

def spin(self):
curr_time = time.monotonic()
if curr_time >= self.last_activity + self.busy_loop_s:
time.sleep(self.wait_sleep_s)
else:
sched_yield()


class ShmRingBuffer:

def __init__(self,
Expand All @@ -42,7 +79,7 @@ def __init__(self,
of items that can be stored in the buffer are known in advance.
In this case, we don't need to synchronize the access to
the buffer.

Buffer memory layout:
data metadata
| |
Expand Down Expand Up @@ -238,6 +275,7 @@ def __init__(
self.local_reader_rank = -1
# rank does not matter for remote readers
self._is_remote_reader = False
self._read_spin_timer = SpinTimer()

self.handle = Handle(
local_reader_ranks=local_reader_ranks,
Expand Down Expand Up @@ -276,6 +314,9 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue":
self.local_socket.connect(socket_addr)

self.remote_socket = None

self._read_spin_timer = SpinSleepTimer(
) if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer()
else:
self.buffer = None # type: ignore
self.current_idx = -1
Expand Down Expand Up @@ -407,7 +448,7 @@ def acquire_read(self,
# we need to wait until it is written

# Release the processor to other threads
sched_yield()
self._read_spin_timer.spin()

# if we wait for a long time, log a message
if (time.monotonic() - start_time
Expand Down Expand Up @@ -438,6 +479,8 @@ def acquire_read(self,
metadata_buffer[self.local_reader_rank + 1] = 1
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks

self._read_spin_timer.record_activity()
break

def enqueue(self, obj, timeout: Optional[float] = None):
Expand Down
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
VLLM_ALL2ALL_BACKEND: str = "naive"
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
VLLM_SLEEP_WHEN_IDLE: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -841,6 +842,11 @@ def get_vllm_port() -> Optional[int]:
# Regex timeout for use by the vLLM tool parsing plugins.
"VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS":
lambda: int(os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1")),

# Reduce CPU usage when vLLM is idle. Enabling this will incur small
# latency penalty when a request eventually comes.
"VLLM_SLEEP_WHEN_IDLE":
lambda: bool(int(os.getenv("VLLM_SLEEP_WHEN_IDLE", "0"))),
}

# --8<-- [end:env-vars-definition]
Expand Down
Loading