Skip to content

Commit 682f280

Browse files
DarkLight1337Isotr0py
authored andcommitted
[Bugfix] Fix multi-modal processors for transformers 4.48 (vllm-project#12187)
Signed-off-by: Isotr0py <2037008807@qq.com>
1 parent f8a0f11 commit 682f280

File tree

6 files changed

+199
-36
lines changed

6 files changed

+199
-36
lines changed

vllm/model_executor/models/llava.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55

66
import torch
77
import torch.nn as nn
8+
from packaging.version import Version
89
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
910
PixtralVisionConfig, PretrainedConfig,
1011
SiglipVisionConfig)
12+
from transformers import __version__ as TRANSFORMERS_VERSION
1113
from transformers.models.llava import LlavaProcessor
1214
from transformers.models.pixtral import PixtralProcessor
1315

@@ -716,6 +718,27 @@ def load_weights(self, weights: Iterable[Tuple[str,
716718
return loader.load_weights(weights)
717719

718720

721+
class MantisProcessingInfo(LlavaProcessingInfo):
722+
723+
def get_hf_processor(self):
724+
hf_config = self.get_hf_config()
725+
vision_info = self.get_vision_encoder_info()
726+
727+
if Version(TRANSFORMERS_VERSION) < Version("4.48"):
728+
# BUG: num_additional_image_tokens = 0 but treated as 1,
729+
# so we set vision_feature_select_strategy to None to offset this
730+
vision_feature_select_strategy = None
731+
else:
732+
# FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
733+
vision_feature_select_strategy = hf_config.vision_feature_select_strategy # noqa: E501
734+
735+
return self.ctx.get_hf_processor(
736+
LlavaProcessor,
737+
patch_size=vision_info.get_patch_size(),
738+
vision_feature_select_strategy=vision_feature_select_strategy,
739+
)
740+
741+
719742
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
720743

721744
def apply(
@@ -794,7 +817,7 @@ def get_replacement_mantis(item_idx: int):
794817
# To use this model, please use
795818
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
796819
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
797-
info=LlavaProcessingInfo,
820+
info=MantisProcessingInfo,
798821
dummy_inputs=LlavaDummyInputsBuilder)
799822
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
800823
pass

vllm/model_executor/models/qwen2_audio.py

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@
3636
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
3737
from vllm.model_executor.sampling_metadata import SamplingMetadata
3838
from vllm.multimodal import MULTIMODAL_REGISTRY
39-
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
40-
NestedTensors)
39+
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
40+
MultiModalInputsV2, MultiModalKwargs,
41+
NestedTensors, PlaceholderRange)
4142
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
4243
MultiModalDataParser)
4344
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@@ -153,29 +154,24 @@ def _call_hf_processor(
153154
mm_data: Mapping[str, object],
154155
mm_kwargs: Mapping[str, Any],
155156
) -> BatchFeature:
156-
mm_data = dict(mm_data)
157-
audios = mm_data.pop("audios", [])
158-
159-
if audios:
160-
mm_data["audios"] = audios
161-
162-
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
163-
mm_kwargs = dict(
164-
**mm_kwargs,
165-
sampling_rate=feature_extractor.sampling_rate,
166-
)
167-
else:
168-
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
169-
pass
157+
# Text-only input not supported in composite processor
158+
if not mm_data or not mm_data.get("audios", []):
159+
prompt_ids = self.info.get_tokenizer().encode(prompt)
160+
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
161+
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
162+
163+
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
164+
mm_kwargs = dict(
165+
**mm_kwargs,
166+
sampling_rate=feature_extractor.sampling_rate,
167+
)
170168

171-
processed_outputs = super()._call_hf_processor(
169+
return super()._call_hf_processor(
172170
prompt=prompt,
173171
mm_data=mm_data,
174172
mm_kwargs=mm_kwargs,
175173
)
176174

177-
return processed_outputs
178-
179175
def _get_mm_fields_config(
180176
self,
181177
hf_inputs: BatchFeature,
@@ -192,8 +188,14 @@ def _get_prompt_replacements(
192188
hf_processor_mm_kwargs: Mapping[str, object],
193189
out_mm_kwargs: MultiModalKwargs,
194190
) -> list[PromptReplacement]:
195-
hf_config = self.info.get_hf_config()
196-
placeholder = hf_config.audio_token_index
191+
processor = self.info.get_hf_processor()
192+
193+
# Use getattr with default to be compatible with transformers<4.48
194+
audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
195+
audio_bos_token = getattr(processor, "audio_bos_token",
196+
"<|audio_bos|>")
197+
audio_eos_token = getattr(processor, "audio_eos_token",
198+
"<|audio_eos|>")
197199

198200
feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
199201
if feature_attention_mask is None:
@@ -214,12 +216,16 @@ def get_replacement_qwen2_audio(item_idx: int):
214216
f"The audio {audio} (len={len(audio)}) is too short "
215217
"to be represented inside the model")
216218

217-
return [placeholder] * num_placeholders
219+
return "".join([
220+
audio_bos_token,
221+
audio_token * num_placeholders,
222+
audio_eos_token,
223+
])
218224

219225
return [
220226
PromptReplacement(
221227
modality="audio",
222-
target=[placeholder],
228+
target=audio_token,
223229
replacement=get_replacement_qwen2_audio,
224230
)
225231
]
@@ -234,6 +240,26 @@ def _always_apply_prompt_replacements(self) -> bool:
234240
# tokens than the number of audio items)
235241
return not hasattr(self.info.get_hf_processor(), "audio_token")
236242

243+
def apply(
244+
self,
245+
prompt: Union[str, list[int]],
246+
mm_data: MultiModalDataDict,
247+
hf_processor_mm_kwargs: Mapping[str, object],
248+
) -> MultiModalInputsV2:
249+
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
250+
251+
# Only <|AUDIO|> tokens should be considered as placeholders,
252+
# so we ignore the audio_bos_token and audio_eos_token
253+
result["mm_placeholders"] = {
254+
modality: [
255+
PlaceholderRange(offset=p["offset"] + 1,
256+
length=p["length"] - 2) for p in ps
257+
]
258+
for modality, ps in result["mm_placeholders"].items()
259+
}
260+
261+
return result
262+
237263

238264
@MULTIMODAL_REGISTRY.register_processor(
239265
Qwen2AudioMultiModalProcessor,

vllm/model_executor/models/ultravox.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def _call_hf_processor(
137137
mm_kwargs: Mapping[str, object],
138138
) -> BatchFeature:
139139
# Text-only input not supported in composite processor
140-
if not mm_data:
140+
if not mm_data or not mm_data.get("audios", []):
141141
prompt_ids = self.info.get_tokenizer().encode(prompt)
142142
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
143143
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
@@ -146,13 +146,6 @@ def _call_hf_processor(
146146
audios = mm_data.pop("audios", [])
147147
assert isinstance(audios, list)
148148

149-
if not audios:
150-
return super()._call_hf_processor(
151-
prompt=prompt,
152-
mm_data=mm_data,
153-
mm_kwargs=mm_kwargs,
154-
)
155-
156149
feature_extractor = self.info.get_feature_extractor()
157150
mm_kwargs = dict(
158151
**mm_kwargs,

vllm/transformers_utils/config.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
from vllm.logger import init_logger
2323
# yapf conflicts with isort for this block
2424
# yapf: disable
25-
from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
26-
DbrxConfig, DeepseekVLV2Config,
27-
EAGLEConfig, ExaoneConfig,
28-
H2OVLChatConfig,
25+
from vllm.transformers_utils.configs import (AriaConfig, ChatGLMConfig,
26+
Cohere2Config, DbrxConfig,
27+
DeepseekVLV2Config, EAGLEConfig,
28+
ExaoneConfig, H2OVLChatConfig,
2929
InternVLChatConfig, JAISConfig,
3030
MedusaConfig, MllamaConfig,
3131
MLPSpeculatorConfig, MPTConfig,
@@ -52,6 +52,7 @@
5252
}
5353

5454
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
55+
"aria": AriaConfig,
5556
"chatglm": ChatGLMConfig,
5657
"cohere2": Cohere2Config,
5758
"dbrx": DbrxConfig,

vllm/transformers_utils/configs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from vllm.transformers_utils.configs.aria import AriaConfig
12
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
23
from vllm.transformers_utils.configs.cohere2 import Cohere2Config
34
from vllm.transformers_utils.configs.dbrx import DbrxConfig
@@ -23,6 +24,7 @@
2324
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
2425

2526
__all__ = [
27+
"AriaConfig",
2628
"ChatGLMConfig",
2729
"Cohere2Config",
2830
"DbrxConfig",

vllm/transformers_utils/configs/aria.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,32 @@
1+
# Copyright 2024 Rhymes AI. All rights reserved.
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
from typing import Mapping
20+
21+
from transformers import PretrainedConfig
122
from transformers.models.idefics2.configuration_idefics2 import (
223
Idefics2VisionConfig)
324
from transformers.models.llama.configuration_llama import LlamaConfig
425

26+
from vllm.logger import init_logger
27+
28+
logger = init_logger(__name__)
29+
530

631
class AriaVisionConfig(Idefics2VisionConfig):
732
model_type = "aria_vision_model"
@@ -45,3 +70,96 @@ def __init__(
4570
self.moe_num_experts = moe_num_experts
4671
self.moe_topk = moe_topk
4772
self.moe_num_shared_experts = moe_num_shared_experts
73+
74+
75+
class AriaConfig(PretrainedConfig):
76+
"""
77+
Configuration class for Aria model.
78+
This class handles the configuration for both vision and text components of
79+
the Aria model,
80+
as well as additional parameters for image token handling and projector
81+
mapping.
82+
83+
Args:
84+
vision_config (AriaVisionConfig or dict): Configuration for the vision
85+
component.
86+
text_config (AriaMoELMConfig or dict): Configuration for the text
87+
component.
88+
projector_patch_to_query_dict (dict): Mapping of patch sizes to query
89+
dimensions.
90+
ignore_index (int): Index to ignore in loss calculation.
91+
image_token_index (int): Index used to represent image tokens.
92+
**kwargs: Additional keyword arguments passed to the parent class.
93+
Attributes:
94+
model_type (str): Type of the model, set to "aria".
95+
is_composition (bool): Whether the model is a composition of multiple
96+
components.
97+
ignore_index (int): Index to ignore in loss calculation.
98+
image_token_index (int): Index used to represent image tokens.
99+
projector_patch_to_query_dict (dict): Mapping of patch sizes to query
100+
dimensions.
101+
vision_config (AriaVisionConfig): Configuration for the vision
102+
component.
103+
text_config (AriaMoELMConfig): Configuration for the text component.
104+
"""
105+
106+
model_type = "aria"
107+
is_composition = False
108+
109+
def __init__(
110+
self,
111+
vision_config: AriaVisionConfig = AriaVisionConfig(), # noqa: B008
112+
text_config: AriaMoELMConfig = AriaMoELMConfig(), # noqa: B008
113+
projector_patch_to_query_dict: Mapping[int, int] = {
114+
1225: 128,
115+
4900: 256,
116+
},
117+
ignore_index=-100,
118+
image_token_index=32000,
119+
tie_word_embeddings=False,
120+
**kwargs,
121+
):
122+
super().__init__(**kwargs)
123+
self.ignore_index = ignore_index
124+
self.image_token_index = image_token_index
125+
self.tie_word_embeddings = tie_word_embeddings
126+
attn_implementation = kwargs.pop("attn_implementation", None)
127+
128+
# Set the default attention implementation to flash_attention_2 if not
129+
# specified
130+
self._attn_implementation = ("flash_attention_2"
131+
if attn_implementation is None else
132+
attn_implementation)
133+
134+
# Convert the keys and values of projector_patch_to_query_dict to
135+
# integers
136+
# This ensures consistency even if they were provided as strings
137+
self.projector_patch_to_query_dict = {
138+
int(k): int(v)
139+
for k, v in projector_patch_to_query_dict.items()
140+
}
141+
142+
if isinstance(vision_config, dict) and "model_type" in vision_config:
143+
vision_config = AriaVisionConfig(**vision_config)
144+
if attn_implementation is None:
145+
vision_attn_implementation = "flash_attention_2"
146+
elif attn_implementation == "sdpa":
147+
logger.warning("SDPA is not supported for vit, using "
148+
"flash_attention_2 instead")
149+
vision_attn_implementation = "flash_attention_2"
150+
else:
151+
vision_attn_implementation = attn_implementation
152+
vision_config._attn_implementation = vision_attn_implementation
153+
154+
self.vision_config = vision_config
155+
156+
if isinstance(text_config, dict) and "model_type" in text_config:
157+
text_attn_implementation = ("sdpa" if attn_implementation is None
158+
else attn_implementation)
159+
text_config = AriaMoELMConfig(**text_config)
160+
text_config._attn_implementation = text_attn_implementation
161+
162+
self.text_config = text_config
163+
164+
# This is needed for the static kv cache
165+
self.num_hidden_layers = self.text_config.num_hidden_layers

0 commit comments

Comments
 (0)