Skip to content

Commit

Permalink
Add LLaVA-NeXT architecture
Browse files Browse the repository at this point in the history
- Note that LLaVA-1.5 has been refactored to facilitate this
  • Loading branch information
DarkLight1337 committed Apr 19, 2024
1 parent adf2b94 commit ea4f8ed
Show file tree
Hide file tree
Showing 6 changed files with 266 additions and 13 deletions.
4 changes: 3 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import torch
from PIL import Image
from transformers import (AutoModelForCausalLM, AutoProcessor,
LlavaForConditionalGeneration)
LlavaForConditionalGeneration,
LlavaNextForConditionalGeneration)

from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
Expand Down Expand Up @@ -127,6 +128,7 @@ def example_long_prompts() -> List[str]:

_VISION_LANGUAGE_MODELS = {
"llava-hf/llava-1.5-7b-hf": LlavaForConditionalGeneration,
"llava-hf/llava-v1.6-34b-hf": LlavaNextForConditionalGeneration,
}


Expand Down
24 changes: 23 additions & 1 deletion tests/models/test_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,32 @@ def iter_llava_configs(model_name: str):
image_processor_revision=None))


def iter_llava_next_configs(model_name: str):
image_hw_to_feature_size = {
(336, 336): 1176,
(672, 672): 2928,
(1344, 336): 1944,
(336, 1344): 1890,
}

for (h, w), f in image_hw_to_feature_size.items():
for input_type, input_shape in [
(VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)),
(VisionLanguageConfig.ImageInputType.IMAGE_FEATURES, (1, f, 1024)),
]:
yield (model_name,
VisionLanguageConfig(image_input_type=input_type,
image_feature_size=f,
image_token_id=64000,
image_input_shape=input_shape,
image_processor=model_name,
image_processor_revision=None))


model_and_vl_config = [
*iter_llava_configs("llava-hf/llava-1.5-7b-hf"),
# Not enough memory
# *iter_llava_configs("llava-hf/llava-1.5-13b-hf"),
# *iter_llava_next_configs("llava-hf/llava-v1.6-34b-hf"),
]


Expand Down
7 changes: 1 addition & 6 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,10 @@
download_weights_from_hf, filter_files_not_needed_for_inference,
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.llava import LlavaForConditionalGeneration

if TYPE_CHECKING:
from vllm.model_executor.layers.linear import LinearMethodBase

_VISION_MODEL_CLASSES = [
LlavaForConditionalGeneration,
]

logger = init_logger(__name__)


Expand Down Expand Up @@ -74,7 +69,7 @@ def _get_model_initialization_kwargs(
"but LoRA is enabled. Support for this model may "
"be added in the future. If this is important to you, "
"please open an issue on github.")
elif model_class in _VISION_MODEL_CLASSES:
elif getattr(model_class, "is_vlm", False):
extra_kwargs["vision_language_config"] = vision_language_config
return extra_kwargs

Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
"LlavaForConditionalGeneration":
("llava", "LlavaForConditionalGeneration"),
"LlavaNextForConditionalGeneration":
("llava_next", "LlavaNextForConditionalGeneration"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
Expand Down
24 changes: 19 additions & 5 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
from typing import (ClassVar, Iterable, List, Literal, Optional, Tuple,
TypedDict, Union)

import torch
from torch import nn
Expand Down Expand Up @@ -52,7 +53,13 @@ def _merge_vision_embeddings(input_ids: torch.Tensor,
image_token_id: int) -> torch.Tensor:
"""In place merges in vision_embeddings with inputs_embeds."""
mask = (input_ids == image_token_id)
inputs_embeds[mask] = vision_embeddings.view(-1,

image_feature_size = vision_embeddings.shape[0] * vision_embeddings.shape[1]
if mask.sum() != image_feature_size:
raise ValueError(f"image_feature_size should be {image_feature_size}, "
f"but found: {mask.sum()}")

inputs_embeds[mask] = vision_embeddings.view(image_feature_size,
vision_embeddings.shape[-1])
return inputs_embeds

Expand All @@ -74,11 +81,17 @@ class LlavaImageFeatureInputs(TypedDict):

class LlavaForConditionalGeneration(nn.Module):

is_vlm: ClassVar[bool] = True
"""Indicates that the model is a vision-language model and thus accepts
the `vision_language_config` parameter.
"""

def __init__(self,
config: "LlavaConfig",
config: LlavaConfig,
vision_language_config: VisionLanguageConfig,
linear_method: Optional["LinearMethodBase"] = None) -> None:
linear_method: Optional[LinearMethodBase] = None) -> None:
super().__init__()

self.config = config

self.vision_language_config = vision_language_config
Expand Down Expand Up @@ -213,7 +226,7 @@ def forward(
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs: object,
) -> SamplerOutput: # noqa: E501
) -> SamplerOutput:
"""Run forward pass for Llava 1.5.
One key thing to understand is the `input_ids` already accounts for the
Expand Down Expand Up @@ -264,6 +277,7 @@ def forward(
input_ids = None
else:
inputs_embeds = None

hidden_states = self.language_model(input_ids,
positions,
kv_caches,
Expand Down
218 changes: 218 additions & 0 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
from typing import Optional, TypedDict, Union

import torch
from torch import nn
from transformers import LlavaNextConfig
from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image)

from vllm.config import VisionLanguageConfig
from vllm.model_executor.layers.linear import LinearMethodBase

from .llava import (LlavaForConditionalGeneration, LlavaImageFeatureInputs,
LlavaImagePixelInputs)


class ImageSizesMixin(TypedDict, total=False):
image_sizes: torch.Tensor
"""Shape: (batch_size, 2)"""


class LlavaNextImagePixelInputs(ImageSizesMixin, LlavaImagePixelInputs):
data: torch.Tensor
"""Shape: (batch_size, 1 + num_patches, num_channels, height, width)"""


class LlavaNextImageFeatureInputs(ImageSizesMixin, LlavaImageFeatureInputs):
data: torch.Tensor
"""Shape: (batch_size, 1 + num_patches, image_feature_size, hidden_size)"""


LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
LlavaNextImageFeatureInputs]


class LlavaNextForConditionalGeneration(LlavaForConditionalGeneration):
"""
Args to `forward()`:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: For PIXEL_VALUES, expects a batch with shape
[1, num_patches, 3, 336, 336].
image_features: For IMAGE_FEATURES, expects a batch with shape
[1, num_patches, 1176, 1024].
"""

def __init__(self,
config: LlavaNextConfig,
vision_language_config: VisionLanguageConfig,
linear_method: Optional[LinearMethodBase] = None) -> None:
super().__init__(config, vision_language_config, linear_method)

# Update the type annotation from that of its superclass
self.config = config

self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))

def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor:
_, num_channels, _, _ = self.vision_language_config.image_input_shape

# Note that this is different from that of vLLM vision_language_config
# since the image is resized by the HuggingFace preprocessor
height = width = self.config.vision_config.image_size

if list(data.shape[2:]) != [num_channels, height, width]:
raise ValueError(
f"The expected image tensor shape is batch dimension plus "
f"num_patches plus {[num_channels, height, width]}. "
f"You supplied {data.shape}. "
f"If you are using vLLM's entrypoint, make sure your "
f"supplied image input is consistent with "
f"image_input_shape in engine args.")

return data

def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_features = kwargs.pop("image_features", None)

expected_input_type = self.vision_language_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType

if expected_input_type == ImageInputType.PIXEL_VALUES:
if image_features is not None:
raise ValueError(
"Expected pixel values but got image features")
if pixel_values is None:
return None

if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values")

if not isinstance(image_sizes, torch.Tensor):
raise ValueError("Incorrect type of image sizes")

return LlavaNextImagePixelInputs(
type="pixel_values",
data=self._validate_image_pixels(pixel_values),
image_sizes=image_sizes,
)

if expected_input_type == ImageInputType.IMAGE_FEATURES:
if pixel_values is not None:
raise ValueError(
"Expected image features but got pixel values")
if image_features is None:
return None

if not isinstance(image_features, torch.Tensor):
raise ValueError("Incorrect type of image features")

return LlavaNextImageFeatureInputs(
type="image_features",
data=self._validate_image_data(image_features),
)

return None

def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
patch_embeddings: torch.Tensor, *,
strategy: str) -> torch.Tensor:
if strategy == "flat":
return patch_embeddings.flatten(0, 1)

if strategy.startswith("spatial"):
orig_width, orig_height = image_size
height = width = self.config.vision_config.image_size \
// self.config.vision_config.patch_size

base_patch_embeds = patch_embeddings[0]
if height * width != base_patch_embeds.shape[0]:
raise ValueError(
"The number of patches is not consistent with the "
"image size.")

if patch_embeddings.shape[0] > 1:
other_patch_embeds = patch_embeddings[1:]

# image_aspect_ratio == "anyres"
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
(orig_width, orig_height),
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
other_patch_embeds = other_patch_embeds \
.view(num_patch_width, num_patch_height, height, width, -1)

if "unpad" in strategy:
other_patch_embeds = other_patch_embeds \
.permute(4, 0, 2, 1, 3).contiguous() \
.flatten(1, 2).flatten(2, 3)
other_patch_embeds = unpad_image(other_patch_embeds,
image_size)
other_patch_embeds = torch.cat((
other_patch_embeds,
self.image_newline[:, None, None] \
.expand(*other_patch_embeds.shape[:-1], 1) \
.to(other_patch_embeds.device),
), dim=-1)
other_patch_embeds = other_patch_embeds \
.flatten(1, 2).transpose(0, 1)
else:
other_patch_embeds = other_patch_embeds \
.permute(0, 2, 1, 3, 4).contiguous() \
.flatten(0, 3)

merged_patch_embeddings = torch.cat(
(base_patch_embeds, other_patch_embeds), dim=0)
else:
if "unpad" in strategy:
merged_patch_embeddings = torch.cat(
(base_patch_embeds,
self.image_newline[None] \
.to(base_patch_embeds.device)
), dim=0)
else:
merged_patch_embeddings = base_patch_embeds

return merged_patch_embeddings

raise ValueError(f"Unexpected patch merge strategy: {strategy}")

def _process_image_pixels(
self, inputs: LlavaNextImagePixelInputs) -> torch.Tensor:
assert self.vision_tower is not None

pixel_values = inputs["data"]

b, num_patches, c, h, w = pixel_values.shape
stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)

stacked_image_features = self._image_pixels_to_features(
self.vision_tower, stacked_pixel_values)

return stacked_image_features.view(b, num_patches,
*stacked_image_features.shape[-2:])

def _process_image_input(
self, image_input: LlavaNextImageInputs) -> torch.Tensor:
patch_embeddings = super()._process_image_input(image_input)

image_sizes = image_input.get("image_sizes")
if image_sizes is None:
batch_size = image_input["data"].shape[0]
default_width, default_height = self.config.vision_config.image_size
image_sizes = torch.as_tensor([[default_width, default_height]
for _ in range(batch_size)])

merged_patch_embeddings = [
self._merge_image_patch_embeddings(image_sizes[i],
patch_features,
strategy="spatial_unpad")
for i, patch_features in enumerate(patch_embeddings)
]

return torch.stack(merged_patch_embeddings, dim=0)

0 comments on commit ea4f8ed

Please sign in to comment.