Skip to content

Commit

Permalink
[Bugfix] Fix LLaVA-NeXT (#5380)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored Jun 10, 2024
1 parent 774d103 commit 2c0d933
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
24 changes: 24 additions & 0 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,30 @@ def _parse_and_validate_image_input(

return None

def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features

raise ValueError(f"Unexpected select feature strategy: {strategy}")

def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
image_outputs = vision_tower(pixel_values.to(vision_tower.device),
output_hidden_states=True)

image_features = image_outputs.hidden_states[
self.config.vision_feature_layer]

return self._select_image_features(
image_features,
strategy=self.config.vision_feature_select_strategy,
)

def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
patch_embeddings: torch.Tensor, *,
strategy: str) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion vllm/multimodal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def get_full_image_text_prompt(image_prompt: str, text_prompt: str,
"""Combine image and text prompts for vision language model depending on
the model architecture."""

if config.hf_config.model_type == "llava":
if config.hf_config.model_type in ("llava", "llava_next"):
full_prompt = f"{image_prompt}\n{text_prompt}"
else:
raise ValueError(
Expand Down

0 comments on commit 2c0d933

Please sign in to comment.