Skip to content

[CI] Enable test_initialization to run on V1 #16736

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 23, 2025
5 changes: 1 addition & 4 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -472,10 +472,7 @@ steps:
- pytest -v -s models/test_registry.py
- pytest -v -s models/test_utils.py
- pytest -v -s models/test_vision.py
# V1 Test: https://github.com/vllm-project/vllm/issues/14531
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4'
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2'
- pytest -v -s models/test_initialization.py

- label: Language Models Test (Standard)
mirror_hardwares: [amdexperimental]
Expand Down
32 changes: 24 additions & 8 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,18 @@ class _HfExamplesInfo:
trust_remote_code: bool = False
"""The ``trust_remote_code`` level required to load the model."""

v0_only: bool = False
"""The model is only available with the vLLM V0 engine."""

hf_overrides: dict[str, Any] = field(default_factory=dict)
"""The ``hf_overrides`` required to load the model."""

max_model_len: Optional[int] = None
"""
The maximum model length to use for this model. Some models default to a
length that is too large to fit into memory in CI.
"""

def check_transformers_version(
self,
*,
Expand Down Expand Up @@ -215,10 +224,11 @@ def check_available_online(
"OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat",
trust_remote_code=True),
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2", v0_only=True),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ - do we know why these are failing? Is it the head_dim?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is head_dim for most of these. The only exceptions IIRC are Kimi and Pixtral, which have tracking issue

"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
"Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct",
trust_remote_code=True),
trust_remote_code=True,
v0_only=True),
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
trust_remote_code=True),
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
Expand All @@ -234,7 +244,8 @@ def check_available_online(
is_available_online=False),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501
is_available_online=False),
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t",
v0_only=True),
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"),
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
Expand Down Expand Up @@ -303,7 +314,8 @@ def check_available_online(
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"), # noqa: E501
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501
extras={"6b": "Salesforce/blip2-opt-6.7b"}), # noqa: E501
extras={"6b": "Salesforce/blip2-opt-6.7b"}, # noqa: E501
v0_only=True),
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501
Expand All @@ -328,9 +340,11 @@ def check_available_online(
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501
"KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
trust_remote_code=True),
trust_remote_code=True,
v0_only=True),
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
min_transformers_version="4.51"),
min_transformers_version="4.51",
max_model_len=10240),
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501
"mistral-fp8": "nm-testing/pixtral-12b-FP8-dynamic"}), # noqa: E501
Expand All @@ -349,7 +363,8 @@ def check_available_online(
extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501
trust_remote_code=True),
"MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo("MiniMaxAI/MiniMax-VL-01", # noqa: E501
trust_remote_code=True),
trust_remote_code=True,
v0_only=True),
"Mistral3ForConditionalGeneration": _HfExamplesInfo("mistralai/Mistral-Small-3.1-24B-Instruct-2503", # noqa: E501
extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}), # noqa: E501
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
Expand All @@ -372,7 +387,8 @@ def check_available_online(
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
trust_remote_code=True),
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501
tokenizer_mode="mistral"),
tokenizer_mode="mistral",
v0_only=True),
"QwenVLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen-VL",
extras={"chat": "Qwen/Qwen-VL-Chat"}, # noqa: E501
trust_remote_code=True,
Expand Down
17 changes: 13 additions & 4 deletions tests/models/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@


@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
def test_can_initialize(model_arch):
def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")

# Avoid OOM
# Avoid OOM and reduce initialization time by only using 1 layer
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
hf_config.update(model_info.hf_overrides)

Expand All @@ -34,6 +34,12 @@ def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
"num_local_experts": 2,
})

if hasattr(hf_config, "vision_config"):
hf_config.vision_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
})

return hf_config

# Avoid calling model.forward()
Expand All @@ -46,7 +52,7 @@ def _initialize_kv_caches_v1(self, vllm_config):
scheduler_kv_cache_config = get_kv_cache_config(
vllm_config,
kv_cache_specs[0],
20 * GiB_bytes,
10 * GiB_bytes,
)

# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
Expand All @@ -55,7 +61,9 @@ def _initialize_kv_caches_v1(self, vllm_config):
with (patch.object(V0LLMEngine, "_initialize_kv_caches",
_initialize_kv_caches_v0),
patch.object(V1EngineCore, "_initialize_kv_caches",
_initialize_kv_caches_v1)):
_initialize_kv_caches_v1), monkeypatch.context() as m):
if model_info.v0_only:
m.setenv("VLLM_USE_V1", "0")
LLM(
model_info.default,
tokenizer=model_info.tokenizer,
Expand All @@ -65,6 +73,7 @@ def _initialize_kv_caches_v1(self, vllm_config):
"num_speculative_tokens": 1,
} if model_info.speculative_model else None,
trust_remote_code=model_info.trust_remote_code,
max_model_len=model_info.max_model_len,
load_format="dummy",
hf_overrides=hf_overrides,
)
30 changes: 7 additions & 23 deletions vllm/model_executor/models/grok1.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torch.nn.functional as F
from torch import nn

from vllm.attention import Attention, AttentionMetadata
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Expand Down Expand Up @@ -182,25 +182,20 @@ def __init__(
quant_config=quant_config,
logits_soft_cap=attn_logits_soft_cap,
prefix=f"{prefix}.attn")
self.attn_multiplier = getattr(self.config, "attn_output_multiplier",
1.0) if self.config else 1.0

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)

# Apply attention output multiplier if specified in config
attn_multiplier = getattr(self.config, "attn_output_multiplier",
None) if self.config else None
if attn_multiplier is not None:
output = output * attn_multiplier
output *= self.attn_multiplier
return output


Expand Down Expand Up @@ -261,8 +256,6 @@ def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
Expand All @@ -276,8 +269,6 @@ def forward(
hidden_states = self.attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)

# Post attention normalization
Expand Down Expand Up @@ -341,8 +332,6 @@ def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: list[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
Expand All @@ -359,9 +348,7 @@ def forward(

for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
hidden_states, residual = layer(positions, hidden_states, residual)

if not get_pp_group().is_last_rank:
return IntermediateTensors({
Expand Down Expand Up @@ -529,13 +516,10 @@ def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: list[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states

Expand Down
15 changes: 9 additions & 6 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2794,14 +2794,17 @@ def wrapper(*args, **kwargs):

# Only relevant for models using ALiBi (e.g, MPT)
def check_use_alibi(model_config: ModelConfig) -> bool:
return (getattr(model_config.hf_text_config, "alibi", False) # Falcon
cfg = model_config.hf_text_config
return (getattr(cfg, "alibi", False) # Falcon
or ("BloomForCausalLM" in getattr(model_config.hf_config,
"architectures", [])) # Bloom
or getattr(model_config.hf_text_config, "position_encoding_type",
"") == "alibi" # codellm_1b_alibi
or
(hasattr(model_config.hf_text_config, "attn_config") # MPT
and model_config.hf_text_config.attn_config.get("alibi", False)))
or getattr(cfg, "position_encoding_type", "") ==
"alibi" # codellm_1b_alibi
or (hasattr(cfg, "attn_config") # MPT
and ((isinstance(cfg.attn_config, dict)
and cfg.attn_config.get("alibi", False)) or
(not isinstance(cfg.attn_config, dict)
and getattr(cfg.attn_config, "alibi", False)))))


def sha256(input) -> int:
Expand Down