Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,015 changes: 0 additions & 1,015 deletions vllm/model_executor/layers/rotary_embedding/mrope.py

Large diffs are not rendered by default.

151 changes: 149 additions & 2 deletions vllm/model_executor/models/ernie45_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
# limitations under the License.
"""Inference-only Erine VL model compatible with HuggingFace weights."""

import itertools
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
Expand All @@ -33,7 +34,7 @@
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from transformers import BatchFeature
from transformers import BatchFeature, PretrainedConfig

from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import (
Expand Down Expand Up @@ -76,6 +77,7 @@
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMRoPE,
SupportsMultiModal,
SupportsPP,
)
Expand Down Expand Up @@ -1271,7 +1273,7 @@ def get_dummy_mm_data(
dummy_inputs=Ernie4_5_VLDummyInputsBuilder,
)
class Ernie4_5_VLMoeForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
merge_by_field_config = True

Expand Down Expand Up @@ -1388,6 +1390,151 @@ def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
else:
self.visual_token_mask = None

@classmethod
def get_mrope_input_positions(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
context_len: int = 0,
seq_len: Optional[int] = None,
second_per_grid_ts: Optional[list[float]] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value for Ernie VL."""

image_token_id = hf_config.im_patch_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_conv_size = hf_config.spatial_conv_size
temporal_conv_size = hf_config.temporal_conv_size
llm_pos_ids_list: list = []

if not (image_grid_thw is None and video_grid_thw is None):
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()

input_token_type: list[str] = []
video_check_flg = False
for token in input_tokens:
if token == video_start_token_id:
video_check_flg = True
elif token == video_end_token_id:
video_check_flg = False

if (token == image_token_id) and (video_check_flg is False):
input_token_type.append("image")
elif (token == image_token_id) and (video_check_flg is True):
input_token_type.append("video")
else:
input_token_type.append("text")

input_type_group: list[tuple[str, int, int]] = []
for key, group_iter in itertools.groupby(
enumerate(input_token_type), lambda x: x[1]
):
group_list = list(group_iter)
start_index = group_list[0][0]
end_index = group_list[-1][0] + 1
input_type_group.append((key, start_index, end_index))

video_frame_num = 1
mm_data_idx = 0
for modality_type, start_idx, end_idx in input_type_group:
st_idx = (
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
if modality_type == "image":
t, h, w = (
image_grid_thw[mm_data_idx][0],
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_conv_size,
w // spatial_conv_size,
)

t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx
)
mm_data_idx += 1

elif modality_type == "video":
t, h, w = (
video_grid_thw[mm_data_idx][0],
video_grid_thw[mm_data_idx][1],
video_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = (
t // temporal_conv_size,
h // spatial_conv_size,
w // spatial_conv_size,
)

for t_idx in range(llm_grid_t):
t_index = (
torch.tensor(t_idx)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(1, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(1, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx
)

mm_data_idx += 1
video_frame_num += 1

else:
text_len = end_idx - start_idx
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
video_frame_num = 1

else:
text_len = len(input_tokens)
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))

llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:seq_len]
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta
Comment on lines +1533 to +1536
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The calculation of mrope_position_delta appears to be incorrect. It's computed after llm_positions is sliced with context_len and seq_len. If context_len > 0, llm_positions.max() will be smaller than the true maximum position, leading to an incorrect delta. This can cause issues with positional embeddings in subsequent steps. The delta should be calculated from the maximum position of the full sequence before slicing.

Suggested change
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:seq_len]
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta


Comment on lines +1393 to +1537
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This large block of code for get_mrope_input_positions is nearly identical to the implementation in vllm/model_executor/models/glm4v.py. This introduces significant code duplication, making future maintenance more difficult and error-prone. Consider refactoring this logic into a shared helper function or a base class method to consolidate the common parts. The model-specific differences (like config keys and some calculations) could be passed as parameters or handled through a configuration object.

def get_language_model(self) -> torch.nn.Module:
return self.language_model

Expand Down
152 changes: 150 additions & 2 deletions vllm/model_executor/models/glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# https://github.com/zai-org/CogAgent
"""Inference-only CogAgent model compatible with THUDM weights."""

import itertools
from argparse import Namespace
from collections.abc import Mapping, Sequence
from typing import Annotated, Literal, Optional, Union
Expand All @@ -14,7 +15,7 @@
from torch.nn import LayerNorm
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from transformers import BatchFeature, PreTrainedTokenizer, TensorType
from transformers import BatchFeature, PretrainedConfig, PreTrainedTokenizer, TensorType
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput

Expand Down Expand Up @@ -54,6 +55,7 @@
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMRoPE,
SupportsMultiModal,
SupportsPP,
)
Expand Down Expand Up @@ -554,7 +556,9 @@ def get_replacement(item_idx: int):
info=GLM4VProcessingInfo,
dummy_inputs=GLM4VDummyInputsBuilder,
)
class GLM4VForCausalLM(ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA, SupportsPP):
class GLM4VForCausalLM(
ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
merge_by_field_config = True

packed_modules_mapping = {
Expand Down Expand Up @@ -615,6 +619,150 @@ def _process_image_input(self, image_input: GLMVImagePixelInputs) -> torch.Tenso

return self.transformer.vision(pixel_values)

@classmethod
def get_mrope_input_positions(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
context_len: int = 0,
seq_len: Optional[int] = None,
second_per_grid_ts: Optional[list[float]] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value for GLM4V."""

image_token_id = hf_config.image_token_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
llm_pos_ids_list: list = []

if not (image_grid_thw is None and video_grid_thw is None):
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()

input_token_type: list[str] = []
video_check_flg = False
for token in input_tokens:
if token == video_start_token_id:
video_check_flg = True
elif token == video_end_token_id:
video_check_flg = False

if (token == image_token_id) and (video_check_flg is False):
input_token_type.append("image")
elif (token == image_token_id) and (video_check_flg is True):
input_token_type.append("video")
else:
input_token_type.append("text")

input_type_group: list[tuple[str, int, int]] = []
for key, group_iter in itertools.groupby(
enumerate(input_token_type), lambda x: x[1]
):
group_list = list(group_iter)
start_index = group_list[0][0]
end_index = group_list[-1][0] + 1
input_type_group.append((key, start_index, end_index))

video_frame_num = 1
mm_data_idx = 0
for modality_type, start_idx, end_idx in input_type_group:
st_idx = (
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
if modality_type == "image":
t, h, w = (
image_grid_thw[mm_data_idx][0],
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)

t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx
)
mm_data_idx += 1

elif modality_type == "video":
t, h, w = (
video_frame_num,
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)

for t_idx in range(llm_grid_t):
t_index = (
torch.tensor(t_idx)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(1, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(1, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx
)

mm_data_idx += 1
video_frame_num += 1

else:
text_len = end_idx - start_idx
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
video_frame_num = 1

else:
text_len = len(input_tokens)
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))

llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:seq_len]
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta
Comment on lines +761 to +764
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The calculation of mrope_position_delta appears to be incorrect. It's computed after llm_positions is sliced with context_len and seq_len. If context_len > 0, llm_positions.max() will be smaller than the true maximum position, leading to an incorrect delta. This can cause issues with positional embeddings in subsequent steps. The delta should be calculated from the maximum position of the full sequence before slicing.

Suggested change
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:seq_len]
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta


def get_language_model(self) -> torch.nn.Module:
return self.transformer

Expand Down
Loading