Skip to content

[Model]: get aria to work with the lastest transfomers impl #12207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 49 additions & 95 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.nn as nn
from transformers import BatchFeature, PretrainedConfig
from transformers.models.aria import AriaTextConfig

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
Expand All @@ -26,8 +27,6 @@
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.aria import (AriaMoELMConfig,
AriaVisionConfig)

from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import SupportsMultiModal
Expand All @@ -39,89 +38,14 @@

class AriaImagePixelInputs(TypedDict):
pixel_values: torch.Tensor
pixel_mask: Optional[torch.Tensor]
patch_attention_mask: Optional[torch.Tensor]
"""
Shape:
pixel_values: `(batch_size * num_images, num_channels, height, width)`
pixel_mask: `(batch_size * num_images, height, width)`
"""


class AriaVisionTransformer(Idefics2VisionTransformer):
"""
AriaVisionTransformer is a modified version of Idefics2VisionTransformer
that replaces the post-layernorm with an identity layer.
"""

def __init__(
self,
config: AriaVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config, quant_config, prefix)
self.post_layernorm = nn.Identity()


class AriaVisionModel(nn.Module):
config_class = AriaVisionConfig

def __init__(
self,
config: AriaVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
prefix: str = "",
) -> None:
super().__init__()

self.vision_model = AriaVisionTransformer(
config,
quant_config,
prefix=f"{prefix}.vision_model",
)

def forward(
self,
pixel_values: torch.Tensor,
pixel_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)

vit_oup = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
)

image_atts = self._create_image_attention_mask(patch_attention_mask)

return vit_oup, image_atts

def _create_patch_attention_mask(
self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor:
if pixel_mask is None:
return None

patches_subgrid = pixel_mask.unfold(
dimension=1,
size=self.vision_model.config.patch_size,
step=self.vision_model.config.patch_size,
).unfold(
dimension=2,
size=self.vision_model.config.patch_size,
step=self.vision_model.config.patch_size,
)
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()

def _create_image_attention_mask(
self, patch_attention_mask: torch.Tensor) -> torch.Tensor:
if patch_attention_mask is None:
return None

flattened_mask = patch_attention_mask.flatten(1)
return torch.logical_not(flattened_mask)


class FFN(nn.Module):

def __init__(self, embed_dim: int, ff_dim: int, output_dim: int) -> None:
Expand Down Expand Up @@ -150,7 +74,7 @@ def __init__(self, kv_dim: int, embed_dim: int, num_heads: int) -> None:
self.linear = nn.Linear(embed_dim, embed_dim)

self.layer_norm = nn.LayerNorm(embed_dim)
self.ln_kv = nn.LayerNorm(kv_dim)
self.layer_norm_kv = nn.LayerNorm(kv_dim)

def forward(
self,
Expand All @@ -161,7 +85,7 @@ def forward(
normed_hidden_states = self.layer_norm(hidden_states)
query = self.q_proj(normed_hidden_states).permute(1, 0, 2)

x = self.ln_kv(x)
x = self.layer_norm_kv(x)
key = self.k_proj(x).permute(1, 0, 2)
value = self.v_proj(x).permute(1, 0, 2)

Expand Down Expand Up @@ -218,8 +142,8 @@ def __init__(

self.cross_attn = CrossAttention(kv_dim, embed_dim, num_heads)

self.ln_ffn = norm_layer(embed_dim)
self.ffn = FFN(embed_dim, ff_dim, output_dim)
self.layer_norm = norm_layer(embed_dim)
self.feed_forward = FFN(embed_dim, ff_dim, output_dim)

def forward(
self,
Expand All @@ -241,7 +165,7 @@ def forward(

attention_out = self.cross_attn(x, queries, attn_mask=attn_mask)

out = self.ffn(self.ln_ffn(attention_out))
out = self.feed_forward(self.layer_norm(attention_out))

return out

Expand Down Expand Up @@ -289,7 +213,7 @@ class MoELayer(nn.Module):

def __init__(
self,
config: AriaMoELMConfig,
config: AriaTextConfig,
quant_config: Optional[QuantizationConfig],
) -> None:
super().__init__()
Expand All @@ -303,13 +227,13 @@ def __init__(
num_experts=config.moe_num_experts,
top_k=config.moe_topk,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
reduce_results=True,
)
self.shared_experts = LlamaMLP(
config.hidden_size,
config.moe_intermediate_size * config.moe_num_shared_experts,
config.intermediate_size * config.moe_num_shared_experts,
"silu",
quant_config=quant_config,
)
Expand Down Expand Up @@ -344,7 +268,7 @@ class MoEDecoderLayer(LlamaDecoderLayer):

def __init__(
self,
config: AriaMoELMConfig,
config: AriaTextConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
Expand Down Expand Up @@ -450,7 +374,7 @@ class AriaProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config()

def get_vision_config(self) -> AriaVisionConfig:
def get_vision_config(self):
return self.get_hf_config().vision_config

def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
Expand Down Expand Up @@ -483,8 +407,8 @@ def get_dummy_processor_inputs(
num_images=num_images)
}

hf_processor = self.info.get_hf_processor()
image_token: str = hf_processor.image_token # type: ignore
# hf_processor = self.info.get_hf_processor()
image_token: str = '<|img|>'

return ProcessorInputs(
prompt_text=image_token * num_images,
Expand Down Expand Up @@ -554,7 +478,7 @@ def __init__(
quant_config = vllm_config.quant_config

self.config = config
self.vision_tower = AriaVisionModel(config.vision_config)
self.vision_tower = Idefics2VisionTransformer(config.vision_config)
self.multi_modal_projector = build_mm_projector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AriaMoELMModel(
Expand All @@ -581,6 +505,30 @@ def _validate_image_sizes(
raise ValueError("All images must be the same size")
return images

def _create_patch_attention_mask(
self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor:
if pixel_mask is None:
return None

patches_subgrid = pixel_mask.unfold(
dimension=1,
size=self.config.vision_config.patch_size,
step=self.config.vision_config.patch_size,
).unfold(
dimension=2,
size=self.config.vision_config.patch_size,
step=self.config.vision_config.patch_size,
)
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()

def _create_image_attention_mask(
self, patch_attention_mask: torch.Tensor) -> torch.Tensor:
if patch_attention_mask is None:
return None

flattened_mask = patch_attention_mask.flatten(1)
return torch.logical_not(flattened_mask)

def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[AriaImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None)
Expand All @@ -596,16 +544,20 @@ def _parse_and_validate_image_input(
pixel_values = self._validate_image_sizes(pixel_values)
pixel_values = flatten_bn(pixel_values, concat=True)

patch_attention_mask = None
if pixel_mask is not None:
if not isinstance(pixel_mask, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel mask. "
f"Got type: {type(pixel_mask)}")

pixel_mask = flatten_bn(pixel_mask, concat=True)

patch_attention_mask = self._create_patch_attention_mask(
pixel_mask)

return AriaImagePixelInputs(
pixel_values=pixel_values,
pixel_mask=pixel_mask,
patch_attention_mask=patch_attention_mask,
)

def _process_image_input(
Expand All @@ -614,10 +566,12 @@ def _process_image_input(
assert self.vision_tower is not None

pixel_values = image_input['pixel_values']
pixel_mask = image_input['pixel_mask']
patch_attention_mask = image_input['patch_attention_mask']

image_feature, image_attn_mask = self.vision_tower(
pixel_values, pixel_mask=pixel_mask)
image_feature = self.vision_tower(
pixel_values, patch_attention_mask=patch_attention_mask)
image_attn_mask = self._create_image_attention_mask(
patch_attention_mask)
return self.multi_modal_projector(image_feature, image_attn_mask)

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
Expand Down
9 changes: 4 additions & 5 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
from vllm.logger import init_logger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.transformers_utils.configs import (AriaConfig, ChatGLMConfig,
Cohere2Config, DbrxConfig,
DeepseekVLV2Config, EAGLEConfig,
ExaoneConfig, H2OVLChatConfig,
from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
DbrxConfig, DeepseekVLV2Config,
EAGLEConfig, ExaoneConfig,
H2OVLChatConfig,
InternVLChatConfig, JAISConfig,
MedusaConfig, MllamaConfig,
MLPSpeculatorConfig, MPTConfig,
Expand All @@ -52,7 +52,6 @@
}

_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"aria": AriaConfig,
"chatglm": ChatGLMConfig,
"cohere2": Cohere2Config,
"dbrx": DbrxConfig,
Expand Down
1 change: 0 additions & 1 deletion vllm/transformers_utils/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from vllm.transformers_utils.configs.aria import AriaConfig
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.cohere2 import Cohere2Config
from vllm.transformers_utils.configs.dbrx import DbrxConfig
Expand Down
Loading
Loading