Skip to content

[Perf][CLI] Improve overall startup time #19941

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ repos:
entry: python tools/check_spdx_header.py
language: python
types: [python]
- id: check-root-lazy-imports
name: Check root lazy imports
entry: python tools/check_init_lazy_imports.py
language: python
types: [python]
- id: check-filenames
name: Check for spaces in all filenames
entry: bash
Expand Down
108 changes: 108 additions & 0 deletions tools/check_init_lazy_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Ensure we perform lazy loading in vllm/__init__.py.
i.e: appears only within the ``if typing.TYPE_CHECKING:`` guard,
**except** for a short whitelist.
"""

from __future__ import annotations

import ast
import pathlib
import sys
from collections.abc import Iterable
from typing import Final

REPO_ROOT: Final = pathlib.Path(__file__).resolve().parent.parent
INIT_PATH: Final = REPO_ROOT / "vllm" / "__init__.py"

# If you need to add items to whitelist, do it here.
ALLOWED_IMPORTS: Final[frozenset[str]] = frozenset({
"vllm.env_override",
})
ALLOWED_FROM_MODULES: Final[frozenset[str]] = frozenset({
".version",
})


def _is_internal(name: str | None, *, level: int = 0) -> bool:
if level > 0:
return True
if name is None:
return False
return name.startswith("vllm.") or name == "vllm"


def _fail(violations: Iterable[tuple[int, str]]) -> None:
print("ERROR: Disallowed eager imports in vllm/__init__.py:\n",
file=sys.stderr)
for lineno, msg in violations:
print(f" Line {lineno}: {msg}", file=sys.stderr)
sys.exit(1)


def main() -> None:
source = INIT_PATH.read_text(encoding="utf-8")
tree = ast.parse(source, filename=str(INIT_PATH))

violations: list[tuple[int, str]] = []

class Visitor(ast.NodeVisitor):

def __init__(self) -> None:
super().__init__()
self._in_type_checking = False

def visit_If(self, node: ast.If) -> None:
guard_is_type_checking = False
test = node.test
if isinstance(test, ast.Attribute) and isinstance(
test.value, ast.Name):
guard_is_type_checking = (test.value.id == "typing"
and test.attr == "TYPE_CHECKING")
elif isinstance(test, ast.Name):
guard_is_type_checking = test.id == "TYPE_CHECKING"

if guard_is_type_checking:
prev = self._in_type_checking
self._in_type_checking = True
for child in node.body:
self.visit(child)
self._in_type_checking = prev
for child in node.orelse:
self.visit(child)
else:
self.generic_visit(node)

def visit_Import(self, node: ast.Import) -> None:
if self._in_type_checking:
return
for alias in node.names:
module_name = alias.name
if _is_internal(
module_name) and module_name not in ALLOWED_IMPORTS:
violations.append((
node.lineno,
f"import '{module_name}' must be inside typing.TYPE_CHECKING", # noqa: E501
))

def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if self._in_type_checking:
return
module_as_written = ("." * node.level) + (node.module or "")
if _is_internal(
node.module, level=node.level
) and module_as_written not in ALLOWED_FROM_MODULES:
violations.append((
node.lineno,
f"from '{module_as_written}' import ... must be inside typing.TYPE_CHECKING", # noqa: E501
))

Visitor().visit(tree)

if violations:
_fail(violations)


if __name__ == "__main__":
main()
75 changes: 59 additions & 16 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,72 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""

# The version.py should be independent library, and we always import the
# version library first. Such assumption is critical for some customization.
from .version import __version__, __version_tuple__ # isort:skip

import typing

# The environment variables override should be imported before any other
# modules to ensure that the environment variables are set before any
# other modules are imported.
import vllm.env_override # isort:skip # noqa: F401

from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (ClassificationOutput, ClassificationRequestOutput,
CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput, ScoringOutput,
ScoringRequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
import vllm.env_override # noqa: F401

MODULE_ATTRS = {
"AsyncEngineArgs": ".engine.arg_utils:AsyncEngineArgs",
"EngineArgs": ".engine.arg_utils:EngineArgs",
"AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine",
"LLMEngine": ".engine.llm_engine:LLMEngine",
"LLM": ".entrypoints.llm:LLM",
"initialize_ray_cluster": ".executor.ray_utils:initialize_ray_cluster",
"PromptType": ".inputs:PromptType",
"TextPrompt": ".inputs:TextPrompt",
"TokensPrompt": ".inputs:TokensPrompt",
"ModelRegistry": ".model_executor.models:ModelRegistry",
"SamplingParams": ".sampling_params:SamplingParams",
"PoolingParams": ".pooling_params:PoolingParams",
"ClassificationOutput": ".outputs:ClassificationOutput",
"ClassificationRequestOutput": ".outputs:ClassificationRequestOutput",
"CompletionOutput": ".outputs:CompletionOutput",
"EmbeddingOutput": ".outputs:EmbeddingOutput",
"EmbeddingRequestOutput": ".outputs:EmbeddingRequestOutput",
"PoolingOutput": ".outputs:PoolingOutput",
"PoolingRequestOutput": ".outputs:PoolingRequestOutput",
"RequestOutput": ".outputs:RequestOutput",
"ScoringOutput": ".outputs:ScoringOutput",
"ScoringRequestOutput": ".outputs:ScoringRequestOutput",
}

if typing.TYPE_CHECKING:
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (ClassificationOutput,
ClassificationRequestOutput, CompletionOutput,
EmbeddingOutput, EmbeddingRequestOutput,
PoolingOutput, PoolingRequestOutput,
RequestOutput, ScoringOutput,
ScoringRequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
else:

def __getattr__(name: str) -> typing.Any:
from importlib import import_module

if name in MODULE_ATTRS:
module_name, attr_name = MODULE_ATTRS[name].split(":")
module = import_module(module_name, __package__)
return getattr(module, attr_name)
else:
raise AttributeError(
f'module {__package__} has no attribute {name}')


__all__ = [
"__version__",
Expand Down
30 changes: 16 additions & 14 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig
from typing_extensions import deprecated, runtime_checkable
from typing_extensions import Self, deprecated, runtime_checkable

import vllm.envs as envs
from vllm import version
Expand Down Expand Up @@ -1537,7 +1537,6 @@ def compute_hash(self) -> str:
def __post_init__(self) -> None:
self.swap_space_bytes = self.swap_space * GiB_bytes

self._verify_args()
self._verify_cache_dtype()
self._verify_prefix_caching()

Expand All @@ -1546,7 +1545,8 @@ def metrics_info(self):
# metrics info
return {key: str(value) for key, value in self.__dict__.items()}

def _verify_args(self) -> None:
@model_validator(mode='after')
def _verify_args(self) -> Self:
if self.cpu_offload_gb < 0:
raise ValueError("CPU offload space must be non-negative"
f", but got {self.cpu_offload_gb}")
Expand All @@ -1556,6 +1556,8 @@ def _verify_args(self) -> None:
"GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.")

return self

def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto":
pass
Expand Down Expand Up @@ -1942,15 +1944,14 @@ def __post_init__(self) -> None:
if self.distributed_executor_backend is None and self.world_size == 1:
self.distributed_executor_backend = "uni"

self._verify_args()

@property
def use_ray(self) -> bool:
return self.distributed_executor_backend == "ray" or (
isinstance(self.distributed_executor_backend, type)
and self.distributed_executor_backend.uses_ray)

def _verify_args(self) -> None:
@model_validator(mode='after')
def _verify_args(self) -> Self:
# Lazy import to avoid circular import
from vllm.executor.executor_base import ExecutorBase
from vllm.platforms import current_platform
Expand All @@ -1977,8 +1978,7 @@ def _verify_args(self) -> None:
raise ValueError("Unable to use nsight profiling unless workers "
"run with Ray.")

assert isinstance(self.worker_extension_cls, str), (
"worker_extension_cls must be a string (qualified class name).")
return self


PreemptionMode = Literal["swap", "recompute"]
Expand Down Expand Up @@ -2202,9 +2202,8 @@ def __post_init__(self) -> None:
self.max_num_partial_prefills, self.max_long_partial_prefills,
self.long_prefill_token_threshold)

self._verify_args()

def _verify_args(self) -> None:
@model_validator(mode='after')
def _verify_args(self) -> Self:
if (self.max_num_batched_tokens < self.max_model_len
and not self.chunked_prefill_enabled):
raise ValueError(
Expand Down Expand Up @@ -2263,6 +2262,8 @@ def _verify_args(self) -> None:
"must be greater than or equal to 1 and less than or equal to "
f"max_num_partial_prefills ({self.max_num_partial_prefills}).")

return self

@property
def is_multi_step(self) -> bool:
return self.num_scheduler_steps > 1
Expand Down Expand Up @@ -2669,8 +2670,6 @@ def __post_init__(self):
if self.posterior_alpha is None:
self.posterior_alpha = 0.3

self._verify_args()

@staticmethod
def _maybe_override_draft_max_model_len(
speculative_max_model_len: Optional[int],
Expand Down Expand Up @@ -2761,7 +2760,8 @@ def create_draft_parallel_config(

return draft_parallel_config

def _verify_args(self) -> None:
@model_validator(mode='after')
def _verify_args(self) -> Self:
if self.num_speculative_tokens is None:
raise ValueError(
"num_speculative_tokens must be provided with "
Expand Down Expand Up @@ -2812,6 +2812,8 @@ def _verify_args(self) -> None:
"Eagle3 is only supported for Llama models. "
f"Got {self.target_model_config.hf_text_config.model_type=}")

return self

@property
def num_lookahead_slots(self) -> int:
"""The number of additional slots the scheduler should allocate per
Expand Down
15 changes: 14 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

# yapf: disable
import argparse
import copy
import dataclasses
import functools
import json
import sys
import threading
Expand Down Expand Up @@ -168,7 +170,8 @@ def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
return type_hints


def get_kwargs(cls: ConfigType) -> dict[str, Any]:
@functools.lru_cache(maxsize=30)
def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
Expand Down Expand Up @@ -269,6 +272,16 @@ def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
return kwargs


def get_kwargs(cls: ConfigType) -> dict[str, Any]:
"""Return argparse kwargs for the given Config dataclass.

The heavy computation is cached via functools.lru_cache, and a deep copy
is returned so callers can mutate the dictionary without affecting the
cached version.
"""
return copy.deepcopy(_compute_kwargs(cls))


@dataclass
class EngineArgs:
"""Arguments for vLLM engine."""
Expand Down
9 changes: 7 additions & 2 deletions vllm/entrypoints/cli/benchmark/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from __future__ import annotations

import argparse
import typing

from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.utils import FlexibleArgumentParser

if typing.TYPE_CHECKING:
from vllm.utils import FlexibleArgumentParser


class BenchmarkSubcommand(CLISubcommand):
Expand All @@ -23,7 +29,6 @@ def validate(self, args: argparse.Namespace) -> None:
def subparser_init(
self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:

bench_parser = subparsers.add_parser(
self.name,
help=self.help,
Expand Down
Loading