-
-
Notifications
You must be signed in to change notification settings - Fork 11k
[Refactor]: Use M-RoPE interface directly while defining model class instead of maintaining model specific M-RoPE implementation in mrope.py #24172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
08ae032
5944c15
875cb93
52bf0f0
f111410
302eb16
5d50b52
50da7bf
29c894a
7eb8753
5a88b92
c069204
f70f1d0
24b5db6
79b366a
377cbb0
6bfbe26
680dc58
0849f74
fb27644
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |
| # limitations under the License. | ||
| """Inference-only Erine VL model compatible with HuggingFace weights.""" | ||
|
|
||
| import itertools | ||
| import math | ||
| from collections.abc import Iterable, Mapping, Sequence | ||
| from functools import partial | ||
|
|
@@ -33,7 +34,7 @@ | |
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
| from einops import rearrange, repeat | ||
| from transformers import BatchFeature | ||
| from transformers import BatchFeature, PretrainedConfig | ||
|
|
||
| from vllm.attention.backends.registry import _Backend | ||
| from vllm.attention.layer import ( | ||
|
|
@@ -76,6 +77,7 @@ | |
| from .interfaces import ( | ||
| MultiModalEmbeddings, | ||
| SupportsLoRA, | ||
| SupportsMRoPE, | ||
| SupportsMultiModal, | ||
| SupportsPP, | ||
| ) | ||
|
|
@@ -1271,7 +1273,7 @@ def get_dummy_mm_data( | |
| dummy_inputs=Ernie4_5_VLDummyInputsBuilder, | ||
| ) | ||
| class Ernie4_5_VLMoeForConditionalGeneration( | ||
| nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP | ||
| nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE | ||
| ): | ||
| merge_by_field_config = True | ||
|
|
||
|
|
@@ -1388,6 +1390,151 @@ def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: | |
| else: | ||
| self.visual_token_mask = None | ||
|
|
||
| @classmethod | ||
| def get_mrope_input_positions( | ||
| cls, | ||
| input_tokens: list[int], | ||
| hf_config: PretrainedConfig, | ||
| image_grid_thw: Union[list[list[int]], torch.Tensor], | ||
| video_grid_thw: Union[list[list[int]], torch.Tensor], | ||
| context_len: int = 0, | ||
| seq_len: Optional[int] = None, | ||
| second_per_grid_ts: Optional[list[float]] = None, | ||
| audio_feature_lengths: Optional[torch.Tensor] = None, | ||
| use_audio_in_video: bool = False, | ||
| ) -> tuple[torch.Tensor, int]: | ||
| """Get mrope input positions and delta value for Ernie VL.""" | ||
|
|
||
| image_token_id = hf_config.im_patch_id | ||
| video_start_token_id = hf_config.video_start_token_id | ||
| video_end_token_id = hf_config.video_end_token_id | ||
| spatial_conv_size = hf_config.spatial_conv_size | ||
| temporal_conv_size = hf_config.temporal_conv_size | ||
| llm_pos_ids_list: list = [] | ||
|
|
||
| if not (image_grid_thw is None and video_grid_thw is None): | ||
| if isinstance(image_grid_thw, torch.Tensor): | ||
| image_grid_thw = image_grid_thw.tolist() | ||
|
|
||
| input_token_type: list[str] = [] | ||
| video_check_flg = False | ||
| for token in input_tokens: | ||
| if token == video_start_token_id: | ||
| video_check_flg = True | ||
| elif token == video_end_token_id: | ||
| video_check_flg = False | ||
|
|
||
| if (token == image_token_id) and (video_check_flg is False): | ||
| input_token_type.append("image") | ||
| elif (token == image_token_id) and (video_check_flg is True): | ||
| input_token_type.append("video") | ||
| else: | ||
| input_token_type.append("text") | ||
|
|
||
| input_type_group: list[tuple[str, int, int]] = [] | ||
| for key, group_iter in itertools.groupby( | ||
| enumerate(input_token_type), lambda x: x[1] | ||
| ): | ||
| group_list = list(group_iter) | ||
| start_index = group_list[0][0] | ||
| end_index = group_list[-1][0] + 1 | ||
| input_type_group.append((key, start_index, end_index)) | ||
|
|
||
| video_frame_num = 1 | ||
| mm_data_idx = 0 | ||
| for modality_type, start_idx, end_idx in input_type_group: | ||
| st_idx = ( | ||
| llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||
| ) | ||
| if modality_type == "image": | ||
| t, h, w = ( | ||
| image_grid_thw[mm_data_idx][0], | ||
| image_grid_thw[mm_data_idx][1], | ||
| image_grid_thw[mm_data_idx][2], | ||
| ) | ||
| llm_grid_t, llm_grid_h, llm_grid_w = ( | ||
| t, | ||
| h // spatial_conv_size, | ||
| w // spatial_conv_size, | ||
| ) | ||
|
|
||
| t_index = ( | ||
| torch.arange(llm_grid_t) | ||
| .view(-1, 1) | ||
| .expand(-1, llm_grid_h * llm_grid_w) | ||
| .flatten() | ||
| ) | ||
| h_index = ( | ||
| torch.arange(llm_grid_h) | ||
| .view(1, -1, 1) | ||
| .expand(llm_grid_t, -1, llm_grid_w) | ||
| .flatten() | ||
| ) | ||
| w_index = ( | ||
| torch.arange(llm_grid_w) | ||
| .view(1, 1, -1) | ||
| .expand(llm_grid_t, llm_grid_h, -1) | ||
| .flatten() | ||
| ) | ||
| llm_pos_ids_list.append( | ||
| torch.stack([t_index, h_index, w_index]) + st_idx | ||
| ) | ||
| mm_data_idx += 1 | ||
|
|
||
| elif modality_type == "video": | ||
| t, h, w = ( | ||
| video_grid_thw[mm_data_idx][0], | ||
| video_grid_thw[mm_data_idx][1], | ||
| video_grid_thw[mm_data_idx][2], | ||
| ) | ||
| llm_grid_t, llm_grid_h, llm_grid_w = ( | ||
| t // temporal_conv_size, | ||
| h // spatial_conv_size, | ||
| w // spatial_conv_size, | ||
| ) | ||
|
|
||
| for t_idx in range(llm_grid_t): | ||
| t_index = ( | ||
| torch.tensor(t_idx) | ||
| .view(-1, 1) | ||
| .expand(-1, llm_grid_h * llm_grid_w) | ||
| .flatten() | ||
| ) | ||
| h_index = ( | ||
| torch.arange(llm_grid_h) | ||
| .view(1, -1, 1) | ||
| .expand(1, -1, llm_grid_w) | ||
| .flatten() | ||
| ) | ||
| w_index = ( | ||
| torch.arange(llm_grid_w) | ||
| .view(1, 1, -1) | ||
| .expand(1, llm_grid_h, -1) | ||
| .flatten() | ||
| ) | ||
| llm_pos_ids_list.append( | ||
| torch.stack([t_index, h_index, w_index]) + st_idx | ||
| ) | ||
|
|
||
| mm_data_idx += 1 | ||
| video_frame_num += 1 | ||
|
|
||
| else: | ||
| text_len = end_idx - start_idx | ||
| llm_pos_ids_list.append( | ||
| torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx | ||
| ) | ||
| video_frame_num = 1 | ||
|
|
||
| else: | ||
| text_len = len(input_tokens) | ||
| llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) | ||
|
|
||
| llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) | ||
| llm_positions = llm_positions[:, context_len:seq_len] | ||
| mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() | ||
| return llm_positions, mrope_position_delta | ||
|
|
||
|
Comment on lines
+1393
to
+1537
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This large block of code for |
||
| def get_language_model(self) -> torch.nn.Module: | ||
| return self.language_model | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |||||||||||||||||
| # https://github.com/zai-org/CogAgent | ||||||||||||||||||
| """Inference-only CogAgent model compatible with THUDM weights.""" | ||||||||||||||||||
|
|
||||||||||||||||||
| import itertools | ||||||||||||||||||
| from argparse import Namespace | ||||||||||||||||||
| from collections.abc import Mapping, Sequence | ||||||||||||||||||
| from typing import Annotated, Literal, Optional, Union | ||||||||||||||||||
|
|
@@ -14,7 +15,7 @@ | |||||||||||||||||
| from torch.nn import LayerNorm | ||||||||||||||||||
| from torchvision import transforms | ||||||||||||||||||
| from torchvision.transforms import InterpolationMode | ||||||||||||||||||
| from transformers import BatchFeature, PreTrainedTokenizer, TensorType | ||||||||||||||||||
| from transformers import BatchFeature, PretrainedConfig, PreTrainedTokenizer, TensorType | ||||||||||||||||||
| from transformers.image_utils import ImageInput | ||||||||||||||||||
| from transformers.tokenization_utils_base import TextInput | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -54,6 +55,7 @@ | |||||||||||||||||
| from .interfaces import ( | ||||||||||||||||||
| MultiModalEmbeddings, | ||||||||||||||||||
| SupportsLoRA, | ||||||||||||||||||
| SupportsMRoPE, | ||||||||||||||||||
| SupportsMultiModal, | ||||||||||||||||||
| SupportsPP, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
@@ -554,7 +556,9 @@ def get_replacement(item_idx: int): | |||||||||||||||||
| info=GLM4VProcessingInfo, | ||||||||||||||||||
| dummy_inputs=GLM4VDummyInputsBuilder, | ||||||||||||||||||
| ) | ||||||||||||||||||
| class GLM4VForCausalLM(ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA, SupportsPP): | ||||||||||||||||||
| class GLM4VForCausalLM( | ||||||||||||||||||
| ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE | ||||||||||||||||||
| ): | ||||||||||||||||||
| merge_by_field_config = True | ||||||||||||||||||
|
|
||||||||||||||||||
| packed_modules_mapping = { | ||||||||||||||||||
|
|
@@ -615,6 +619,150 @@ def _process_image_input(self, image_input: GLMVImagePixelInputs) -> torch.Tenso | |||||||||||||||||
|
|
||||||||||||||||||
| return self.transformer.vision(pixel_values) | ||||||||||||||||||
|
|
||||||||||||||||||
| @classmethod | ||||||||||||||||||
| def get_mrope_input_positions( | ||||||||||||||||||
| cls, | ||||||||||||||||||
| input_tokens: list[int], | ||||||||||||||||||
| hf_config: PretrainedConfig, | ||||||||||||||||||
| image_grid_thw: Union[list[list[int]], torch.Tensor], | ||||||||||||||||||
| video_grid_thw: Union[list[list[int]], torch.Tensor], | ||||||||||||||||||
| context_len: int = 0, | ||||||||||||||||||
| seq_len: Optional[int] = None, | ||||||||||||||||||
| second_per_grid_ts: Optional[list[float]] = None, | ||||||||||||||||||
| audio_feature_lengths: Optional[torch.Tensor] = None, | ||||||||||||||||||
| use_audio_in_video: bool = False, | ||||||||||||||||||
| ) -> tuple[torch.Tensor, int]: | ||||||||||||||||||
| """Get mrope input positions and delta value for GLM4V.""" | ||||||||||||||||||
|
|
||||||||||||||||||
| image_token_id = hf_config.image_token_id | ||||||||||||||||||
| video_start_token_id = hf_config.video_start_token_id | ||||||||||||||||||
| video_end_token_id = hf_config.video_end_token_id | ||||||||||||||||||
| spatial_merge_size = hf_config.vision_config.spatial_merge_size | ||||||||||||||||||
| llm_pos_ids_list: list = [] | ||||||||||||||||||
|
|
||||||||||||||||||
| if not (image_grid_thw is None and video_grid_thw is None): | ||||||||||||||||||
| if isinstance(image_grid_thw, torch.Tensor): | ||||||||||||||||||
| image_grid_thw = image_grid_thw.tolist() | ||||||||||||||||||
|
|
||||||||||||||||||
| input_token_type: list[str] = [] | ||||||||||||||||||
| video_check_flg = False | ||||||||||||||||||
| for token in input_tokens: | ||||||||||||||||||
| if token == video_start_token_id: | ||||||||||||||||||
| video_check_flg = True | ||||||||||||||||||
| elif token == video_end_token_id: | ||||||||||||||||||
| video_check_flg = False | ||||||||||||||||||
|
|
||||||||||||||||||
| if (token == image_token_id) and (video_check_flg is False): | ||||||||||||||||||
| input_token_type.append("image") | ||||||||||||||||||
| elif (token == image_token_id) and (video_check_flg is True): | ||||||||||||||||||
| input_token_type.append("video") | ||||||||||||||||||
| else: | ||||||||||||||||||
| input_token_type.append("text") | ||||||||||||||||||
|
|
||||||||||||||||||
| input_type_group: list[tuple[str, int, int]] = [] | ||||||||||||||||||
| for key, group_iter in itertools.groupby( | ||||||||||||||||||
| enumerate(input_token_type), lambda x: x[1] | ||||||||||||||||||
| ): | ||||||||||||||||||
| group_list = list(group_iter) | ||||||||||||||||||
| start_index = group_list[0][0] | ||||||||||||||||||
| end_index = group_list[-1][0] + 1 | ||||||||||||||||||
| input_type_group.append((key, start_index, end_index)) | ||||||||||||||||||
|
|
||||||||||||||||||
| video_frame_num = 1 | ||||||||||||||||||
| mm_data_idx = 0 | ||||||||||||||||||
| for modality_type, start_idx, end_idx in input_type_group: | ||||||||||||||||||
| st_idx = ( | ||||||||||||||||||
| llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||||||||||||||||||
| ) | ||||||||||||||||||
| if modality_type == "image": | ||||||||||||||||||
| t, h, w = ( | ||||||||||||||||||
| image_grid_thw[mm_data_idx][0], | ||||||||||||||||||
| image_grid_thw[mm_data_idx][1], | ||||||||||||||||||
| image_grid_thw[mm_data_idx][2], | ||||||||||||||||||
| ) | ||||||||||||||||||
| llm_grid_t, llm_grid_h, llm_grid_w = ( | ||||||||||||||||||
| t, | ||||||||||||||||||
| h // spatial_merge_size, | ||||||||||||||||||
| w // spatial_merge_size, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| t_index = ( | ||||||||||||||||||
| torch.arange(llm_grid_t) | ||||||||||||||||||
| .view(-1, 1) | ||||||||||||||||||
| .expand(-1, llm_grid_h * llm_grid_w) | ||||||||||||||||||
| .flatten() | ||||||||||||||||||
| ) | ||||||||||||||||||
| h_index = ( | ||||||||||||||||||
| torch.arange(llm_grid_h) | ||||||||||||||||||
| .view(1, -1, 1) | ||||||||||||||||||
| .expand(llm_grid_t, -1, llm_grid_w) | ||||||||||||||||||
| .flatten() | ||||||||||||||||||
| ) | ||||||||||||||||||
| w_index = ( | ||||||||||||||||||
| torch.arange(llm_grid_w) | ||||||||||||||||||
| .view(1, 1, -1) | ||||||||||||||||||
| .expand(llm_grid_t, llm_grid_h, -1) | ||||||||||||||||||
| .flatten() | ||||||||||||||||||
| ) | ||||||||||||||||||
| llm_pos_ids_list.append( | ||||||||||||||||||
| torch.stack([t_index, h_index, w_index]) + st_idx | ||||||||||||||||||
| ) | ||||||||||||||||||
| mm_data_idx += 1 | ||||||||||||||||||
|
|
||||||||||||||||||
| elif modality_type == "video": | ||||||||||||||||||
| t, h, w = ( | ||||||||||||||||||
| video_frame_num, | ||||||||||||||||||
| image_grid_thw[mm_data_idx][1], | ||||||||||||||||||
| image_grid_thw[mm_data_idx][2], | ||||||||||||||||||
| ) | ||||||||||||||||||
| llm_grid_t, llm_grid_h, llm_grid_w = ( | ||||||||||||||||||
| t, | ||||||||||||||||||
| h // spatial_merge_size, | ||||||||||||||||||
| w // spatial_merge_size, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| for t_idx in range(llm_grid_t): | ||||||||||||||||||
| t_index = ( | ||||||||||||||||||
| torch.tensor(t_idx) | ||||||||||||||||||
| .view(-1, 1) | ||||||||||||||||||
| .expand(-1, llm_grid_h * llm_grid_w) | ||||||||||||||||||
| .flatten() | ||||||||||||||||||
| ) | ||||||||||||||||||
| h_index = ( | ||||||||||||||||||
| torch.arange(llm_grid_h) | ||||||||||||||||||
| .view(1, -1, 1) | ||||||||||||||||||
| .expand(1, -1, llm_grid_w) | ||||||||||||||||||
| .flatten() | ||||||||||||||||||
| ) | ||||||||||||||||||
| w_index = ( | ||||||||||||||||||
| torch.arange(llm_grid_w) | ||||||||||||||||||
| .view(1, 1, -1) | ||||||||||||||||||
| .expand(1, llm_grid_h, -1) | ||||||||||||||||||
| .flatten() | ||||||||||||||||||
| ) | ||||||||||||||||||
| llm_pos_ids_list.append( | ||||||||||||||||||
| torch.stack([t_index, h_index, w_index]) + st_idx | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| mm_data_idx += 1 | ||||||||||||||||||
| video_frame_num += 1 | ||||||||||||||||||
|
|
||||||||||||||||||
| else: | ||||||||||||||||||
| text_len = end_idx - start_idx | ||||||||||||||||||
| llm_pos_ids_list.append( | ||||||||||||||||||
| torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx | ||||||||||||||||||
| ) | ||||||||||||||||||
| video_frame_num = 1 | ||||||||||||||||||
|
|
||||||||||||||||||
| else: | ||||||||||||||||||
| text_len = len(input_tokens) | ||||||||||||||||||
| llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) | ||||||||||||||||||
|
|
||||||||||||||||||
| llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) | ||||||||||||||||||
| llm_positions = llm_positions[:, context_len:seq_len] | ||||||||||||||||||
| mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() | ||||||||||||||||||
| return llm_positions, mrope_position_delta | ||||||||||||||||||
|
Comment on lines
+761
to
+764
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The calculation of
Suggested change
|
||||||||||||||||||
|
|
||||||||||||||||||
| def get_language_model(self) -> torch.nn.Module: | ||||||||||||||||||
| return self.transformer | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation of
mrope_position_deltaappears to be incorrect. It's computed afterllm_positionsis sliced withcontext_lenandseq_len. Ifcontext_len > 0,llm_positions.max()will be smaller than the true maximum position, leading to an incorrect delta. This can cause issues with positional embeddings in subsequent steps. The delta should be calculated from the maximum position of the full sequence before slicing.