Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .jenkins/vision/configs/Qwen2.5-VL-7B-Instruct.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
model_name: "/mnt/weka/data/pytorch/Qwen/Qwen2.5-VL-7B-Instruct/"
dtype: "bfloat16"
max_model_len: 32768
max_model_len: 35840
max_num_seqs: 32
num_prompts: 4
35 changes: 15 additions & 20 deletions vllm/model_executor/models/gemma3_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,6 @@ def _get_mm_fields_config(
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
num_crops = hf_inputs.get("num_crops", torch.empty(0))

return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops + 1),
Expand Down Expand Up @@ -534,7 +533,6 @@ def _parse_and_validate_image_input(
assert image_embeds is None, "Gemma3 does not support image_embeds."
if pixel_values is None:
return None

if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
Expand Down Expand Up @@ -570,11 +568,6 @@ def _process_image_input(
pixel_values = image_input["pixel_values"]
num_patches = image_input["num_patches"]

image_features = self._image_pixels_to_features(
self.vision_tower,
pixel_values,
)

if is_hpu:
batch_breakdown = greedy_plan(pixel_values.shape[0], \
self.vision_buckets.multimodal_buckets)
Expand All @@ -583,22 +576,24 @@ def _process_image_input(

for i in batch_breakdown:
end_idx = start_idx + i
batch_sliced_image_features = \
image_features[start_idx:end_idx, ...]
if is_lazy:
image_embeds_multibatches += \
[self.multi_modal_projector(
batch_sliced_image_features,
bypass_hpu_graphs=i
not in self.graphed_multimodal_buckets
and len(self.graphed_multimodal_buckets) > 0)]
else:
image_embeds_multibatches += \
[self.multi_modal_projector( \
batch_sliced_image_features)]
indices = torch.arange(start_idx, end_idx)
batch_sliced_pixel_values = torch.index_select(pixel_values,
dim=0,
index=indices)

image_features = self._image_pixels_to_features(
self.vision_tower,
batch_sliced_pixel_values,
)
image_embeds = self.multi_modal_projector(image_features)
image_embeds_multibatches += [image_embeds.clone()]
start_idx = end_idx
image_embeds = torch.cat(image_embeds_multibatches, dim=0)
else:
image_features = self._image_pixels_to_features(
self.vision_tower,
pixel_values,
)
image_embeds = self.multi_modal_projector(image_features)
return [
e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())
Expand Down
48 changes: 28 additions & 20 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def __init__(self, model, vllm_config, is_causal, sampler):
if self.is_mm_optimized:
if hasattr(self.model, 'vision_tower'):
self.model.vision_tower = htorch.hpu.wrap_in_hpu_graph(
self.model.vision_tower, disable_tensor_cache=True)
self.model.vision_tower, disable_tensor_cache=False)
if hasattr(self.model, 'multi_modal_projector'):
self.model.multi_modal_projector = \
htorch.hpu.wrap_in_hpu_graph( \
Expand Down Expand Up @@ -616,13 +616,19 @@ def _update_metadata(self,
device, dtype, True)
return attn_metadata

def compute_input_embeddings_for_mm_optimized(self, **kwargs):
def compute_input_embeddings_for_mm_optimized(self, warmup_mode, **kwargs):
input_ids = kwargs['input_ids']
vision_embeddings = self.model.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.model.get_input_embeddings(
input_ids, vision_embeddings)

if vision_embeddings is not None:
# TODO: In warmup, we need to warmup the model with dummy image data for
# multimodal model for prompt, here instead of generating a dummy image,
# we are just generating attn_mask for the images and pass with
# attn_metadata, so we can reuse HPU graph without running
# the whole vision tower.
if vision_embeddings is not None or (
warmup_mode & kwargs['attn_metadata'].is_prompt):
input_ids = kwargs['input_ids']
positions = kwargs['positions']
kwargs = self.model.prepare_attn_masks(
Expand All @@ -631,14 +637,16 @@ def compute_input_embeddings_for_mm_optimized(self, **kwargs):
)
kwargs['input_ids'] = input_ids
kwargs['positions'] = positions
#input_ids = None

kwargs.update({'inputs_embeds': inputs_embeds})
# done compute the visual tokens
# done compute the visual tokens and others
kwargs.pop('pixel_values', None)
kwargs.pop("num_crops", None)
kwargs.pop("graphed_multimodal_buckets", None)
return kwargs

def compute_input_embeddings_for_mrope_mm_optimized(self, **kwargs):
def compute_input_embeddings_for_mrope_mm_optimized(
self, warmup_mode, **kwargs):

if 'inputs_embeds' in kwargs:
return kwargs
Expand Down Expand Up @@ -677,7 +685,8 @@ def compute_input_embeddings_for_mrope_mm_optimized(self, **kwargs):
kwargs.pop('image_grid_thw', None)
return kwargs
else:
return self.compute_input_embeddings_for_mm_optimized(**kwargs)
return self.compute_input_embeddings_for_mm_optimized(
warmup_mode, **kwargs)

def forward(self, *args, **kwargs):
kwargs = kwargs.copy()
Expand All @@ -689,9 +698,9 @@ def forward(self, *args, **kwargs):
virtual_engine = kwargs.pop('virtual_engine')

input_ids = kwargs['input_ids']
global_attn_masks = kwargs.get("global_attn_masks") \
global_attn_masks = kwargs.pop("global_attn_masks") \
if kwargs.get("global_attn_masks") else None
local_attn_masks = kwargs.get("local_attn_masks") \
local_attn_masks = kwargs.pop("local_attn_masks") \
if kwargs.get("local_attn_masks") else None

kwargs['attn_metadata'] = self._update_metadata(
Expand Down Expand Up @@ -1383,12 +1392,8 @@ def get_model(self) -> torch.nn.Module:
return self.model.model
return self.model

def _use_graphs(self, img_args=None):
if not img_args:
return not self.enforce_eager
#TODO: We might need to check both language bucket and multimodal bucket
# and return True only it's avialble, or return separately.
return (img_args) in self.graphed_multimodal_buckets
def _use_graphs(self):
return not self.enforce_eager

def _is_valid_bucket(self, bucket):
return bucket[0] * bucket[1] <= self.max_num_batched_tokens
Expand Down Expand Up @@ -2652,7 +2657,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:

def create_dummy_multi_modal_seq_group_metadata(self, group_id, img_args,
sampling_params,
lora_request):
lora_request, seq_len):
assert self.model_is_mrope or self.is_mm_optimized, \
("Warmup compatible with Qwen2vl/Gemma3 models")
if img_args == UNSET_IMG_ARGS:
Expand Down Expand Up @@ -2697,7 +2702,9 @@ def create_dummy_multi_modal_seq_group_metadata(self, group_id, img_args,
}

image_token_id = self.get_model().config.image_token_id
prompt_token_ids = [image_token_id] * num_image_tokens
prompt_token_ids_image = [image_token_id] * num_image_tokens
prompt_token_ids = [0] * (
seq_len - len(prompt_token_ids_image)) + prompt_token_ids_image
prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821
placeholders_by_modality = {
'image':
Expand Down Expand Up @@ -2741,6 +2748,7 @@ def create_dummy_seq_group_metadata(self,
img_args=img_args,
sampling_params=sampling_params,
lora_request=lora_request,
seq_len=seq_len,
)
else:
input_len = seq_len
Expand Down Expand Up @@ -2853,7 +2861,7 @@ def warmup_scenario(self,
align_worker=False,
is_dummy_run=False) -> None:
phase = 'prompt' if is_prompt else 'decode'
use_graphs = is_dummy_run or self._use_graphs(img_args)
use_graphs = is_dummy_run or self._use_graphs()

scenario_name = ("warmup_"
f"{phase}_"
Expand Down Expand Up @@ -3739,8 +3747,7 @@ def execute_model(
if not warmup_mode:
ctx_blocks = seq_len
seq_len = 1
img_args = self._get_img_args_from_model_input(model_input)
use_graphs = self._use_graphs(img_args=img_args)
use_graphs = self._use_graphs()
self._check_config(batch_size, seq_len, ctx_blocks, attn_metadata,
warmup_mode)
lora_mask: torch.Tensor = None
Expand Down Expand Up @@ -3906,6 +3913,7 @@ def try_revert_dummy_output_tokens():
# hpu graphs, hence turning it to a list
execute_model_kwargs = \
self.model.compute_input_embeddings_for_mrope_mm_optimized(
warmup_mode,
**execute_model_kwargs
)
if warmup_mode and bypass_model_exec:
Expand Down