3232import torch .nn as nn
3333import torch .nn .functional as F
3434from einops import rearrange , repeat
35- from transformers import AutoConfig , BatchFeature
35+ from transformers import AutoConfig , BatchFeature , PretrainedConfig
3636from transformers .models .qwen2_vl import (Qwen2VLImageProcessor ,
3737 Qwen2VLProcessor )
3838from transformers .models .qwen2_vl .configuration_qwen2_vl import (
7373from vllm .transformers_utils .tokenizer import AnyTokenizer
7474from vllm .utils .tensor_schema import TensorSchema , TensorShape
7575
76- from .interfaces import (MultiModalEmbeddings , SupportsLoRA ,
76+ from .interfaces import (MultiModalEmbeddings , SupportsLoRA , SupportsMRoPE ,
7777 SupportsMultiModal , SupportsPP )
7878from .utils import (AutoWeightsLoader , WeightsMapper ,
7979 init_vllm_registered_model , maybe_prefix ,
@@ -1096,7 +1096,7 @@ def _get_mm_fields_config(
10961096 info = Qwen2VLProcessingInfo ,
10971097 dummy_inputs = Qwen2VLDummyInputsBuilder )
10981098class Qwen2VLForConditionalGeneration (nn .Module , SupportsMultiModal ,
1099- SupportsLoRA , SupportsPP ):
1099+ SupportsLoRA , SupportsPP , SupportsMRoPE ):
11001100
11011101 # To ensure correct weight loading and mapping.
11021102 hf_to_vllm_mapper = WeightsMapper (
@@ -1109,6 +1109,118 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
11091109 "model." : "language_model.model." ,
11101110 })
11111111
1112+ def get_mrope_input_positions (
1113+ self ,
1114+ input_tokens : list [int ],
1115+ hf_config : PretrainedConfig ,
1116+ image_grid_thw : Optional [Union [list [list [int ]], torch .Tensor ]],
1117+ video_grid_thw : Optional [Union [list [list [int ]], torch .Tensor ]],
1118+ second_per_grid_ts : Optional [list [float ]] = None ,
1119+ context_len : int = 0 ,
1120+ seq_len : Optional [int ] = None ,
1121+ audio_feature_lengths : Optional [torch .Tensor ] = None ,
1122+ use_audio_in_video : bool = False ,
1123+ ) -> tuple [torch .Tensor , int ]:
1124+ """Get M-RoPE input positions for Qwen2-VL model."""
1125+ if image_grid_thw is None :
1126+ image_grid_thw = []
1127+ if video_grid_thw is None :
1128+ video_grid_thw = []
1129+ if second_per_grid_ts is None :
1130+ second_per_grid_ts = []
1131+
1132+ image_token_id = hf_config .image_token_id
1133+ video_token_id = hf_config .video_token_id
1134+ vision_start_token_id = hf_config .vision_start_token_id
1135+ spatial_merge_size = hf_config .vision_config .spatial_merge_size
1136+ tokens_per_second = getattr (hf_config .vision_config ,
1137+ "tokens_per_second" , 1.0 )
1138+
1139+ input_tokens_tensor = torch .tensor (input_tokens )
1140+ vision_start_indices = torch .argwhere (
1141+ input_tokens_tensor == vision_start_token_id ).squeeze (1 )
1142+ vision_tokens = input_tokens_tensor [vision_start_indices + 1 ]
1143+ image_nums = (vision_tokens == image_token_id ).sum ()
1144+ video_nums = (vision_tokens == video_token_id ).sum ()
1145+ llm_pos_ids_list : list = []
1146+
1147+ st = 0
1148+ remain_images , remain_videos = image_nums , video_nums
1149+
1150+ image_index , video_index = 0 , 0
1151+ for _ in range (image_nums + video_nums ):
1152+ video_second_per_grid_t = 0.0
1153+ if remain_images > 0 :
1154+ try :
1155+ ed_image = input_tokens .index (image_token_id , st )
1156+ except ValueError :
1157+ ed_image = len (input_tokens ) + 1
1158+ else :
1159+ ed_image = len (input_tokens ) + 1
1160+ if remain_videos > 0 :
1161+ try :
1162+ ed_video = input_tokens .index (video_token_id , st )
1163+ except ValueError :
1164+ ed_video = len (input_tokens ) + 1
1165+ else :
1166+ ed_video = len (input_tokens ) + 1
1167+ if ed_image < ed_video :
1168+ t , h , w = (
1169+ image_grid_thw [image_index ][0 ],
1170+ image_grid_thw [image_index ][1 ],
1171+ image_grid_thw [image_index ][2 ],
1172+ )
1173+ image_index += 1
1174+ remain_images -= 1
1175+ ed = ed_image
1176+ else :
1177+ t , h , w = (
1178+ video_grid_thw [video_index ][0 ],
1179+ video_grid_thw [video_index ][1 ],
1180+ video_grid_thw [video_index ][2 ],
1181+ )
1182+ video_second_per_grid_t = 1.0
1183+ if second_per_grid_ts :
1184+ video_second_per_grid_t = second_per_grid_ts [video_index ]
1185+ video_index += 1
1186+ remain_videos -= 1
1187+ ed = ed_video
1188+
1189+ llm_grid_t , llm_grid_h , llm_grid_w = \
1190+ t , h // spatial_merge_size , w // spatial_merge_size
1191+ text_len = ed - st
1192+
1193+ st_idx = llm_pos_ids_list [- 1 ].max () + 1 if len (
1194+ llm_pos_ids_list ) > 0 else 0
1195+ llm_pos_ids_list .append (
1196+ torch .arange (text_len ).view (1 , - 1 ).expand (3 , - 1 ) + st_idx )
1197+
1198+ t_index = (torch .arange (llm_grid_t ).view (- 1 , 1 ).expand (
1199+ - 1 , llm_grid_h * llm_grid_w ) * video_second_per_grid_t *
1200+ tokens_per_second ).long ().flatten ()
1201+
1202+ h_index = torch .arange (llm_grid_h ).view (1 , - 1 , 1 ).expand (
1203+ llm_grid_t , - 1 , llm_grid_w ).flatten ()
1204+ w_index = torch .arange (llm_grid_w ).view (1 , 1 , - 1 ).expand (
1205+ llm_grid_t , llm_grid_h , - 1 ).flatten ()
1206+ llm_pos_ids_list .append (
1207+ torch .stack ([t_index , h_index , w_index ]) + text_len + st_idx )
1208+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
1209+
1210+ if st < len (input_tokens ):
1211+ st_idx = llm_pos_ids_list [- 1 ].max () + 1 if len (
1212+ llm_pos_ids_list ) > 0 else 0
1213+ text_len = len (input_tokens ) - st
1214+ llm_pos_ids_list .append (
1215+ torch .arange (text_len ).view (1 , - 1 ).expand (3 , - 1 ) + st_idx )
1216+
1217+ llm_positions = torch .cat (llm_pos_ids_list , dim = 1 ).reshape (3 , - 1 )
1218+ mrope_position_delta = (llm_positions .max () + 1 -
1219+ len (input_tokens )).item ()
1220+ llm_positions = llm_positions [:, context_len :seq_len ]
1221+
1222+ return llm_positions , mrope_position_delta
1223+
11121224 @classmethod
11131225 def get_placeholder_str (cls , modality : str , i : int ) -> Optional [str ]:
11141226 if modality .startswith ("image" ):
0 commit comments