Skip to content

[Frontend]Reduce vLLM's import time #15128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 48 additions & 45 deletions vllm/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import ast
import copy
import enum
Expand All @@ -22,42 +24,45 @@
import torch
from pydantic import BaseModel, Field, PrivateAttr
from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig

import vllm.envs as envs
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
get_quantization_config)
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import CpuArchEnum, current_platform
from vllm.platforms import CpuArchEnum
from vllm.sampling_params import GuidedDecodingParams
from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
try_get_generation_config, uses_mrope)
from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, get_open_port, is_torch_equal_or_newer,
random_uuid, resolve_obj_by_qualname)
from vllm.utils import (GiB_bytes, LayerBlockType, LazyLoader,
cuda_device_count_stateless, get_cpu_memory,
get_open_port, is_torch_equal_or_newer, random_uuid,
resolve_obj_by_qualname)

if TYPE_CHECKING:
from _typeshed import DataclassInstance
from ray.util.placement_group import PlacementGroup
from transformers import PretrainedConfig

from vllm.executor.executor_base import ExecutorBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.loader import BaseModelLoader

ConfigType = type[DataclassInstance]
HfOverrides = Union[dict[str, Any], Callable[[PretrainedConfig],
PretrainedConfig]]
else:
QuantizationConfig = None
HfOverrides = None
ConfigType = type

me_quant = LazyLoader("model_executor", globals(),
"vllm.model_executor.layers.quantization")
me_models = LazyLoader("model_executor", globals(),
"vllm.model_executor.models")
logger = init_logger(__name__)

ConfigT = TypeVar("ConfigT", bound=ConfigType)
Expand Down Expand Up @@ -89,9 +94,6 @@
for task in tasks
}

HfOverrides = Union[dict[str, Any], Callable[[PretrainedConfig],
PretrainedConfig]]


class SupportsHash(Protocol):

Expand Down Expand Up @@ -365,7 +367,7 @@ def __init__(
mm_processor_kwargs: Optional[dict[str, Any]] = None,
disable_mm_preprocessor_cache: bool = False,
override_neuron_config: Optional[dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None,
override_pooler_config: Optional[PoolerConfig] = None,
logits_processor_pattern: Optional[str] = None,
generation_config: str = "auto",
enable_sleep_mode: bool = False,
Expand Down Expand Up @@ -548,7 +550,7 @@ def __init__(

@property
def registry(self):
return ModelRegistry
return me_models.ModelRegistry

@property
def architectures(self) -> list[str]:
Expand Down Expand Up @@ -581,7 +583,7 @@ def maybe_pull_model_tokenizer_for_s3(self, model: str,

def _init_multimodal_config(
self, limit_mm_per_prompt: Optional[dict[str, int]]
) -> Optional["MultiModalConfig"]:
) -> Optional[MultiModalConfig]:
if self.registry.is_multimodal_model(self.architectures):
return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {})

Expand All @@ -597,8 +599,8 @@ def _get_encoder_config(self):

def _init_pooler_config(
self,
override_pooler_config: Optional["PoolerConfig"],
) -> Optional["PoolerConfig"]:
override_pooler_config: Optional[PoolerConfig],
) -> Optional[PoolerConfig]:

if self.runner_type == "pooling":
user_config = override_pooler_config or PoolerConfig()
Expand Down Expand Up @@ -749,7 +751,8 @@ def _parse_quant_hf_config(self):
return quant_cfg

def _verify_quantization(self) -> None:
supported_quantization = QUANTIZATION_METHODS
supported_quantization = me_quant.QUANTIZATION_METHODS

optimized_quantization_methods = [
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
Expand All @@ -766,8 +769,8 @@ def _verify_quantization(self) -> None:
quant_method = quant_cfg.get("quant_method", "").lower()

# Detect which checkpoint is it
for name in QUANTIZATION_METHODS:
method = get_quantization_config(name)
for name in me_quant.QUANTIZATION_METHODS:
method = me_quant.get_quantization_config(name)
quantization_override = method.override_quantization_method(
quant_cfg, self.quantization)
if quantization_override:
Expand Down Expand Up @@ -799,6 +802,8 @@ def _verify_quantization(self) -> None:
"non-quantized models.", self.quantization)

def _verify_cuda_graph(self) -> None:
from vllm.platforms import current_platform

if self.max_seq_len_to_capture is None:
self.max_seq_len_to_capture = self.max_model_len
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
Expand Down Expand Up @@ -885,7 +890,7 @@ def verify_async_output_proc(self, parallel_config, speculative_config,

def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
parallel_config: ParallelConfig,
) -> None:

if parallel_config.distributed_executor_backend == "external_launcher":
Expand Down Expand Up @@ -1038,7 +1043,7 @@ def get_total_num_kv_heads(self) -> int:
# equal to the number of attention heads.
return self.hf_text_config.num_attention_heads

def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
def get_num_kv_heads(self, parallel_config: ParallelConfig) -> int:
"""Returns the number of KV heads per GPU."""
if self.use_mla:
# When using MLA during decode it becomes MQA
Expand All @@ -1052,13 +1057,12 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
return max(1,
total_num_kv_heads // parallel_config.tensor_parallel_size)

def get_num_attention_heads(self,
parallel_config: "ParallelConfig") -> int:
def get_num_attention_heads(self, parallel_config: ParallelConfig) -> int:
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
return num_heads // parallel_config.tensor_parallel_size

def get_layers_start_end_indices(
self, parallel_config: "ParallelConfig") -> tuple[int, int]:
self, parallel_config: ParallelConfig) -> tuple[int, int]:
from vllm.distributed.utils import get_pp_indices
if self.hf_text_config.model_type == "deepseek_mtp":
total_num_hidden_layers = getattr(self.hf_text_config,
Expand All @@ -1073,13 +1077,13 @@ def get_layers_start_end_indices(
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
return start, end

def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
def get_num_layers(self, parallel_config: ParallelConfig) -> int:
start, end = self.get_layers_start_end_indices(parallel_config)
return end - start

def get_num_layers_by_block_type(
self,
parallel_config: "ParallelConfig",
parallel_config: ParallelConfig,
block_type: LayerBlockType = LayerBlockType.attention,
) -> int:
# This function relies on 'layers_block_type' in hf_config,
Expand Down Expand Up @@ -1132,7 +1136,7 @@ def get_num_layers_by_block_type(

return sum(t == 1 for t in attn_type_list[start:end])

def get_multimodal_config(self) -> "MultiModalConfig":
def get_multimodal_config(self) -> MultiModalConfig:
"""
Get the multimodal configuration of the model.

Expand Down Expand Up @@ -1241,7 +1245,7 @@ def runner_type(self) -> RunnerType:
@property
def is_v1_compatible(self) -> bool:
architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.is_v1_compatible(architectures)
return me_models.ModelRegistry.is_v1_compatible(architectures)

@property
def is_matryoshka(self) -> bool:
Expand Down Expand Up @@ -1392,7 +1396,7 @@ def _verify_prefix_caching(self) -> None:

def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
parallel_config: ParallelConfig,
) -> None:
total_cpu_memory = get_cpu_memory()
# FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
Expand Down Expand Up @@ -1460,7 +1464,7 @@ class LoadConfig:
"""Configuration for loading the model weights."""

load_format: Union[str, LoadFormat,
"BaseModelLoader"] = LoadFormat.AUTO.value
BaseModelLoader] = LoadFormat.AUTO.value
"""The format of the model weights to load:\n
- "auto" will try to load the weights in the safetensors format and fall
back to the pytorch bin format if safetensors format is not available.\n
Expand Down Expand Up @@ -1582,11 +1586,11 @@ def data_parallel_rank_local(self, value: int) -> None:
ray_workers_use_nsight: bool = False
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""

placement_group: Optional["PlacementGroup"] = None
placement_group: Optional[PlacementGroup] = None
"""ray distributed model workers placement group."""

distributed_executor_backend: Optional[Union[DistributedExecutorBackend,
type["ExecutorBase"]]] = None
type[ExecutorBase]]] = None
"""Backend to use for distributed model
workers, either "ray" or "mp" (multiprocessing). If the product
of pipeline_parallel_size and tensor_parallel_size is less than
Expand Down Expand Up @@ -1629,7 +1633,7 @@ def get_next_dp_init_port(self) -> int:
self.data_parallel_master_port += 1
return answer

def stateless_init_dp_group(self) -> "ProcessGroup":
def stateless_init_dp_group(self) -> ProcessGroup:
from vllm.distributed.utils import (
stateless_init_torch_distributed_process_group)

Expand All @@ -1644,7 +1648,7 @@ def stateless_init_dp_group(self) -> "ProcessGroup":
return dp_group

@staticmethod
def has_unfinished_dp(dp_group: "ProcessGroup",
def has_unfinished_dp(dp_group: ProcessGroup,
has_unfinished: bool) -> bool:
tensor = torch.tensor([has_unfinished],
dtype=torch.int32,
Expand Down Expand Up @@ -2227,7 +2231,7 @@ def compute_hash(self) -> str:
return hash_str

@classmethod
def from_dict(cls, dict_value: dict) -> "SpeculativeConfig":
def from_dict(cls, dict_value: dict) -> SpeculativeConfig:
"""Parse the CLI value for the speculative config."""
return cls(**dict_value)

Expand Down Expand Up @@ -2819,7 +2823,7 @@ def compute_hash(self) -> str:
return hash_str

@staticmethod
def from_json(json_str: str) -> "PoolerConfig":
def from_json(json_str: str) -> PoolerConfig:
return PoolerConfig(**json.loads(json_str))


Expand Down Expand Up @@ -3176,6 +3180,7 @@ def compute_hash(self) -> str:
return hash_str

def __post_init__(self):
from vllm.tracing import is_otel_available, otel_import_error_traceback
if not is_otel_available() and self.otlp_traces_endpoint is not None:
raise ValueError(
"OpenTelemetry is not available. Unable to configure "
Expand Down Expand Up @@ -3239,7 +3244,7 @@ def compute_hash(self) -> str:
return hash_str

@classmethod
def from_cli(cls, cli_value: str) -> "KVTransferConfig":
def from_cli(cls, cli_value: str) -> KVTransferConfig:
"""Parse the CLI value for the kv cache transfer config."""
return KVTransferConfig.model_validate_json(cli_value)

Expand Down Expand Up @@ -3476,7 +3481,7 @@ def __repr__(self) -> str:
__str__ = __repr__

@classmethod
def from_cli(cls, cli_value: str) -> "CompilationConfig":
def from_cli(cls, cli_value: str) -> CompilationConfig:
"""Parse the CLI value for the compilation config."""
if cli_value in ["0", "1", "2", "3"]:
return cls(level=int(cli_value))
Expand Down Expand Up @@ -3528,7 +3533,7 @@ def model_post_init(self, __context: Any) -> None:
self.static_forward_context = {}
self.compilation_time = 0.0

def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
def init_backend(self, vllm_config: VllmConfig) -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.")

Expand Down Expand Up @@ -3744,9 +3749,7 @@ def _get_quantization_config(
"""Get the quantization config."""
from vllm.platforms import current_platform
if model_config.quantization is not None:
from vllm.model_executor.model_loader.weight_utils import (
get_quant_config)
quant_config = get_quant_config(model_config, load_config)
quant_config = me_quant.get_quant_config(model_config, load_config)
capability_tuple = current_platform.get_device_capability()

if capability_tuple is not None:
Expand All @@ -3770,7 +3773,7 @@ def with_hf_config(
self,
hf_config: PretrainedConfig,
architectures: Optional[list[str]] = None,
) -> "VllmConfig":
) -> VllmConfig:
if architectures is not None:
hf_config = copy.deepcopy(hf_config)
hf_config.architectures = architectures
Expand Down
Loading