Skip to content

Commit 65397e4

Browse files
authored
[Bugfix] Allow CUDA_VISIBLE_DEVICES='' in Platform.device_id_to_physical_device_id (#18979)
Signed-off-by: Seiji Eicher <seiji@anyscale.com>
1 parent 9502c38 commit 65397e4

File tree

3 files changed

+114
-10
lines changed

3 files changed

+114
-10
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
5+
from vllm.engine.arg_utils import EngineArgs
6+
from vllm.model_executor.layers.quantization.quark.utils import deep_compare
7+
8+
9+
def test_cuda_empty_vs_unset_configs(monkeypatch: pytest.MonkeyPatch):
10+
"""Test that configs created with normal (untouched) CUDA_VISIBLE_DEVICES
11+
and CUDA_VISIBLE_DEVICES="" are equivalent. This ensures consistent
12+
behavior regardless of whether GPU visibility is disabled via empty string
13+
or left in its normal state.
14+
"""
15+
16+
def create_config():
17+
engine_args = EngineArgs(model="deepseek-ai/DeepSeek-V2-Lite",
18+
trust_remote_code=True)
19+
return engine_args.create_engine_config()
20+
21+
# Create config with CUDA_VISIBLE_DEVICES set normally
22+
normal_config = create_config()
23+
24+
# Create config with CUDA_VISIBLE_DEVICES=""
25+
with monkeypatch.context() as m:
26+
m.setenv("CUDA_VISIBLE_DEVICES", "")
27+
empty_config = create_config()
28+
29+
normal_config_dict = vars(normal_config)
30+
empty_config_dict = vars(empty_config)
31+
32+
# Remove instance_id before comparison as it's expected to be different
33+
normal_config_dict.pop("instance_id", None)
34+
empty_config_dict.pop("instance_id", None)
35+
36+
assert deep_compare(normal_config_dict, empty_config_dict), (
37+
"Configs with normal CUDA_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES=\"\""
38+
" should be equivalent")

tests/v1/engine/test_engine_core_client.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
import uuid
99
from threading import Thread
1010
from typing import Optional
11+
from unittest.mock import MagicMock
1112

1213
import pytest
14+
import torch
1315
from transformers import AutoTokenizer
1416

1517
from tests.utils import multi_gpu_test
@@ -517,3 +519,72 @@ def kill_first_child():
517519
)
518520

519521
assert "Engine core initialization failed" in str(e_info.value)
522+
523+
524+
@create_new_process_for_each_test()
525+
def test_engine_core_proc_instantiation_cuda_empty(
526+
monkeypatch: pytest.MonkeyPatch):
527+
"""
528+
Test that EngineCoreProc can be instantiated when CUDA_VISIBLE_DEVICES
529+
is empty. This ensures the engine frontend does not need access to GPUs.
530+
"""
531+
532+
from vllm.v1.engine.core import EngineCoreProc
533+
from vllm.v1.executor.abstract import Executor
534+
535+
# Create a simple mock executor instead of a complex custom class
536+
mock_executor_class = MagicMock(spec=Executor)
537+
538+
def create_mock_executor(vllm_config):
539+
mock_executor = MagicMock()
540+
541+
# Only implement the methods that are actually called during init
542+
from vllm.v1.kv_cache_interface import FullAttentionSpec
543+
mock_spec = FullAttentionSpec(block_size=16,
544+
num_kv_heads=1,
545+
head_size=64,
546+
dtype=torch.float16,
547+
use_mla=False)
548+
549+
mock_executor.get_kv_cache_specs.return_value = [{
550+
"default": mock_spec
551+
}]
552+
mock_executor.determine_available_memory.return_value = [
553+
1024 * 1024 * 1024
554+
]
555+
mock_executor.initialize_from_config.return_value = None
556+
mock_executor.max_concurrent_batches = 1
557+
558+
return mock_executor
559+
560+
mock_executor_class.side_effect = create_mock_executor
561+
562+
with monkeypatch.context() as m:
563+
m.setenv("VLLM_USE_V1", "1")
564+
m.setenv("CUDA_VISIBLE_DEVICES", "") # No CUDA devices
565+
566+
from vllm.v1.utils import EngineZmqAddresses
567+
568+
def mock_startup_handshake(self, handshake_socket, on_head_node,
569+
parallel_config):
570+
return EngineZmqAddresses(inputs=["tcp://127.0.0.1:5555"],
571+
outputs=["tcp://127.0.0.1:5556"],
572+
coordinator_input=None,
573+
coordinator_output=None)
574+
575+
# Background processes are not important here
576+
m.setattr(EngineCoreProc, "startup_handshake", mock_startup_handshake)
577+
578+
vllm_config = EngineArgs(
579+
model="deepseek-ai/DeepSeek-V2-Lite",
580+
trust_remote_code=True).create_engine_config()
581+
engine_core_proc = EngineCoreProc(
582+
vllm_config=vllm_config,
583+
on_head_node=True,
584+
handshake_address="tcp://127.0.0.1:12345",
585+
executor_class=mock_executor_class,
586+
log_stats=False,
587+
engine_index=0,
588+
)
589+
590+
engine_core_proc.shutdown()

vllm/platforms/interface.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -173,17 +173,12 @@ def is_sleep_mode_available(self) -> bool:
173173

174174
@classmethod
175175
def device_id_to_physical_device_id(cls, device_id: int):
176-
if cls.device_control_env_var in os.environ:
176+
# Treat empty device control env var as unset. This is a valid
177+
# configuration in Ray setups where the engine is launched in
178+
# a CPU-only placement group located on a GPU node.
179+
if cls.device_control_env_var in os.environ and os.environ[
180+
cls.device_control_env_var] != "":
177181
device_ids = os.environ[cls.device_control_env_var].split(",")
178-
if device_ids == [""]:
179-
msg = (f"{cls.device_control_env_var} is set to empty string, "
180-
"which means current platform support is disabled. If "
181-
"you are using ray, please unset the environment "
182-
f"variable `{cls.device_control_env_var}` inside the "
183-
"worker/actor. Check "
184-
"https://github.com/vllm-project/vllm/issues/8402 for "
185-
"more information.")
186-
raise RuntimeError(msg)
187182
physical_device_id = device_ids[device_id]
188183
return int(physical_device_id)
189184
else:

0 commit comments

Comments
 (0)