Skip to content
11 changes: 11 additions & 0 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,17 @@ def __post_init__(
)

self.hf_config = hf_config

# Ensure Gemma2 configs have hidden_act for backward compatibility.
# GGUF configs may only have hidden_activation; model code expects both.
if (
hasattr(hf_config, "model_type")
and hf_config.model_type == "gemma2"
and not hasattr(hf_config, "hidden_act")
and hasattr(hf_config, "hidden_activation")
):
hf_config.hidden_act = hf_config.hidden_activation

if dict_overrides:
self._apply_dict_overrides(hf_config, dict_overrides)
self.hf_text_config = get_hf_text_config(self.hf_config)
Expand Down
209 changes: 107 additions & 102 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,11 +389,39 @@ def _get_quantization_config(
)
supported_dtypes = quant_config.get_supported_act_dtypes()
if model_config.dtype not in supported_dtypes:
raise ValueError(
f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}"
)
# Handle dtype conflict between model restrictions and
# quantization restrictions (e.g., Gemma3 GGUF on Blackwell
# where Gemma3 blocks float16 and GGUF blocks bfloat16)
from vllm.config.model import _is_valid_dtype

model_type = getattr(model_config.hf_config, "model_type", None)
compatible_dtypes = [
d
for d in supported_dtypes
if model_type is None or _is_valid_dtype(model_type, d)
]
if compatible_dtypes:
# Prefer float16 > bfloat16 > float32 for performance
dtype_preference = [torch.float16, torch.bfloat16, torch.float32]
for preferred in dtype_preference:
if preferred in compatible_dtypes:
logger.warning(
"dtype=%s is not supported for quantization "
"method %s with model type %s. "
"Automatically selecting %s as compatible dtype.",
model_config.dtype,
model_config.quantization,
model_type,
preferred,
)
model_config.dtype = preferred
break
else:
raise ValueError(
f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}"
)
quant_config.maybe_update_config(model_config.model)
return quant_config
return None
Expand Down Expand Up @@ -666,9 +694,8 @@ def has_blocked_weights():

default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
self._apply_optimization_level_defaults(default_config)

if (
self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
):
logger.info(
Expand All @@ -693,29 +720,22 @@ def has_blocked_weights():

if current_platform.support_static_graph_mode():
# if cudagraph_mode has full cudagraphs, we need to check support
if model_config := self.model_config:
if (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
and model_config.pooler_config is not None
):
if (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
and self.model_config is not None
):
if self.model_config.pooler_config is not None:
logger.warning_once(
"Pooling models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
elif (
model_config.is_encoder_decoder
and self.compilation_config.cudagraph_mode
not in (CUDAGraphMode.NONE, CUDAGraphMode.FULL_DECODE_ONLY)
):
logger.info_once(
"Encoder-decoder models do not support %s. "
"Overriding cudagraph_mode to FULL_DECODE_ONLY.",
self.compilation_config.cudagraph_mode.name,
)
self.compilation_config.cudagraph_mode = (
CUDAGraphMode.FULL_DECODE_ONLY
elif self.model_config.is_encoder_decoder:
logger.warning_once(
"Encoder-decoder models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE

# disable cudagraph when enforce eager execution
if self.model_config is not None and self.model_config.enforce_eager:
Expand Down Expand Up @@ -750,17 +770,27 @@ def has_blocked_weights():
# TODO: Move after https://github.com/vllm-project/vllm/pull/26847 lands
self._set_compile_ranges()

if (
self.model_config
and self.model_config.architecture == "WhisperForConditionalGeneration"
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
):
logger.warning(
"Whisper is known to have issues with "
"forked workers. If startup is hanging, "
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
"to 'spawn'."
if self.model_config and self.model_config.is_encoder_decoder:
from vllm.multimodal import MULTIMODAL_REGISTRY

self.scheduler_config.max_num_encoder_input_tokens = (
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
)
logger.debug(
"Encoder-decoder model detected: setting "
"`max_num_encoder_input_tokens` to encoder length (%s)",
self.scheduler_config.max_num_encoder_input_tokens,
)
if (
self.model_config.architecture == "WhisperForConditionalGeneration"
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
):
logger.warning(
"Whisper is known to have issues with "
"forked workers. If startup is hanging, "
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
"to 'spawn'."
)

if (
self.kv_events_config is not None
Expand Down Expand Up @@ -810,6 +840,11 @@ def has_blocked_weights():
f"({self.parallel_config.cp_kv_cache_interleave_size})."
)

assert (
self.parallel_config.cp_kv_cache_interleave_size == 1
or self.speculative_config is None
), "MTP with cp_kv_cache_interleave_size > 1 is not supported now."

# Do this after all the updates to compilation_config.mode
self.compilation_config.set_splitting_ops_for_v1(
all2all_backend=self.parallel_config.all2all_backend,
Expand Down Expand Up @@ -887,75 +922,51 @@ def has_blocked_weights():
if not self.instance_id:
self.instance_id = random_uuid()[:5]

# Hybrid KV cache manager (HMA) runtime rules:
# - Explicit enable (--no-disable-kv-cache-manager): error if runtime
# disables it
# - No preference: auto-disable for unsupported features (e.g. kv connector)
# - Explicit disable (--disable-kv-cache-manager): always respect it
need_disable_hybrid_kv_cache_manager = False
# logger should only print warning message for hybrid models. As we
# can't know whether the model is hybrid or not now, so we don't log
# warning message here and will log it later.
if not current_platform.support_hybrid_kv_cache():
# Hybrid KV cache manager is not supported on non-GPU platforms.
need_disable_hybrid_kv_cache_manager = True
if self.kv_events_config is not None:
# Hybrid KV cache manager is not compatible with KV events.
need_disable_hybrid_kv_cache_manager = True
if (
self.model_config is not None
and self.model_config.attention_chunk_size is not None
):
if (
self.speculative_config is not None
and self.speculative_config.use_eagle()
):
# Hybrid KV cache manager is not yet supported with chunked
# local attention + eagle.
need_disable_hybrid_kv_cache_manager = True
elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE:
logger.warning(
"There is a latency regression when using chunked local"
" attention with the hybrid KV cache manager. Disabling"
" it, by default. To enable it, set the environment "
"VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1."
)
# Hybrid KV cache manager is not yet supported with chunked
# local attention.
need_disable_hybrid_kv_cache_manager = True

if self.scheduler_config.disable_hybrid_kv_cache_manager is None:
# Default to disable HMA, but only if the user didn't express a preference.
if not self.scheduler_config.disable_hybrid_kv_cache_manager:
# logger should only print warning message for hybrid models. As we
# can't know whether the model is hybrid or not now, so we don't log
# warning message here and will log it later.
if not current_platform.support_hybrid_kv_cache():
# Hybrid KV cache manager is not supported on non-GPU platforms.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
if self.kv_transfer_config is not None:
# NOTE(Kuntai): turn HMA off for connector unless specifically enabled.
need_disable_hybrid_kv_cache_manager = True
# NOTE(Kuntai): turn HMA off for connector for now.
# TODO(Kuntai): have a more elegent solution to check and
# turn off HMA for connector that does not support HMA.
logger.warning(
"Turning off hybrid kv cache manager because "
"`--kv-transfer-config` is set. This will reduce the "
"performance of vLLM on LLMs with sliding window attention "
"or Mamba attention. If you are a developer of kv connector"
", please consider supporting hybrid kv cache manager for "
"your connector by making sure your connector is a subclass"
" of `SupportsHMA` defined in kv_connector/v1/base.py and"
" use --no-disable-hybrid-kv-cache-manager to start vLLM."
" of `SupportsHMA` defined in kv_connector/v1/base.py."
)
self.scheduler_config.disable_hybrid_kv_cache_manager = (
need_disable_hybrid_kv_cache_manager
)
elif (
self.scheduler_config.disable_hybrid_kv_cache_manager is False
and need_disable_hybrid_kv_cache_manager
):
raise ValueError(
"Hybrid KV cache manager was explicitly enabled but is not "
"supported in this configuration. Consider omitting the "
"--no-disable-hybrid-kv-cache-manager flag to let vLLM decide"
" automatically."
)

if self.scheduler_config.disable_hybrid_kv_cache_manager is None:
# Default to enable HMA if not explicitly disabled by user or logic above.
self.scheduler_config.disable_hybrid_kv_cache_manager = False
self.scheduler_config.disable_hybrid_kv_cache_manager = True
if self.kv_events_config is not None:
# Hybrid KV cache manager is not compatible with KV events.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
if (
self.model_config is not None
and self.model_config.attention_chunk_size is not None
):
if (
self.speculative_config is not None
and self.speculative_config.use_eagle()
):
# Hybrid KV cache manager is not yet supported with chunked
# local attention + eagle.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE:
logger.warning(
"There is a latency regression when using chunked local"
" attention with the hybrid KV cache manager. Disabling"
" it, by default. To enable it, set the environment "
"VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1."
)
# Hybrid KV cache manager is not yet supported with chunked
# local attention.
self.scheduler_config.disable_hybrid_kv_cache_manager = True

if self.compilation_config.debug_dump_path:
self.compilation_config.debug_dump_path = (
Expand Down Expand Up @@ -1023,7 +1034,7 @@ def _set_cudagraph_sizes(self):
max_graph_size = min(max_num_seqs * 2, 512)
# 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16
# up to max_graph_size
cudagraph_capture_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
cuda_graph_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
range(256, max_graph_size + 1, 16))

In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
Expand Down Expand Up @@ -1064,14 +1075,8 @@ def _set_cudagraph_sizes(self):
self.compilation_config.max_cudagraph_capture_size
)
if max_cudagraph_capture_size is None:
decode_query_len = 1
if (
self.speculative_config
and self.speculative_config.num_speculative_tokens
):
decode_query_len += self.speculative_config.num_speculative_tokens
max_cudagraph_capture_size = min(
self.scheduler_config.max_num_seqs * decode_query_len * 2, 512
self.scheduler_config.max_num_seqs * 2, 512
)
max_num_tokens = self.scheduler_config.max_num_batched_tokens
max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size)
Expand Down
7 changes: 5 additions & 2 deletions vllm/model_executor/layers/quantization/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,12 @@ def get_name(self) -> QuantizationMethods:

def get_supported_act_dtypes(self) -> list[torch.dtype]:
# GGUF dequantization kernels use half precision (fp16) internally.
# bfloat16 has precision issues on Blackwell devices.
# bfloat16 has precision issues on SM 10.0+ devices (Blackwell).
if current_platform.has_device_capability(100):
logger.warning_once("GGUF has precision issues with bfloat16 on Blackwell.")
logger.warning_once(
"GGUF has precision issues with bfloat16 on Blackwell (SM 10.0+). "
"bfloat16 is unavailable."
)
return [torch.half, torch.float32]
return [torch.half, torch.bfloat16, torch.float32]

Expand Down
5 changes: 5 additions & 0 deletions vllm/model_executor/model_loader/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ def _get_gguf_weights_map(self, model_config: ModelConfig):
)
)

# For models with tied word embeddings, lm_head.weight is initialized
# from embed_tokens and doesn't need to be mapped from GGUF file
if getattr(config, "tie_word_embeddings", False):
sideload_params.append(re.compile(r"lm_head\.weight"))

arch = None
for key, value in gguf.MODEL_ARCH_NAMES.items():
if value == model_type:
Expand Down
6 changes: 6 additions & 0 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
Expand Down Expand Up @@ -365,6 +367,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
continue
if is_pp_missing_parameter(name, self):
continue
# Skip parameters not in the model (e.g., GGUF quantization
# metadata like qweight_type for embeddings)
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
Expand Down
28 changes: 20 additions & 8 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,14 +505,26 @@ def maybe_override_with_speculators(
else:
gguf_model_repo = None
kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE
config_dict, _ = PretrainedConfig.get_config_dict(
model if gguf_model_repo is None else gguf_model_repo,
revision=revision,
trust_remote_code=trust_remote_code,
token=_get_hf_token(),
**kwargs,
)
speculators_config = config_dict.get("speculators_config")
try:
config_dict, _ = PretrainedConfig.get_config_dict(
model if gguf_model_repo is None else gguf_model_repo,
revision=revision,
trust_remote_code=trust_remote_code,
token=_get_hf_token(),
**kwargs,
)
speculators_config = config_dict.get("speculators_config")
except OSError as e:
# GGUF models without config.json cannot have speculators config
# (speculators is defined in config.json), so skip gracefully.
# We only suppress "file not found" errors, not other OS errors like
# permission denied.
is_file_not_found = isinstance(
e, FileNotFoundError
) or "does not appear to have a file named" in str(e)
if gguf_model_repo is not None and is_file_not_found:
return model, tokenizer, vllm_speculative_config
raise

if speculators_config is None:
# No speculators config found, return original values
Expand Down
Loading