@@ -177,6 +177,18 @@ def triton_mrope(
177177    return  q , k 
178178
179179
180+ def  apply_interleaved_rope (x : torch .Tensor ,
181+                            mrope_section : list [int ]) ->  torch .Tensor :
182+     """Apply interleaved MRoPE to 3D rotary embeddings. 
183+     Reorganizes frequency layout from chunked [TTT...HHH...WWW] to 
184+     interleaved [THTHWHTHW...TT], preserving frequency continuity. 
185+     """ 
186+     x_t  =  x [0 ].clone ()
187+     x_t [..., 1 :mrope_section [1 ] *  3 :3 ] =  x [1 , ..., 1 :mrope_section [1 ] *  3 :3 ]
188+     x_t [..., 2 :mrope_section [2 ] *  3 :3 ] =  x [2 , ..., 2 :mrope_section [2 ] *  3 :3 ]
189+     return  x_t 
190+ 
191+ 
180192class  MRotaryEmbedding (RotaryEmbedding ):
181193    """Rotary Embedding with Multimodal Sections.""" 
182194
@@ -189,6 +201,7 @@ def __init__(
189201        is_neox_style : bool ,
190202        dtype : torch .dtype ,
191203        mrope_section : Optional [list [int ]] =  None ,
204+         mrope_interleaved : Optional [bool ] =  False ,
192205    ) ->  None :
193206        # In Qwen2.5-VL, the maximum index value is related to the duration of 
194207        # the input video. We enlarge max_position_embeddings to 4 times to get 
@@ -198,6 +211,7 @@ def __init__(
198211                         base , is_neox_style , dtype )
199212
200213        self .mrope_section  =  mrope_section 
214+         self .mrope_interleaved  =  mrope_interleaved 
201215        if  self .mrope_section :
202216            assert  sum (self .mrope_section ) ==  rotary_dim  //  2 
203217
@@ -225,17 +239,20 @@ def forward_native(
225239        cos , sin  =  cos_sin .chunk (2 , dim = - 1 )
226240        if  positions .ndim  ==  2 :
227241            assert  self .mrope_section 
228- 
229-             cos  =  torch .cat ([
230-                 m [i ]
231-                 for  i , m  in  enumerate (cos .split (self .mrope_section , dim = - 1 ))
232-             ],
233-                             dim = - 1 )
234-             sin  =  torch .cat ([
235-                 m [i ]
236-                 for  i , m  in  enumerate (sin .split (self .mrope_section , dim = - 1 ))
237-             ],
238-                             dim = - 1 )
242+             if  self .mrope_interleaved :
243+                 cos  =  apply_interleaved_rope (cos , self .mrope_section )
244+                 sin  =  apply_interleaved_rope (sin , self .mrope_section )
245+             else :
246+                 cos  =  torch .cat ([
247+                     m [i ] for  i , m  in  enumerate (
248+                         cos .split (self .mrope_section , dim = - 1 ))
249+                 ],
250+                                 dim = - 1 )
251+                 sin  =  torch .cat ([
252+                     m [i ] for  i , m  in  enumerate (
253+                         sin .split (self .mrope_section , dim = - 1 ))
254+                 ],
255+                                 dim = - 1 )
239256
240257        query_shape  =  query .shape 
241258        query  =  query .view (num_tokens , - 1 , self .head_size )
@@ -265,6 +282,10 @@ def forward_cuda(
265282        assert  positions .ndim  ==  1  or  positions .ndim  ==  2 
266283        assert  key  is  not None 
267284
285+         if  self .mrope_interleaved :
286+             # TODO: add triton implementation to support mrope-interleaved 
287+             return  self .forward_native (positions , query , key )
288+ 
268289        num_tokens  =  positions .shape [- 1 ]
269290        cos_sin  =  self .cos_sin_cache [positions ]
270291        cos , sin  =  cos_sin .chunk (2 , dim = - 1 )
@@ -388,6 +409,15 @@ def get_input_positions_tensor(
388409                context_len = context_len ,
389410                seq_len = seq_len ,
390411            )
412+         elif  hf_config .model_type  in  ["qwen3_vl" , "qwen3_vl_moe" ]:
413+             return  cls ._qwen3vl_get_input_positions_tensor (
414+                 input_tokens = input_tokens ,
415+                 hf_config = hf_config ,
416+                 image_grid_thw = image_grid_thw ,
417+                 video_grid_thw = video_grid_thw ,
418+                 context_len = context_len ,
419+                 seq_len = seq_len ,
420+             )
391421        elif  hf_config .model_type  in  ["ernie4_5_moe_vl" , "ernie4_5_vl" ]:
392422            return  cls ._ernie_get_input_positions_tensor (
393423                input_tokens = input_tokens ,
@@ -526,6 +556,98 @@ def _glm4v_get_input_positions_tensor(
526556                                len (input_tokens )).item ()
527557        return  llm_positions , mrope_position_delta 
528558
559+     @classmethod  
560+     def  _qwen3vl_get_input_positions_tensor (
561+         cls ,
562+         input_tokens : list [int ],
563+         hf_config : PretrainedConfig ,
564+         image_grid_thw : Union [list [list [int ]], torch .Tensor ],
565+         video_grid_thw : Union [list [list [int ]], torch .Tensor ],
566+         context_len : int  =  0 ,
567+         seq_len : Optional [int ] =  None ,
568+     ) ->  tuple [torch .Tensor , int ]:
569+         """Get mrope input positions and delta value.""" 
570+ 
571+         video_grid_thw  =  [[1 , h , w ] for  t , h , w  in  video_grid_thw 
572+                           for  _  in  range (t )]
573+ 
574+         image_token_id  =  hf_config .image_token_id 
575+         video_token_id  =  hf_config .video_token_id 
576+         vision_start_token_id  =  hf_config .vision_start_token_id 
577+         spatial_merge_size  =  hf_config .vision_config .spatial_merge_size 
578+ 
579+         input_tokens_tensor  =  torch .tensor (input_tokens )
580+         vision_start_indices  =  torch .argwhere (
581+             input_tokens_tensor  ==  vision_start_token_id ).squeeze (1 )
582+         vision_tokens  =  input_tokens_tensor [vision_start_indices  +  1 ]
583+         image_nums  =  (vision_tokens  ==  image_token_id ).sum ()
584+         video_nums  =  (vision_tokens  ==  video_token_id ).sum ()
585+         llm_pos_ids_list : list  =  []
586+ 
587+         st  =  0 
588+         remain_images , remain_videos  =  image_nums , video_nums 
589+ 
590+         image_index , video_index  =  0 , 0 
591+         for  _  in  range (image_nums  +  video_nums ):
592+             if  image_token_id  in  input_tokens  and  remain_images  >  0 :
593+                 ed_image  =  input_tokens .index (image_token_id , st )
594+             else :
595+                 ed_image  =  len (input_tokens ) +  1 
596+             if  video_token_id  in  input_tokens  and  remain_videos  >  0 :
597+                 ed_video  =  input_tokens .index (video_token_id , st )
598+             else :
599+                 ed_video  =  len (input_tokens ) +  1 
600+             if  ed_image  <  ed_video :
601+                 t , h , w  =  (
602+                     image_grid_thw [image_index ][0 ],
603+                     image_grid_thw [image_index ][1 ],
604+                     image_grid_thw [image_index ][2 ],
605+                 )
606+                 image_index  +=  1 
607+                 remain_images  -=  1 
608+                 ed  =  ed_image 
609+             else :
610+                 t , h , w  =  (
611+                     video_grid_thw [video_index ][0 ],
612+                     video_grid_thw [video_index ][1 ],
613+                     video_grid_thw [video_index ][2 ],
614+                 )
615+                 video_index  +=  1 
616+                 remain_videos  -=  1 
617+                 ed  =  ed_video 
618+ 
619+             llm_grid_t , llm_grid_h , llm_grid_w  =  \
620+                 t , h  //  spatial_merge_size , w  //  spatial_merge_size 
621+             text_len  =  ed  -  st 
622+ 
623+             st_idx  =  llm_pos_ids_list [- 1 ].max () +  1  if  len (
624+                 llm_pos_ids_list ) >  0  else  0 
625+             llm_pos_ids_list .append (
626+                 torch .arange (text_len ).view (1 , - 1 ).expand (3 , - 1 ) +  st_idx )
627+ 
628+             t_index  =  torch .arange (llm_grid_t ).view (- 1 , 1 ).expand (
629+                 - 1 , llm_grid_h  *  llm_grid_w ).flatten ()
630+             h_index  =  torch .arange (llm_grid_h ).view (1 , - 1 , 1 ).expand (
631+                 llm_grid_t , - 1 , llm_grid_w ).flatten ()
632+             w_index  =  torch .arange (llm_grid_w ).view (1 , 1 , - 1 ).expand (
633+                 llm_grid_t , llm_grid_h , - 1 ).flatten ()
634+             llm_pos_ids_list .append (
635+                 torch .stack ([t_index , h_index , w_index ]) +  text_len  +  st_idx )
636+             st  =  ed  +  llm_grid_t  *  llm_grid_h  *  llm_grid_w 
637+ 
638+         if  st  <  len (input_tokens ):
639+             st_idx  =  llm_pos_ids_list [- 1 ].max () +  1  if  len (
640+                 llm_pos_ids_list ) >  0  else  0 
641+             text_len  =  len (input_tokens ) -  st 
642+             llm_pos_ids_list .append (
643+                 torch .arange (text_len ).view (1 , - 1 ).expand (3 , - 1 ) +  st_idx )
644+ 
645+         llm_positions  =  torch .cat (llm_pos_ids_list , dim = 1 ).reshape (3 , - 1 )
646+         mrope_position_delta  =  (llm_positions .max () +  1  - 
647+                                 len (input_tokens )).item ()
648+         llm_positions  =  llm_positions [:, context_len :seq_len ]
649+         return  llm_positions , mrope_position_delta 
650+ 
529651    @classmethod  
530652    def  _ernie_get_input_positions_tensor (
531653        cls ,
0 commit comments