@@ -270,6 +270,7 @@ def __init__(
270270        self .temporal_patch_size  =  vision_config .temporal_patch_size 
271271        self .deepstack_visual_indexes  =  vision_config .deepstack_visual_indexes 
272272        self .use_data_parallel  =  use_data_parallel 
273+         self .num_grid_per_side  =  int (self .num_position_embeddings ** 0.5 )
273274
274275        # NOTE: This is used for creating empty tensor for all_gather for 
275276        # DP ViT. Here out_hidden_size is enlarged due to deepstack 
@@ -377,82 +378,68 @@ def rot_pos_emb(self, grid_thw):
377378        rotary_pos_emb  =  rotary_pos_emb_full [pos_ids ].flatten (1 )
378379        return  rotary_pos_emb 
379380
380-     def  fast_pos_embed_interpolate (self ,  grid_thw ): 
381-         num_grid_per_side   =   int ( self . num_position_embeddings ** 0.5 ) 
381+     def  fast_pos_embed_interpolate (self ,
382+                                     grid_thw :  list [ list [ int ]])  ->   torch . Tensor : 
382383
383-         idx_list  =  [[] for  _  in  range (4 )]
384-         weight_list  =  [[] for  _  in  range (4 )]
384+         num_grid_per_side  =  self .num_grid_per_side 
385+         m_size  =  self .spatial_merge_size 
386+         hidden_dim  =  self .pos_embed .embedding_dim 
385387
388+         outputs  =  []
386389        for  t , h , w  in  grid_thw :
387390            h_idxs  =  torch .linspace (0 ,
388391                                    num_grid_per_side  -  1 ,
389392                                    h ,
390-                                     dtype = torch .float32 )
393+                                     dtype = torch .float32 ,
394+                                     device = self .device )
391395            w_idxs  =  torch .linspace (0 ,
392396                                    num_grid_per_side  -  1 ,
393397                                    w ,
394-                                     dtype = torch .float32 )
395- 
396-             h_idxs_floor  =  h_idxs .to (torch .long )
397-             w_idxs_floor  =  w_idxs .to (torch .long )
398-             h_idxs_ceil  =  torch .clamp (h_idxs .to (torch .long ) +  1 ,
399-                                       max = num_grid_per_side  -  1 )
400-             w_idxs_ceil  =  torch .clamp (w_idxs .to (torch .long ) +  1 ,
401-                                       max = num_grid_per_side  -  1 )
402- 
403-             dh  =  h_idxs  -  h_idxs_floor 
404-             dw  =  w_idxs  -  w_idxs_floor 
405- 
406-             idx_list [0 ].extend (((h_idxs_floor  *  num_grid_per_side )[None ].T  + 
407-                                 w_idxs_floor [None ]).flatten ().tolist () *  t )
408-             idx_list [1 ].extend (((h_idxs_floor  *  num_grid_per_side )[None ].T  + 
409-                                 w_idxs_ceil [None ]).flatten ().tolist () *  t )
410-             idx_list [2 ].extend (((h_idxs_ceil  *  num_grid_per_side )[None ].T  + 
411-                                 w_idxs_floor [None ]).flatten ().tolist () *  t )
412-             idx_list [3 ].extend (((h_idxs_ceil  *  num_grid_per_side )[None ].T  + 
413-                                 w_idxs_ceil [None ]).flatten ().tolist () *  t )
414- 
415-             weight_list [0 ].extend (
416-                 ((1  -  dh )[None ].T  *  (1  -  dw )[None ]).flatten ().tolist () *  t )
417-             weight_list [1 ].extend (
418-                 ((1  -  dh )[None ].T  *  dw [None ]).flatten ().tolist () *  t )
419-             weight_list [2 ].extend (
420-                 (dh [None ].T  *  (1  -  dw )[None ]).flatten ().tolist () *  t )
421-             weight_list [3 ].extend (
422-                 (dh [None ].T  *  dw [None ]).flatten ().tolist () *  t )
423- 
424-         device  =  self .pos_embed .weight .device 
425-         dtype  =  self .pos_embed .weight .dtype 
426- 
427-         p0  =  self .pos_embed (
428-             torch .tensor (
429-                 idx_list [0 ], dtype = torch .long , device = device )) *  torch .tensor (
430-                     weight_list [0 ], dtype = dtype , device = device )[:, None ]
431-         p1  =  self .pos_embed (
432-             torch .tensor (
433-                 idx_list [1 ], dtype = torch .long , device = device )) *  torch .tensor (
434-                     weight_list [1 ], dtype = dtype , device = device )[:, None ]
435-         p2  =  self .pos_embed (
436-             torch .tensor (
437-                 idx_list [2 ], dtype = torch .long , device = device )) *  torch .tensor (
438-                     weight_list [2 ], dtype = dtype , device = device )[:, None ]
439-         p3  =  self .pos_embed (
440-             torch .tensor (
441-                 idx_list [3 ], dtype = torch .long , device = device )) *  torch .tensor (
442-                     weight_list [3 ], dtype = dtype , device = device )[:, None ]
443- 
444-         patch_pos_embeds  =  p0  +  p1  +  p2  +  p3 
445-         patch_pos_embeds  =  patch_pos_embeds .split (
446-             [t  *  h  *  w  for  t , h , w  in  grid_thw ])
447-         patch_pos_embeds_permute  =  []
448-         m_size  =  self .spatial_merge_size 
449-         for  pos_embed , (t , h , w ) in  zip (patch_pos_embeds , grid_thw ):
450-             pos_embed  =  pos_embed .view (t , h  //  m_size , m_size , w  //  m_size ,
451-                                        m_size , - 1 ).permute (0 , 1 , 3 , 2 , 4 ,
452-                                                            5 ).flatten (0 , 4 )
453-             patch_pos_embeds_permute .append (pos_embed )
454-         patch_pos_embeds  =  torch .cat (patch_pos_embeds_permute )
455-         return  patch_pos_embeds 
398+                                     dtype = torch .float32 ,
399+                                     device = self .device )
400+ 
401+             h_floor  =  h_idxs .to (torch .long )
402+             w_floor  =  w_idxs .to (torch .long )
403+             h_ceil  =  torch .clamp (h_floor  +  1 , max = num_grid_per_side  -  1 )
404+             w_ceil  =  torch .clamp (w_floor  +  1 , max = num_grid_per_side  -  1 )
405+ 
406+             dh  =  h_idxs  -  h_floor 
407+             dw  =  w_idxs  -  w_floor 
408+ 
409+             w00  =  ((1  -  dh )[:, None ] *  (1  -  dw )[None , :]).reshape (- 1 )
410+             w01  =  ((1  -  dh )[:, None ] *  dw [None , :]).reshape (- 1 )
411+             w10  =  (dh [:, None ] *  (1  -  dw )[None , :]).reshape (- 1 )
412+             w11  =  (dh [:, None ] *  dw [None , :]).reshape (- 1 )
413+ 
414+             idx00  =  (h_floor [:, None ] *  num_grid_per_side  + 
415+                      w_floor [None , :]).reshape (- 1 )
416+             idx01  =  (h_floor [:, None ] *  num_grid_per_side  + 
417+                      w_ceil [None , :]).reshape (- 1 )
418+             idx10  =  (h_ceil [:, None ] *  num_grid_per_side  + 
419+                      w_floor [None , :]).reshape (- 1 )
420+             idx11  =  (h_ceil [:, None ] *  num_grid_per_side  + 
421+                      w_ceil [None , :]).reshape (- 1 )
422+ 
423+             indices  =  torch .stack ([idx00 , idx01 , idx10 , idx11 ], dim = 0 )
424+             weights  =  torch .stack ([w00 , w01 , w10 , w11 ],
425+                                   dim = 0 ).to (dtype = self .dtype ,
426+                                             device = self .device )
427+             weights  =  weights .unsqueeze (- 1 )
428+ 
429+             embeds  =  self .pos_embed (indices )
430+             weighted_embeds  =  embeds  *  weights 
431+             p0 , p1 , p2 , p3  =  weighted_embeds .unbind (dim = 0 )
432+             combined  =  p0  +  p1  +  p2  +  p3 
433+ 
434+             combined  =  combined .view (h  *  w , hidden_dim )
435+             repeated  =  combined .unsqueeze (0 ).expand (t , - 1 , - 1 ).contiguous ()
436+             repeated  =  repeated .view (t , h  //  m_size , m_size , w  //  m_size ,
437+                                      m_size , hidden_dim )
438+             repeated  =  repeated .permute (0 , 1 , 3 , 2 , 4 ,
439+                                         5 ).reshape (- 1 , hidden_dim )
440+             outputs .append (repeated )
441+ 
442+         return  torch .cat (outputs , dim = 0 )
456443
457444    def  compute_attn_mask_seqlen (
458445        self ,
@@ -477,12 +464,9 @@ def forward(
477464        hidden_states  =  hidden_states  +  pos_embeds 
478465        rotary_pos_emb  =  self .rot_pos_emb (grid_thw )
479466
480-         if  isinstance (grid_thw , list ):
481-             grid_thw_tensor  =  torch .tensor (grid_thw ,
482-                                            device = hidden_states .device ,
483-                                            dtype = torch .int32 )
484-         else :
485-             grid_thw_tensor  =  grid_thw 
467+         grid_thw_tensor  =  torch .tensor (grid_thw ,
468+                                        device = self .device ,
469+                                        dtype = torch .int32 )
486470
487471        cu_seqlens  =  torch .repeat_interleave (
488472            grid_thw_tensor [:, 1 ] *  grid_thw_tensor [:, 2 ],
@@ -1224,7 +1208,8 @@ def _process_image_input(
12241208                                                         grid_thw_list ,
12251209                                                         rope_type = "rope_3d" )
12261210            else :
1227-                 image_embeds  =  self .visual (pixel_values , grid_thw = grid_thw )
1211+                 image_embeds  =  self .visual (pixel_values ,
1212+                                            grid_thw = grid_thw_list )
12281213
12291214        # Split concatenated embeddings for each image item. 
12301215        # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync 
@@ -1526,4 +1511,4 @@ def get_mm_mapping(self) -> MultiModelKeys:
15261511            language_model = "language_model" ,
15271512            connector = "model.visual.merger" ,
15281513            tower_model = "model.visual." ,
1529-         )
1514+         )
0 commit comments