Skip to content

Commit c67453d

Browse files
mgoinxuebwang-amd
authored andcommitted
[UX] Enforce valid choices for envs like VLLM_ATTENTION_BACKEND, etc (vllm-project#24761)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent db18c86 commit c67453d

File tree

1 file changed

+77
-24
lines changed

1 file changed

+77
-24
lines changed

vllm/envs.py

Lines changed: 77 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import os
77
import sys
88
import tempfile
9-
from typing import TYPE_CHECKING, Any, Callable, Optional
9+
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
1010

1111
if TYPE_CHECKING:
1212
VLLM_HOST_IP: str = ""
@@ -56,11 +56,12 @@
5656
VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING: bool = True
5757
VLLM_USE_RAY_SPMD_WORKER: bool = False
5858
VLLM_USE_RAY_COMPILED_DAG: bool = False
59-
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "auto"
59+
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl",
60+
"shm"] = "auto"
6061
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
6162
VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True
6263
VLLM_XLA_USE_SPMD: bool = False
63-
VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
64+
VLLM_WORKER_MULTIPROC_METHOD: Literal["fork", "spawn"] = "fork"
6465
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
6566
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
6667
VLLM_VIDEO_FETCH_TIMEOUT: int = 30
@@ -77,7 +78,8 @@
7778
VLLM_DOCKER_BUILD_CONTEXT: bool = False
7879
VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL: bool = False
7980
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False
80-
CMAKE_BUILD_TYPE: Optional[str] = None
81+
CMAKE_BUILD_TYPE: Optional[Literal["Debug", "Release",
82+
"RelWithDebInfo"]] = None
8183
VERBOSE: bool = False
8284
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
8385
VLLM_RPC_TIMEOUT: int = 10000 # ms
@@ -140,22 +142,25 @@
140142
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
141143
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
142144
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
143-
VLLM_FLASHINFER_MOE_BACKEND: str = "throughput"
145+
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput",
146+
"latency"] = "throughput"
144147
VLLM_XGRAMMAR_CACHE_MB: int = 0
145148
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
146149
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
147150
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
148151
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
149-
VLLM_ALL2ALL_BACKEND: str = "naive"
152+
VLLM_ALL2ALL_BACKEND: Literal["naive", "pplx", "deepep_high_throughput",
153+
"deepep_low_latency"] = "naive"
150154
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
151155
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
152156
VLLM_SLEEP_WHEN_IDLE: bool = False
153157
VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16
154158
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300
155-
VLLM_KV_CACHE_LAYOUT: Optional[str] = None
159+
VLLM_KV_CACHE_LAYOUT: Optional[Literal["NHD", "HND"]] = None
156160
VLLM_COMPUTE_NANS_IN_LOGITS: bool = False
157161
VLLM_USE_NVFP4_CT_EMULATIONS: bool = False
158-
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE"
162+
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: Literal["FP", "INT8", "INT6", "INT4",
163+
"NONE"] = "NONE"
159164
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
160165
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
161166
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120
@@ -207,6 +212,48 @@ def maybe_convert_bool(value: Optional[str]) -> Optional[bool]:
207212
return bool(int(value))
208213

209214

215+
def env_with_choices(
216+
env_name: str,
217+
default: Optional[str],
218+
choices: Union[list[str], Callable[[], list[str]]],
219+
case_sensitive: bool = True) -> Callable[[], Optional[str]]:
220+
"""
221+
Create a lambda that validates environment variable against allowed choices
222+
223+
Args:
224+
env_name: Name of the environment variable
225+
default: Default value if not set (can be None)
226+
choices: List of valid string options or callable that returns list
227+
case_sensitive: Whether validation should be case sensitive
228+
229+
Returns:
230+
Lambda function for environment_variables dict
231+
"""
232+
233+
def _get_validated_env() -> Optional[str]:
234+
value = os.getenv(env_name)
235+
if value is None:
236+
return default
237+
238+
# Resolve choices if it's a callable (for lazy loading)
239+
actual_choices = choices() if callable(choices) else choices
240+
241+
if not case_sensitive:
242+
check_value = value.lower()
243+
check_choices = [choice.lower() for choice in actual_choices]
244+
else:
245+
check_value = value
246+
check_choices = actual_choices
247+
248+
if check_value not in check_choices:
249+
raise ValueError(f"Invalid value '{value}' for {env_name}. "
250+
f"Valid options: {actual_choices}.")
251+
252+
return value
253+
254+
return _get_validated_env
255+
256+
210257
def get_vllm_port() -> Optional[int]:
211258
"""Get the port from VLLM_PORT environment variable.
212259
@@ -287,7 +334,8 @@ def get_vllm_port() -> Optional[int]:
287334
# If not set, defaults to "Debug" or "RelWithDebInfo"
288335
# Available options: "Debug", "Release", "RelWithDebInfo"
289336
"CMAKE_BUILD_TYPE":
290-
lambda: os.getenv("CMAKE_BUILD_TYPE"),
337+
env_with_choices("CMAKE_BUILD_TYPE", None,
338+
["Debug", "Release", "RelWithDebInfo"]),
291339

292340
# If set, vllm will print verbose logs during installation
293341
"VERBOSE":
@@ -476,7 +524,7 @@ def get_vllm_port() -> Optional[int]:
476524
lambda: int(os.getenv("VLLM_TRACE_FUNCTION", "0")),
477525

478526
# Backend for attention computation
479-
# Available options:
527+
# Example options:
480528
# - "TORCH_SDPA": use torch.nn.MultiheadAttention
481529
# - "FLASH_ATTN": use FlashAttention
482530
# - "XFORMERS": use XFormers
@@ -486,8 +534,11 @@ def get_vllm_port() -> Optional[int]:
486534
# - "FLASH_ATTN_MLA": use FlashAttention for MLA
487535
# - "FLASHINFER_MLA": use FlashInfer for MLA
488536
# - "CUTLASS_MLA": use CUTLASS for MLA
537+
# All possible options loaded dynamically from _Backend enum
489538
"VLLM_ATTENTION_BACKEND":
490-
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),
539+
env_with_choices("VLLM_ATTENTION_BACKEND", None,
540+
lambda: list(__import__('vllm.platforms.interface', \
541+
fromlist=['_Backend'])._Backend.__members__.keys())),
491542

492543
# If set, vllm will use flashinfer sampler
493544
"VLLM_USE_FLASHINFER_SAMPLER":
@@ -550,7 +601,8 @@ def get_vllm_port() -> Optional[int]:
550601
# - "shm": use shared memory and gRPC for communication
551602
# This flag is ignored if VLLM_USE_RAY_COMPILED_DAG is not set.
552603
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
553-
lambda: os.getenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "auto"),
604+
env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "auto",
605+
["auto", "nccl", "shm"]),
554606

555607
# If the env var is set, it enables GPU communication overlap
556608
# (experimental feature) in Ray's Compiled Graph. This flag is ignored if
@@ -569,7 +621,8 @@ def get_vllm_port() -> Optional[int]:
569621
# Use dedicated multiprocess context for workers.
570622
# Both spawn and fork work
571623
"VLLM_WORKER_MULTIPROC_METHOD":
572-
lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "fork"),
624+
env_with_choices("VLLM_WORKER_MULTIPROC_METHOD", "fork",
625+
["spawn", "fork"]),
573626

574627
# Path to the cache for storing downloaded assets
575628
"VLLM_ASSETS_CACHE":
@@ -833,7 +886,8 @@ def get_vllm_port() -> Optional[int]:
833886
# Choice of quantization level: FP, INT8, INT6, INT4 or NONE
834887
# Recommended for large models to get allreduce
835888
"VLLM_ROCM_QUICK_REDUCE_QUANTIZATION":
836-
lambda: os.getenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", "NONE").upper(),
889+
env_with_choices("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", "NONE",
890+
["FP", "INT8", "INT6", "INT4", "NONE"]),
837891

838892
# Custom quick allreduce kernel for MI3* cards
839893
# Due to the lack of the bfloat16 asm instruction, bfloat16
@@ -1075,21 +1129,20 @@ def get_vllm_port() -> Optional[int]:
10751129
# - "deepep_high_throughput", use deepep high-throughput kernels
10761130
# - "deepep_low_latency", use deepep low-latency kernels
10771131
"VLLM_ALL2ALL_BACKEND":
1078-
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
1132+
env_with_choices("VLLM_ALL2ALL_BACKEND", "naive",
1133+
["naive", "pplx",
1134+
"deepep_high_throughput", "deepep_low_latency"]),
10791135

1080-
# Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support. Both
1081-
# require compute capability 10.0 or above.
1136+
# Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support.
1137+
# Both require compute capability 10.0 or above.
10821138
# Available options:
10831139
# - "throughput": [default]
10841140
# Uses CUTLASS kernels optimized for high-throughput batch inference.
10851141
# - "latency":
10861142
# Uses TensorRT-LLM kernels optimized for low-latency inference.
1087-
# To set this backend, define the environment variable:
1088-
# export VLLM_FLASHINFER_MOE_BACKEND=latency.
1089-
# If not set, defaults to "throughput".
1090-
"VLLM_FLASHINFER_MOE_BACKEND": lambda: os.getenv(
1091-
"VLLM_FLASHINFER_MOE_BACKEND", "throughput"
1092-
),
1143+
"VLLM_FLASHINFER_MOE_BACKEND":
1144+
env_with_choices("VLLM_FLASHINFER_MOE_BACKEND", "throughput",
1145+
["throughput", "latency"]),
10931146

10941147
# Control the maximum number of tokens per expert supported by the
10951148
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
@@ -1145,7 +1198,7 @@ def get_vllm_port() -> Optional[int]:
11451198
# leave the layout choice to the backend. Mind that backends may only
11461199
# implement and support a subset of all possible layouts.
11471200
"VLLM_KV_CACHE_LAYOUT":
1148-
lambda: os.getenv("VLLM_KV_CACHE_LAYOUT", None),
1201+
env_with_choices("VLLM_KV_CACHE_LAYOUT", None, ["NHD", "HND"]),
11491202

11501203
# Enable checking whether the generated logits contain NaNs,
11511204
# indicating corrupted output. Useful for debugging low level bugs

0 commit comments

Comments
 (0)