Skip to content

Commit 7708532

Browse files
AzizCode92gemini-code-assist[bot]DarkLight1337
authored andcommitted
[feat]: Create interface for model-specific M-RoPE (vllm-project#24194)
Signed-off-by: AzizCode92 <azizbenothman76@gmail.com> Signed-off-by: Aziz <azizbenothman76@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
1 parent dbf13c2 commit 7708532

File tree

5 files changed

+242
-30
lines changed

5 files changed

+242
-30
lines changed

vllm/model_executor/models/__init__.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
5-
SupportsPP, SupportsTranscription, SupportsV0Only,
6-
has_inner_state, supports_lora, supports_multimodal,
7-
supports_pp, supports_transcription, supports_v0_only)
4+
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMRoPE,
5+
SupportsMultiModal, SupportsPP, SupportsTranscription,
6+
SupportsV0Only, has_inner_state, supports_lora,
7+
supports_mrope, supports_multimodal, supports_pp,
8+
supports_transcription, supports_v0_only)
89
from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration,
910
is_pooling_model, is_text_generation_model)
1011
from .registry import ModelRegistry
@@ -21,6 +22,8 @@
2122
"supports_lora",
2223
"SupportsMultiModal",
2324
"supports_multimodal",
25+
"SupportsMRoPE",
26+
"supports_mrope",
2427
"SupportsPP",
2528
"supports_pp",
2629
"SupportsTranscription",

vllm/model_executor/models/interfaces.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import torch
1010
from torch import Tensor
11+
from transformers import PretrainedConfig
1112
from transformers.models.whisper.tokenization_whisper import LANGUAGES
1213
from typing_extensions import Self, TypeIs
1314

@@ -852,3 +853,70 @@ def supports_eagle3(
852853
model: Union[type[object], object],
853854
) -> Union[TypeIs[type[SupportsEagle3]], TypeIs[SupportsEagle3]]:
854855
return isinstance(model, SupportsEagle3)
856+
857+
858+
@runtime_checkable
859+
class SupportsMRoPE(Protocol):
860+
"""The interface required for all models that support M-RoPE."""
861+
862+
supports_mrope: ClassVar[Literal[True]] = True
863+
"""
864+
A flag that indicates this model supports M-RoPE.
865+
866+
Note:
867+
There is no need to redefine this flag if this class is in the
868+
MRO of your model class.
869+
"""
870+
871+
def get_mrope_input_positions(
872+
self,
873+
input_tokens: list[int],
874+
hf_config: PretrainedConfig,
875+
image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
876+
video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
877+
second_per_grid_ts: Optional[list[float]] = None,
878+
context_len: int = 0,
879+
seq_len: Optional[int] = None,
880+
audio_feature_lengths: Optional[torch.Tensor] = None,
881+
use_audio_in_video: bool = False,
882+
) -> tuple[torch.Tensor, int]:
883+
"""
884+
Get M-RoPE input positions and delta value for this specific model.
885+
886+
This method should be implemented by each model that supports M-RoPE
887+
to provide model-specific logic for computing input positions.
888+
889+
Args:
890+
input_tokens: List of input token IDs
891+
hf_config: HuggingFace model configuration
892+
image_grid_thw: Image grid dimensions (t, h, w)
893+
video_grid_thw: Video grid dimensions (t, h, w)
894+
second_per_grid_ts: Seconds per grid timestep for videos
895+
context_len: Context length
896+
seq_len: Sequence length
897+
audio_feature_lengths: Audio feature lengths for multimodal models
898+
use_audio_in_video: Whether to use audio in video for interleaving
899+
900+
Returns:
901+
Tuple of (llm_positions, mrope_position_delta)
902+
- llm_positions: Tensor of shape [3, num_tokens]
903+
with T/H/W positions
904+
- mrope_position_delta: Delta for position calculations
905+
"""
906+
...
907+
908+
909+
@overload
910+
def supports_mrope(model: type[object]) -> TypeIs[type[SupportsMRoPE]]:
911+
...
912+
913+
914+
@overload
915+
def supports_mrope(model: object) -> TypeIs[SupportsMRoPE]:
916+
...
917+
918+
919+
def supports_mrope(
920+
model: Union[type[object], object],
921+
) -> Union[TypeIs[type[SupportsMRoPE]], TypeIs[SupportsMRoPE]]:
922+
return isinstance(model, SupportsMRoPE)

vllm/model_executor/models/qwen2_vl.py

Lines changed: 115 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import torch.nn as nn
3333
import torch.nn.functional as F
3434
from einops import rearrange, repeat
35-
from transformers import AutoConfig, BatchFeature
35+
from transformers import AutoConfig, BatchFeature, PretrainedConfig
3636
from transformers.models.qwen2_vl import (Qwen2VLImageProcessor,
3737
Qwen2VLProcessor)
3838
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
@@ -73,7 +73,7 @@
7373
from vllm.transformers_utils.tokenizer import AnyTokenizer
7474
from vllm.utils.tensor_schema import TensorSchema, TensorShape
7575

76-
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
76+
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMRoPE,
7777
SupportsMultiModal, SupportsPP)
7878
from .utils import (AutoWeightsLoader, WeightsMapper,
7979
init_vllm_registered_model, maybe_prefix,
@@ -1096,7 +1096,7 @@ def _get_mm_fields_config(
10961096
info=Qwen2VLProcessingInfo,
10971097
dummy_inputs=Qwen2VLDummyInputsBuilder)
10981098
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
1099-
SupportsLoRA, SupportsPP):
1099+
SupportsLoRA, SupportsPP, SupportsMRoPE):
11001100

11011101
# To ensure correct weight loading and mapping.
11021102
hf_to_vllm_mapper = WeightsMapper(
@@ -1109,6 +1109,118 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
11091109
"model.": "language_model.model.",
11101110
})
11111111

1112+
def get_mrope_input_positions(
1113+
self,
1114+
input_tokens: list[int],
1115+
hf_config: PretrainedConfig,
1116+
image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
1117+
video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
1118+
second_per_grid_ts: Optional[list[float]] = None,
1119+
context_len: int = 0,
1120+
seq_len: Optional[int] = None,
1121+
audio_feature_lengths: Optional[torch.Tensor] = None,
1122+
use_audio_in_video: bool = False,
1123+
) -> tuple[torch.Tensor, int]:
1124+
"""Get M-RoPE input positions for Qwen2-VL model."""
1125+
if image_grid_thw is None:
1126+
image_grid_thw = []
1127+
if video_grid_thw is None:
1128+
video_grid_thw = []
1129+
if second_per_grid_ts is None:
1130+
second_per_grid_ts = []
1131+
1132+
image_token_id = hf_config.image_token_id
1133+
video_token_id = hf_config.video_token_id
1134+
vision_start_token_id = hf_config.vision_start_token_id
1135+
spatial_merge_size = hf_config.vision_config.spatial_merge_size
1136+
tokens_per_second = getattr(hf_config.vision_config,
1137+
"tokens_per_second", 1.0)
1138+
1139+
input_tokens_tensor = torch.tensor(input_tokens)
1140+
vision_start_indices = torch.argwhere(
1141+
input_tokens_tensor == vision_start_token_id).squeeze(1)
1142+
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
1143+
image_nums = (vision_tokens == image_token_id).sum()
1144+
video_nums = (vision_tokens == video_token_id).sum()
1145+
llm_pos_ids_list: list = []
1146+
1147+
st = 0
1148+
remain_images, remain_videos = image_nums, video_nums
1149+
1150+
image_index, video_index = 0, 0
1151+
for _ in range(image_nums + video_nums):
1152+
video_second_per_grid_t = 0.0
1153+
if remain_images > 0:
1154+
try:
1155+
ed_image = input_tokens.index(image_token_id, st)
1156+
except ValueError:
1157+
ed_image = len(input_tokens) + 1
1158+
else:
1159+
ed_image = len(input_tokens) + 1
1160+
if remain_videos > 0:
1161+
try:
1162+
ed_video = input_tokens.index(video_token_id, st)
1163+
except ValueError:
1164+
ed_video = len(input_tokens) + 1
1165+
else:
1166+
ed_video = len(input_tokens) + 1
1167+
if ed_image < ed_video:
1168+
t, h, w = (
1169+
image_grid_thw[image_index][0],
1170+
image_grid_thw[image_index][1],
1171+
image_grid_thw[image_index][2],
1172+
)
1173+
image_index += 1
1174+
remain_images -= 1
1175+
ed = ed_image
1176+
else:
1177+
t, h, w = (
1178+
video_grid_thw[video_index][0],
1179+
video_grid_thw[video_index][1],
1180+
video_grid_thw[video_index][2],
1181+
)
1182+
video_second_per_grid_t = 1.0
1183+
if second_per_grid_ts:
1184+
video_second_per_grid_t = second_per_grid_ts[video_index]
1185+
video_index += 1
1186+
remain_videos -= 1
1187+
ed = ed_video
1188+
1189+
llm_grid_t, llm_grid_h, llm_grid_w = \
1190+
t, h // spatial_merge_size, w // spatial_merge_size
1191+
text_len = ed - st
1192+
1193+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
1194+
llm_pos_ids_list) > 0 else 0
1195+
llm_pos_ids_list.append(
1196+
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
1197+
1198+
t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
1199+
-1, llm_grid_h * llm_grid_w) * video_second_per_grid_t *
1200+
tokens_per_second).long().flatten()
1201+
1202+
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
1203+
llm_grid_t, -1, llm_grid_w).flatten()
1204+
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
1205+
llm_grid_t, llm_grid_h, -1).flatten()
1206+
llm_pos_ids_list.append(
1207+
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
1208+
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
1209+
1210+
if st < len(input_tokens):
1211+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
1212+
llm_pos_ids_list) > 0 else 0
1213+
text_len = len(input_tokens) - st
1214+
llm_pos_ids_list.append(
1215+
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
1216+
1217+
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1218+
mrope_position_delta = (llm_positions.max() + 1 -
1219+
len(input_tokens)).item()
1220+
llm_positions = llm_positions[:, context_len:seq_len]
1221+
1222+
return llm_positions, mrope_position_delta
1223+
11121224
@classmethod
11131225
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
11141226
if modality.startswith("image"):

vllm/v1/worker/gpu_model_runner.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
4343
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
4444
supports_eagle3,
45+
supports_mrope,
4546
supports_transcription)
4647
from vllm.model_executor.models.interfaces_base import (
4748
VllmModelForPooling, is_pooling_model, is_text_generation_model)
@@ -730,16 +731,28 @@ def _init_mrope_positions(self, req_state: CachedRequestState):
730731
if mm_input.get("use_audio_in_video") is True:
731732
use_audio_in_video = True
732733

733-
req_state.mrope_positions, req_state.mrope_position_delta = \
734-
MRotaryEmbedding.get_input_positions_tensor(
735-
req_state.prompt_token_ids,
736-
hf_config=self.model_config.hf_config,
737-
image_grid_thw=image_grid_thw,
738-
video_grid_thw=video_grid_thw,
739-
second_per_grid_ts=second_per_grid_ts,
740-
audio_feature_lengths=audio_feature_lengths,
741-
use_audio_in_video=use_audio_in_video,
742-
)
734+
if supports_mrope(self.model):
735+
req_state.mrope_positions, req_state.mrope_position_delta = \
736+
self.model.get_mrope_input_positions(
737+
req_state.prompt_token_ids,
738+
hf_config=self.model_config.hf_config,
739+
image_grid_thw=image_grid_thw,
740+
video_grid_thw=video_grid_thw,
741+
second_per_grid_ts=second_per_grid_ts,
742+
audio_feature_lengths=audio_feature_lengths,
743+
use_audio_in_video=use_audio_in_video,
744+
)
745+
else:
746+
req_state.mrope_positions, req_state.mrope_position_delta = \
747+
MRotaryEmbedding.get_input_positions_tensor(
748+
req_state.prompt_token_ids,
749+
hf_config=self.model_config.hf_config,
750+
image_grid_thw=image_grid_thw,
751+
video_grid_thw=video_grid_thw,
752+
second_per_grid_ts=second_per_grid_ts,
753+
audio_feature_lengths=audio_feature_lengths,
754+
use_audio_in_video=use_audio_in_video,
755+
)
743756

744757
def _extract_mm_kwargs(
745758
self,

vllm/worker/model_runner.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@
4141
get_sampler)
4242
from vllm.model_executor.model_loader import get_model
4343
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
44-
from vllm.model_executor.models import supports_lora, supports_multimodal
44+
from vllm.model_executor.models import (supports_lora, supports_mrope,
45+
supports_multimodal)
4546
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
4647
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
4748
MultiModalKwargs, MultiModalPlaceholderMap,
@@ -670,18 +671,33 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
670671
inter_data.seq_ids[seq_idx]]
671672
token_ids = seq_data.get_token_ids()
672673

673-
mrope_input_positions, mrope_position_delta = \
674-
MRotaryEmbedding.get_input_positions(
675-
token_ids,
676-
hf_config=hf_config,
677-
image_grid_thw=image_grid_thw,
678-
video_grid_thw=video_grid_thw,
679-
second_per_grid_ts=second_per_grid_ts,
680-
context_len=inter_data.context_lens[seq_idx],
681-
seq_len=inter_data.seq_lens[seq_idx],
682-
audio_feature_lengths=audio_feature_lengths,
683-
use_audio_in_video=use_audio_in_video,
684-
)
674+
if supports_mrope(self.runner.model):
675+
mrope_input_positions, mrope_position_delta = \
676+
self.runner.model.get_mrope_input_positions(
677+
token_ids,
678+
hf_config=hf_config,
679+
image_grid_thw=image_grid_thw,
680+
video_grid_thw=video_grid_thw,
681+
second_per_grid_ts=second_per_grid_ts,
682+
context_len=inter_data.context_lens[seq_idx],
683+
seq_len=inter_data.seq_lens[seq_idx],
684+
audio_feature_lengths=audio_feature_lengths,
685+
use_audio_in_video=use_audio_in_video,
686+
)
687+
mrope_input_positions = mrope_input_positions.tolist()
688+
else:
689+
mrope_input_positions, mrope_position_delta = \
690+
MRotaryEmbedding.get_input_positions(
691+
token_ids,
692+
hf_config=hf_config,
693+
image_grid_thw=image_grid_thw,
694+
video_grid_thw=video_grid_thw,
695+
second_per_grid_ts=second_per_grid_ts,
696+
context_len=inter_data.context_lens[seq_idx],
697+
seq_len=inter_data.seq_lens[seq_idx],
698+
audio_feature_lengths=audio_feature_lengths,
699+
use_audio_in_video=use_audio_in_video,
700+
)
685701

686702
seq_data.mrope_position_delta = mrope_position_delta
687703
inter_data.mrope_input_positions[

0 commit comments

Comments
 (0)