@@ -804,8 +804,13 @@ def compute_attn_mask_seqlen(
804804    def  forward (
805805        self ,
806806        x : torch .Tensor ,
807-         grid_thw : torch .Tensor ,
807+         grid_thw : Union [ torch .Tensor ,  list [ list [ int ]]] ,
808808    ) ->  torch .Tensor :
809+         # Convert grid_thw to tensor if it's a list (for compatibility with  
810+         # run_dp_sharded_mrope_vision_model) 
811+         if  isinstance (grid_thw , list ):
812+             grid_thw  =  torch .tensor (grid_thw , device = x .device , dtype = torch .long )
813+         
809814        # patchify 
810815        x  =  x .to (device = self .device , dtype = self .dtype )
811816        x  =  self .patch_embed (x )
@@ -1467,10 +1472,11 @@ def _process_image_input(
14671472                # run_dp_sharded_mrope_vision_model already 
14681473                # returns split embeddings 
14691474                return  run_dp_sharded_mrope_vision_model (
1470-                     self .visual , pixel_values , grid_thw )
1475+                     self .visual , pixel_values , grid_thw .tolist (),
1476+                     rope_type = "rope_3d" )
14711477            else :
1472-                 # Non-data parallel mode: self.visual expects tensor  format 
1473-                 image_embeds  =  self .visual (pixel_values , grid_thw = grid_thw )
1478+                 # Non-data parallel mode: pass list  format for consistency  
1479+                 image_embeds  =  self .visual (pixel_values , grid_thw = grid_thw . tolist () )
14741480                merge_size  =  self .visual .spatial_merge_size 
14751481                sizes  =  grid_thw .prod (- 1 ) //  merge_size  //  merge_size 
14761482                return  image_embeds .split (sizes .tolist ())
@@ -1493,11 +1499,12 @@ def _process_video_input(
14931499                # run_dp_sharded_mrope_vision_model already 
14941500                # returns split embeddings 
14951501                return  run_dp_sharded_mrope_vision_model (
1496-                     self .visual , pixel_values_videos , grid_thw )
1502+                     self .visual , pixel_values_videos , grid_thw .tolist (),
1503+                     rope_type = "rope_3d" )
14971504            else :
1498-                 # Non-data parallel mode: self.visual expects tensor  format 
1505+                 # Non-data parallel mode: pass list  format for consistency  
14991506                video_embeds  =  self .visual (pixel_values_videos ,
1500-                                            grid_thw = grid_thw )
1507+                                            grid_thw = grid_thw . tolist () )
15011508                # Split concatenated embeddings for each video item. 
15021509                merge_size  =  self .visual .spatial_merge_size 
15031510                sizes  =  grid_thw .prod (- 1 ) //  merge_size  //  merge_size 
0 commit comments