Skip to content

Commit f88ff73

Browse files
committed
Revert "Fix after #16229, mm (#286)"
This reverts commit 18ead2d.
1 parent 11a34e7 commit f88ff73

File tree

2 files changed

+27
-50
lines changed

2 files changed

+27
-50
lines changed

tests/full_tests/ci_gsm8k_tests.sh

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@ set -e
1010
VLLM_GAUDI_PREFIX=${VLLM_GAUDI_PREFIX:-"vllm-gaudi"}
1111
echo $VLLM_GAUDI_PREFIX
1212

13-
# Gemma3 with image input
14-
run_gemma3_test() {
15-
echo "➡️ Testing gemma-3-4b-it..."
16-
VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 VLLM_USE_V1=1 python -u "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/generation_mm.py" --model-card-path "${VLLM_GAUDI_PREFIX}/tests/full_tests/model_cards/gemma-3-4b-it.yaml"
17-
echo "✅ Test with multimodal-support with gemma-3-4b-it passed."
18-
}
13+
# NOTE(Chendi): temporarily disable gemma3 test due to upstream change.
14+
# Expect fixing from https://github.com/vllm-project/vllm-gaudi/pull/286
15+
# # Gemma3 with image input
16+
# run_gemma3_test() {
17+
# echo "➡️ Testing gemma-3-4b-it..."
18+
# VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 VLLM_USE_V1=1 python -u "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/generation_mm.py" --model-card-path "${VLLM_GAUDI_PREFIX}/tests/full_tests/model_cards/gemma-3-4b-it.yaml"
19+
# echo "✅ Test with multimodal-support with gemma-3-4b-it passed."
20+
# }
1921

2022
# Basic model test
2123
run_basic_model_test() {
@@ -182,13 +184,15 @@ run_gsm8k_qwen3_30b_test() {
182184
echo "✅ Test with QWEN3-30B-A3B passed."
183185
}
184186

185-
# Multimodal-support with qwen2.5-vl
186-
run_qwen2_5_vl_test() {
187-
echo "➡️ Testing Qwen2.5-VL-7B..."
188-
VLLM_SKIP_WARMUP=true VLLM_CONTIGUOUS_PA=False PT_HPU_LAZY_MODE=1 VLLM_USE_V1=1 \
189-
python -u "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/generation_mm.py" --model-card-path "${VLLM_GAUDI_PREFIX}/tests/full_tests/model_cards/qwen2.5-vl-7b.yaml"
190-
echo "✅ Test with multimodal-support with qwen2.5-vl-7b passed."
191-
}
187+
# NOTE(Chendi): Disabled due to upstream change #16229
188+
# Expect fixing from https://github.com/vllm-project/vllm-gaudi/pull/286
189+
# # Multimodal-support with qwen2.5-vl
190+
# run_qwen2_5_vl_test() {
191+
# echo "➡️ Testing Qwen2.5-VL-7B..."
192+
# VLLM_SKIP_WARMUP=true VLLM_CONTIGUOUS_PA=False PT_HPU_LAZY_MODE=1 VLLM_USE_V1=1 \
193+
# python -u "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/generation_mm.py" --model-card-path "${VLLM_GAUDI_PREFIX}/tests/full_tests/model_cards/qwen2.5-vl-7b.yaml"
194+
# echo "✅ Test with multimodal-support with qwen2.5-vl-7b passed."
195+
# }
192196

193197
# Spec decode with ngram
194198
run_spec_decode_ngram_test() {
@@ -213,7 +217,7 @@ run_embedding_model_test() {
213217
# Function to run all tests sequentially
214218
launch_all_tests() {
215219
echo "🚀 Starting all test suites..."
216-
run_gemma3_test
220+
# run_gemma3_test
217221
run_basic_model_test
218222
run_tp2_test
219223
run_mla_moe_test
@@ -233,7 +237,7 @@ launch_all_tests() {
233237
run_gsm8k_granite_async_test
234238
run_gsm8k_deepseek_test
235239
run_gsm8k_qwen3_30b_test
236-
run_qwen2_5_vl_test
240+
#run_qwen2_5_vl_test
237241
run_spec_decode_ngram_test
238242
#run_embedding_model_test
239243
echo "🎉 All test suites passed successfully!"

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 8 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
AsyncModelRunnerOutput, KVConnectorOutput)
5454
from vllm.v1.sample.metadata import SamplingMetadata
5555
from vllm.v1.worker.utils import bind_kv_cache
56-
from vllm.v1.utils import CpuGpuBuffer
5756
from vllm_gaudi.v1.worker.hpu_input_batch import InputBatch, CachedRequestState
5857
from vllm.distributed.parallel_state import get_pp_group
5958
from vllm.model_executor.models.interfaces import (SupportsMultiModal, supports_eagle3, supports_transcription)
@@ -543,10 +542,8 @@ def forward(self, *args, **kwargs):
543542
self._reset_rotary_cos_sin()
544543
return hidden_states
545544

546-
def get_input_embeddings(self, input_ids, multimodal_embeddings=None, is_multimodal=False):
547-
return self.model.get_input_embeddings(input_ids=input_ids,
548-
multimodal_embeddings=multimodal_embeddings,
549-
is_multimodal=is_multimodal)
545+
def get_input_embeddings(self, input_ids, multimodal_embeddings=None):
546+
return self.model.get_input_embeddings(input_ids=input_ids, multimodal_embeddings=multimodal_embeddings)
550547

551548
def get_multimodal_embeddings(self, **batched_mm_inputs):
552549
return self.model.get_multimodal_embeddings(**batched_mm_inputs)
@@ -741,8 +738,6 @@ def __init__(
741738
self.mm_registry = MULTIMODAL_REGISTRY
742739
self.uses_mrope = model_config.uses_mrope
743740
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(model_config)
744-
if self.supports_mm_inputs:
745-
self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool)
746741
self.is_multimodal_raw_input_supported = (model_config.is_multimodal_raw_input_only_model)
747742

748743
# Lazy initialization
@@ -853,9 +848,6 @@ def __init__(
853848
assert not (self.unified_attn and not self.use_contiguous_pa), 'Unified attn requires contiguous_pa!'
854849
assert not (self.unified_attn and not self.use_merged_prefill), 'Unified attn requires merged_prefill!'
855850

856-
def _make_buffer(self, *size: Union[int, torch.SymInt], dtype: torch.dtype, numpy: bool = True) -> CpuGpuBuffer:
857-
return CpuGpuBuffer(*size, dtype=dtype, device=self.device, pin_memory=self.pin_memory, with_numpy=numpy)
858-
859851
def unified_bucketing_fn(self, is_causal, query_len, shared_blocks, unique_blocks, logits):
860852
if not get_config().use_bucketing:
861853
return query_len, shared_blocks, unique_blocks, logits
@@ -1319,15 +1311,8 @@ def _gather_mm_embeddings(
13191311
scheduler_output: "SchedulerOutput",
13201312
req_ids: list[str],
13211313
shift_computed_tokens: int = 0,
1322-
total_num_scheduled_tokens: Optional[int] = None,
1323-
) -> tuple[list[torch.Tensor], torch.Tensor]:
1324-
total_num_scheduled_tokens = total_num_scheduled_tokens or scheduler_output.total_num_scheduled_tokens
1325-
1326-
mm_embeds = list[torch.Tensor]()
1327-
is_mm_embed = self.is_mm_embed.cpu
1328-
is_mm_embed[:total_num_scheduled_tokens] = False
1329-
1330-
req_start_idx = 0
1314+
) -> list[torch.Tensor]:
1315+
mm_embeds: list[torch.Tensor] = []
13311316
for req_id in req_ids:
13321317
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
13331318
req_state = self.requests[req_id]
@@ -1366,17 +1351,8 @@ def _gather_mm_embeddings(
13661351
encoder_output[start_idx:end_idx],
13671352
is_embed=is_embed,
13681353
)
1369-
req_start_pos = req_start_idx + start_pos - num_computed_tokens
1370-
is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \
1371-
= True
1372-
1373-
# Only whole mm items are processed
13741354
mm_embeds.append(mm_embeds_item)
1375-
req_start_idx += num_scheduled_tokens
1376-
1377-
is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens)
1378-
1379-
return mm_embeds, is_mm_embed
1355+
return mm_embeds
13801356

13811357
def get_model(self) -> torch.nn.Module:
13821358
assert self.model is not None
@@ -2963,16 +2939,13 @@ def execute_model(
29632939
with self.profiler.record_event('internal', 'prepare_input_encoders'):
29642940
self._execute_mm_encoder(scheduler_output, req_id)
29652941

2966-
mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output,
2967-
req_id,
2968-
total_num_scheduled_tokens=token_ids.shape[-1])
2942+
mm_embeds = self._gather_mm_embeddings(scheduler_output, req_id)
29692943
# TODO: Only get embeddings for valid token_ids. Ignore token_ids[<pad_idxs>] # noqa E501
29702944
# This may require moving multimodal input preps into _prepare_inputs, # noqa E501
29712945
# to avoid padding issues.
29722946
inputs_embeds = self.model.get_input_embeddings(
2973-
token_ids,
2974-
multimodal_embeddings=mm_embeds,
2975-
is_multimodal=is_mm_embed,
2947+
input_ids=token_ids,
2948+
multimodal_embeddings=mm_embeds or None,
29762949
)
29772950

29782951
model_mm_kwargs = self._extract_mm_kwargs(scheduler_output)

0 commit comments

Comments
 (0)