11# SPDX-License-Identifier: Apache-2.0
22
3+ from __future__ import annotations
4+
35import ast
46import copy
57import enum
2224import torch
2325from pydantic import BaseModel , Field , PrivateAttr
2426from torch .distributed import ProcessGroup , ReduceOp
25- from transformers import PretrainedConfig
2627
2728import vllm .envs as envs
2829from vllm .compilation .inductor_pass import CallableInductorPass , InductorPass
2930from vllm .logger import init_logger
30- from vllm .model_executor .layers .quantization import (QUANTIZATION_METHODS ,
31- get_quantization_config )
32- from vllm .model_executor .models import ModelRegistry
33- from vllm .platforms import CpuArchEnum , current_platform
31+ from vllm .platforms import CpuArchEnum
3432from vllm .sampling_params import GuidedDecodingParams
35- from vllm .tracing import is_otel_available , otel_import_error_traceback
3633from vllm .transformers_utils .config import (
3734 ConfigFormat , get_config , get_hf_image_processor_config ,
3835 get_hf_text_config , get_pooling_config ,
3936 get_sentence_transformer_tokenizer_config , is_encoder_decoder ,
4037 try_get_generation_config , uses_mrope )
4138from vllm .transformers_utils .s3_utils import S3Model
4239from vllm .transformers_utils .utils import is_s3 , maybe_model_redirect
43- from vllm .utils import (GiB_bytes , LayerBlockType , cuda_device_count_stateless ,
44- get_cpu_memory , get_open_port , is_torch_equal_or_newer ,
45- random_uuid , resolve_obj_by_qualname )
40+ from vllm .utils import (GiB_bytes , LayerBlockType , LazyLoader ,
41+ cuda_device_count_stateless , get_cpu_memory ,
42+ get_open_port , is_torch_equal_or_newer , random_uuid ,
43+ resolve_obj_by_qualname )
4644
4745if TYPE_CHECKING :
4846 from _typeshed import DataclassInstance
4947 from ray .util .placement_group import PlacementGroup
48+ from transformers import PretrainedConfig
5049
5150 from vllm .executor .executor_base import ExecutorBase
5251 from vllm .model_executor .layers .quantization .base_config import (
5352 QuantizationConfig )
5453 from vllm .model_executor .model_loader .loader import BaseModelLoader
5554
5655 ConfigType = type [DataclassInstance ]
56+ HfOverrides = Union [dict [str , Any ], Callable [[PretrainedConfig ],
57+ PretrainedConfig ]]
5758else :
58- QuantizationConfig = None
59+ HfOverrides = None
5960 ConfigType = type
6061
62+ me_quant = LazyLoader ("model_executor" , globals (),
63+ "vllm.model_executor.layers.quantization" )
64+ me_models = LazyLoader ("model_executor" , globals (),
65+ "vllm.model_executor.models" )
6166logger = init_logger (__name__ )
6267
6368ConfigT = TypeVar ("ConfigT" , bound = ConfigType )
8994 for task in tasks
9095}
9196
92- HfOverrides = Union [dict [str , Any ], Callable [[PretrainedConfig ],
93- PretrainedConfig ]]
94-
9597
9698class SupportsHash (Protocol ):
9799
@@ -365,7 +367,7 @@ def __init__(
365367 mm_processor_kwargs : Optional [dict [str , Any ]] = None ,
366368 disable_mm_preprocessor_cache : bool = False ,
367369 override_neuron_config : Optional [dict [str , Any ]] = None ,
368- override_pooler_config : Optional [" PoolerConfig" ] = None ,
370+ override_pooler_config : Optional [PoolerConfig ] = None ,
369371 logits_processor_pattern : Optional [str ] = None ,
370372 generation_config : str = "auto" ,
371373 enable_sleep_mode : bool = False ,
@@ -548,7 +550,7 @@ def __init__(
548550
549551 @property
550552 def registry (self ):
551- return ModelRegistry
553+ return me_models . ModelRegistry
552554
553555 @property
554556 def architectures (self ) -> list [str ]:
@@ -581,7 +583,7 @@ def maybe_pull_model_tokenizer_for_s3(self, model: str,
581583
582584 def _init_multimodal_config (
583585 self , limit_mm_per_prompt : Optional [dict [str , int ]]
584- ) -> Optional [" MultiModalConfig" ]:
586+ ) -> Optional [MultiModalConfig ]:
585587 if self .registry .is_multimodal_model (self .architectures ):
586588 return MultiModalConfig (limit_per_prompt = limit_mm_per_prompt or {})
587589
@@ -597,8 +599,8 @@ def _get_encoder_config(self):
597599
598600 def _init_pooler_config (
599601 self ,
600- override_pooler_config : Optional [" PoolerConfig" ],
601- ) -> Optional [" PoolerConfig" ]:
602+ override_pooler_config : Optional [PoolerConfig ],
603+ ) -> Optional [PoolerConfig ]:
602604
603605 if self .runner_type == "pooling" :
604606 user_config = override_pooler_config or PoolerConfig ()
@@ -749,7 +751,8 @@ def _parse_quant_hf_config(self):
749751 return quant_cfg
750752
751753 def _verify_quantization (self ) -> None :
752- supported_quantization = QUANTIZATION_METHODS
754+ supported_quantization = me_quant .QUANTIZATION_METHODS
755+
753756 optimized_quantization_methods = [
754757 "fp8" , "marlin" , "modelopt" , "gptq_marlin_24" , "gptq_marlin" ,
755758 "awq_marlin" , "fbgemm_fp8" , "compressed_tensors" ,
@@ -766,8 +769,8 @@ def _verify_quantization(self) -> None:
766769 quant_method = quant_cfg .get ("quant_method" , "" ).lower ()
767770
768771 # Detect which checkpoint is it
769- for name in QUANTIZATION_METHODS :
770- method = get_quantization_config (name )
772+ for name in me_quant . QUANTIZATION_METHODS :
773+ method = me_quant . get_quantization_config (name )
771774 quantization_override = method .override_quantization_method (
772775 quant_cfg , self .quantization )
773776 if quantization_override :
@@ -799,6 +802,8 @@ def _verify_quantization(self) -> None:
799802 "non-quantized models." , self .quantization )
800803
801804 def _verify_cuda_graph (self ) -> None :
805+ from vllm .platforms import current_platform
806+
802807 if self .max_seq_len_to_capture is None :
803808 self .max_seq_len_to_capture = self .max_model_len
804809 self .max_seq_len_to_capture = min (self .max_seq_len_to_capture ,
@@ -885,7 +890,7 @@ def verify_async_output_proc(self, parallel_config, speculative_config,
885890
886891 def verify_with_parallel_config (
887892 self ,
888- parallel_config : " ParallelConfig" ,
893+ parallel_config : ParallelConfig ,
889894 ) -> None :
890895
891896 if parallel_config .distributed_executor_backend == "external_launcher" :
@@ -1038,7 +1043,7 @@ def get_total_num_kv_heads(self) -> int:
10381043 # equal to the number of attention heads.
10391044 return self .hf_text_config .num_attention_heads
10401045
1041- def get_num_kv_heads (self , parallel_config : " ParallelConfig" ) -> int :
1046+ def get_num_kv_heads (self , parallel_config : ParallelConfig ) -> int :
10421047 """Returns the number of KV heads per GPU."""
10431048 if self .use_mla :
10441049 # When using MLA during decode it becomes MQA
@@ -1052,13 +1057,12 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
10521057 return max (1 ,
10531058 total_num_kv_heads // parallel_config .tensor_parallel_size )
10541059
1055- def get_num_attention_heads (self ,
1056- parallel_config : "ParallelConfig" ) -> int :
1060+ def get_num_attention_heads (self , parallel_config : ParallelConfig ) -> int :
10571061 num_heads = getattr (self .hf_text_config , "num_attention_heads" , 0 )
10581062 return num_heads // parallel_config .tensor_parallel_size
10591063
10601064 def get_layers_start_end_indices (
1061- self , parallel_config : " ParallelConfig" ) -> tuple [int , int ]:
1065+ self , parallel_config : ParallelConfig ) -> tuple [int , int ]:
10621066 from vllm .distributed .utils import get_pp_indices
10631067 if self .hf_text_config .model_type == "deepseek_mtp" :
10641068 total_num_hidden_layers = getattr (self .hf_text_config ,
@@ -1073,13 +1077,13 @@ def get_layers_start_end_indices(
10731077 start , end = get_pp_indices (total_num_hidden_layers , pp_rank , pp_size )
10741078 return start , end
10751079
1076- def get_num_layers (self , parallel_config : " ParallelConfig" ) -> int :
1080+ def get_num_layers (self , parallel_config : ParallelConfig ) -> int :
10771081 start , end = self .get_layers_start_end_indices (parallel_config )
10781082 return end - start
10791083
10801084 def get_num_layers_by_block_type (
10811085 self ,
1082- parallel_config : " ParallelConfig" ,
1086+ parallel_config : ParallelConfig ,
10831087 block_type : LayerBlockType = LayerBlockType .attention ,
10841088 ) -> int :
10851089 # This function relies on 'layers_block_type' in hf_config,
@@ -1132,7 +1136,7 @@ def get_num_layers_by_block_type(
11321136
11331137 return sum (t == 1 for t in attn_type_list [start :end ])
11341138
1135- def get_multimodal_config (self ) -> " MultiModalConfig" :
1139+ def get_multimodal_config (self ) -> MultiModalConfig :
11361140 """
11371141 Get the multimodal configuration of the model.
11381142
@@ -1241,7 +1245,7 @@ def runner_type(self) -> RunnerType:
12411245 @property
12421246 def is_v1_compatible (self ) -> bool :
12431247 architectures = getattr (self .hf_config , "architectures" , [])
1244- return ModelRegistry .is_v1_compatible (architectures )
1248+ return me_models . ModelRegistry .is_v1_compatible (architectures )
12451249
12461250 @property
12471251 def is_matryoshka (self ) -> bool :
@@ -1392,7 +1396,7 @@ def _verify_prefix_caching(self) -> None:
13921396
13931397 def verify_with_parallel_config (
13941398 self ,
1395- parallel_config : " ParallelConfig" ,
1399+ parallel_config : ParallelConfig ,
13961400 ) -> None :
13971401 total_cpu_memory = get_cpu_memory ()
13981402 # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
@@ -1460,7 +1464,7 @@ class LoadConfig:
14601464 """Configuration for loading the model weights."""
14611465
14621466 load_format : Union [str , LoadFormat ,
1463- " BaseModelLoader" ] = LoadFormat .AUTO .value
1467+ BaseModelLoader ] = LoadFormat .AUTO .value
14641468 """The format of the model weights to load:\n
14651469 - "auto" will try to load the weights in the safetensors format and fall
14661470 back to the pytorch bin format if safetensors format is not available.\n
@@ -1582,11 +1586,11 @@ def data_parallel_rank_local(self, value: int) -> None:
15821586 ray_workers_use_nsight : bool = False
15831587 """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
15841588
1585- placement_group : Optional [" PlacementGroup" ] = None
1589+ placement_group : Optional [PlacementGroup ] = None
15861590 """ray distributed model workers placement group."""
15871591
15881592 distributed_executor_backend : Optional [Union [DistributedExecutorBackend ,
1589- type [" ExecutorBase" ]]] = None
1593+ type [ExecutorBase ]]] = None
15901594 """Backend to use for distributed model
15911595 workers, either "ray" or "mp" (multiprocessing). If the product
15921596 of pipeline_parallel_size and tensor_parallel_size is less than
@@ -1629,7 +1633,7 @@ def get_next_dp_init_port(self) -> int:
16291633 self .data_parallel_master_port += 1
16301634 return answer
16311635
1632- def stateless_init_dp_group (self ) -> " ProcessGroup" :
1636+ def stateless_init_dp_group (self ) -> ProcessGroup :
16331637 from vllm .distributed .utils import (
16341638 stateless_init_torch_distributed_process_group )
16351639
@@ -1644,7 +1648,7 @@ def stateless_init_dp_group(self) -> "ProcessGroup":
16441648 return dp_group
16451649
16461650 @staticmethod
1647- def has_unfinished_dp (dp_group : " ProcessGroup" ,
1651+ def has_unfinished_dp (dp_group : ProcessGroup ,
16481652 has_unfinished : bool ) -> bool :
16491653 tensor = torch .tensor ([has_unfinished ],
16501654 dtype = torch .int32 ,
@@ -2227,7 +2231,7 @@ def compute_hash(self) -> str:
22272231 return hash_str
22282232
22292233 @classmethod
2230- def from_dict (cls , dict_value : dict ) -> " SpeculativeConfig" :
2234+ def from_dict (cls , dict_value : dict ) -> SpeculativeConfig :
22312235 """Parse the CLI value for the speculative config."""
22322236 return cls (** dict_value )
22332237
@@ -2819,7 +2823,7 @@ def compute_hash(self) -> str:
28192823 return hash_str
28202824
28212825 @staticmethod
2822- def from_json (json_str : str ) -> " PoolerConfig" :
2826+ def from_json (json_str : str ) -> PoolerConfig :
28232827 return PoolerConfig (** json .loads (json_str ))
28242828
28252829
@@ -3176,6 +3180,7 @@ def compute_hash(self) -> str:
31763180 return hash_str
31773181
31783182 def __post_init__ (self ):
3183+ from vllm .tracing import is_otel_available , otel_import_error_traceback
31793184 if not is_otel_available () and self .otlp_traces_endpoint is not None :
31803185 raise ValueError (
31813186 "OpenTelemetry is not available. Unable to configure "
@@ -3239,7 +3244,7 @@ def compute_hash(self) -> str:
32393244 return hash_str
32403245
32413246 @classmethod
3242- def from_cli (cls , cli_value : str ) -> " KVTransferConfig" :
3247+ def from_cli (cls , cli_value : str ) -> KVTransferConfig :
32433248 """Parse the CLI value for the kv cache transfer config."""
32443249 return KVTransferConfig .model_validate_json (cli_value )
32453250
@@ -3476,7 +3481,7 @@ def __repr__(self) -> str:
34763481 __str__ = __repr__
34773482
34783483 @classmethod
3479- def from_cli (cls , cli_value : str ) -> " CompilationConfig" :
3484+ def from_cli (cls , cli_value : str ) -> CompilationConfig :
34803485 """Parse the CLI value for the compilation config."""
34813486 if cli_value in ["0" , "1" , "2" , "3" ]:
34823487 return cls (level = int (cli_value ))
@@ -3528,7 +3533,7 @@ def model_post_init(self, __context: Any) -> None:
35283533 self .static_forward_context = {}
35293534 self .compilation_time = 0.0
35303535
3531- def init_backend (self , vllm_config : " VllmConfig" ) -> Union [str , Callable ]:
3536+ def init_backend (self , vllm_config : VllmConfig ) -> Union [str , Callable ]:
35323537 if self .level == CompilationLevel .NO_COMPILATION :
35333538 raise ValueError ("No compilation level is set." )
35343539
@@ -3744,9 +3749,7 @@ def _get_quantization_config(
37443749 """Get the quantization config."""
37453750 from vllm .platforms import current_platform
37463751 if model_config .quantization is not None :
3747- from vllm .model_executor .model_loader .weight_utils import (
3748- get_quant_config )
3749- quant_config = get_quant_config (model_config , load_config )
3752+ quant_config = me_quant .get_quant_config (model_config , load_config )
37503753 capability_tuple = current_platform .get_device_capability ()
37513754
37523755 if capability_tuple is not None :
@@ -3770,7 +3773,7 @@ def with_hf_config(
37703773 self ,
37713774 hf_config : PretrainedConfig ,
37723775 architectures : Optional [list [str ]] = None ,
3773- ) -> " VllmConfig" :
3776+ ) -> VllmConfig :
37743777 if architectures is not None :
37753778 hf_config = copy .deepcopy (hf_config )
37763779 hf_config .architectures = architectures
0 commit comments