@@ -270,6 +270,7 @@ def __init__(
270270 self .temporal_patch_size = vision_config .temporal_patch_size
271271 self .deepstack_visual_indexes = vision_config .deepstack_visual_indexes
272272 self .use_data_parallel = use_data_parallel
273+ self .num_grid_per_side = int (self .num_position_embeddings ** 0.5 )
273274
274275 # NOTE: This is used for creating empty tensor for all_gather for
275276 # DP ViT. Here out_hidden_size is enlarged due to deepstack
@@ -377,82 +378,68 @@ def rot_pos_emb(self, grid_thw):
377378 rotary_pos_emb = rotary_pos_emb_full [pos_ids ].flatten (1 )
378379 return rotary_pos_emb
379380
380- def fast_pos_embed_interpolate (self , grid_thw ):
381- num_grid_per_side = int ( self . num_position_embeddings ** 0.5 )
381+ def fast_pos_embed_interpolate (self ,
382+ grid_thw : list [ list [ int ]]) -> torch . Tensor :
382383
383- idx_list = [[] for _ in range (4 )]
384- weight_list = [[] for _ in range (4 )]
384+ num_grid_per_side = self .num_grid_per_side
385+ m_size = self .spatial_merge_size
386+ hidden_dim = self .pos_embed .embedding_dim
385387
388+ outputs = []
386389 for t , h , w in grid_thw :
387390 h_idxs = torch .linspace (0 ,
388391 num_grid_per_side - 1 ,
389392 h ,
390- dtype = torch .float32 )
393+ dtype = torch .float32 ,
394+ device = self .device )
391395 w_idxs = torch .linspace (0 ,
392396 num_grid_per_side - 1 ,
393397 w ,
394- dtype = torch .float32 )
395-
396- h_idxs_floor = h_idxs .to (torch .long )
397- w_idxs_floor = w_idxs .to (torch .long )
398- h_idxs_ceil = torch .clamp (h_idxs .to (torch .long ) + 1 ,
399- max = num_grid_per_side - 1 )
400- w_idxs_ceil = torch .clamp (w_idxs .to (torch .long ) + 1 ,
401- max = num_grid_per_side - 1 )
402-
403- dh = h_idxs - h_idxs_floor
404- dw = w_idxs - w_idxs_floor
405-
406- idx_list [0 ].extend (((h_idxs_floor * num_grid_per_side )[None ].T +
407- w_idxs_floor [None ]).flatten ().tolist () * t )
408- idx_list [1 ].extend (((h_idxs_floor * num_grid_per_side )[None ].T +
409- w_idxs_ceil [None ]).flatten ().tolist () * t )
410- idx_list [2 ].extend (((h_idxs_ceil * num_grid_per_side )[None ].T +
411- w_idxs_floor [None ]).flatten ().tolist () * t )
412- idx_list [3 ].extend (((h_idxs_ceil * num_grid_per_side )[None ].T +
413- w_idxs_ceil [None ]).flatten ().tolist () * t )
414-
415- weight_list [0 ].extend (
416- ((1 - dh )[None ].T * (1 - dw )[None ]).flatten ().tolist () * t )
417- weight_list [1 ].extend (
418- ((1 - dh )[None ].T * dw [None ]).flatten ().tolist () * t )
419- weight_list [2 ].extend (
420- (dh [None ].T * (1 - dw )[None ]).flatten ().tolist () * t )
421- weight_list [3 ].extend (
422- (dh [None ].T * dw [None ]).flatten ().tolist () * t )
423-
424- device = self .pos_embed .weight .device
425- dtype = self .pos_embed .weight .dtype
426-
427- p0 = self .pos_embed (
428- torch .tensor (
429- idx_list [0 ], dtype = torch .long , device = device )) * torch .tensor (
430- weight_list [0 ], dtype = dtype , device = device )[:, None ]
431- p1 = self .pos_embed (
432- torch .tensor (
433- idx_list [1 ], dtype = torch .long , device = device )) * torch .tensor (
434- weight_list [1 ], dtype = dtype , device = device )[:, None ]
435- p2 = self .pos_embed (
436- torch .tensor (
437- idx_list [2 ], dtype = torch .long , device = device )) * torch .tensor (
438- weight_list [2 ], dtype = dtype , device = device )[:, None ]
439- p3 = self .pos_embed (
440- torch .tensor (
441- idx_list [3 ], dtype = torch .long , device = device )) * torch .tensor (
442- weight_list [3 ], dtype = dtype , device = device )[:, None ]
443-
444- patch_pos_embeds = p0 + p1 + p2 + p3
445- patch_pos_embeds = patch_pos_embeds .split (
446- [t * h * w for t , h , w in grid_thw ])
447- patch_pos_embeds_permute = []
448- m_size = self .spatial_merge_size
449- for pos_embed , (t , h , w ) in zip (patch_pos_embeds , grid_thw ):
450- pos_embed = pos_embed .view (t , h // m_size , m_size , w // m_size ,
451- m_size , - 1 ).permute (0 , 1 , 3 , 2 , 4 ,
452- 5 ).flatten (0 , 4 )
453- patch_pos_embeds_permute .append (pos_embed )
454- patch_pos_embeds = torch .cat (patch_pos_embeds_permute )
455- return patch_pos_embeds
398+ dtype = torch .float32 ,
399+ device = self .device )
400+
401+ h_floor = h_idxs .to (torch .long )
402+ w_floor = w_idxs .to (torch .long )
403+ h_ceil = torch .clamp (h_floor + 1 , max = num_grid_per_side - 1 )
404+ w_ceil = torch .clamp (w_floor + 1 , max = num_grid_per_side - 1 )
405+
406+ dh = h_idxs - h_floor
407+ dw = w_idxs - w_floor
408+
409+ w00 = ((1 - dh )[:, None ] * (1 - dw )[None , :]).reshape (- 1 )
410+ w01 = ((1 - dh )[:, None ] * dw [None , :]).reshape (- 1 )
411+ w10 = (dh [:, None ] * (1 - dw )[None , :]).reshape (- 1 )
412+ w11 = (dh [:, None ] * dw [None , :]).reshape (- 1 )
413+
414+ idx00 = (h_floor [:, None ] * num_grid_per_side +
415+ w_floor [None , :]).reshape (- 1 )
416+ idx01 = (h_floor [:, None ] * num_grid_per_side +
417+ w_ceil [None , :]).reshape (- 1 )
418+ idx10 = (h_ceil [:, None ] * num_grid_per_side +
419+ w_floor [None , :]).reshape (- 1 )
420+ idx11 = (h_ceil [:, None ] * num_grid_per_side +
421+ w_ceil [None , :]).reshape (- 1 )
422+
423+ indices = torch .stack ([idx00 , idx01 , idx10 , idx11 ], dim = 0 )
424+ weights = torch .stack ([w00 , w01 , w10 , w11 ],
425+ dim = 0 ).to (dtype = self .dtype ,
426+ device = self .device )
427+ weights = weights .unsqueeze (- 1 )
428+
429+ embeds = self .pos_embed (indices )
430+ weighted_embeds = embeds * weights
431+ p0 , p1 , p2 , p3 = weighted_embeds .unbind (dim = 0 )
432+ combined = p0 + p1 + p2 + p3
433+
434+ combined = combined .view (h * w , hidden_dim )
435+ repeated = combined .unsqueeze (0 ).expand (t , - 1 , - 1 ).contiguous ()
436+ repeated = repeated .view (t , h // m_size , m_size , w // m_size ,
437+ m_size , hidden_dim )
438+ repeated = repeated .permute (0 , 1 , 3 , 2 , 4 ,
439+ 5 ).reshape (- 1 , hidden_dim )
440+ outputs .append (repeated )
441+
442+ return torch .cat (outputs , dim = 0 )
456443
457444 def compute_attn_mask_seqlen (
458445 self ,
@@ -477,12 +464,9 @@ def forward(
477464 hidden_states = hidden_states + pos_embeds
478465 rotary_pos_emb = self .rot_pos_emb (grid_thw )
479466
480- if isinstance (grid_thw , list ):
481- grid_thw_tensor = torch .tensor (grid_thw ,
482- device = hidden_states .device ,
483- dtype = torch .int32 )
484- else :
485- grid_thw_tensor = grid_thw
467+ grid_thw_tensor = torch .tensor (grid_thw ,
468+ device = self .device ,
469+ dtype = torch .int32 )
486470
487471 cu_seqlens = torch .repeat_interleave (
488472 grid_thw_tensor [:, 1 ] * grid_thw_tensor [:, 2 ],
@@ -1224,7 +1208,8 @@ def _process_image_input(
12241208 grid_thw_list ,
12251209 rope_type = "rope_3d" )
12261210 else :
1227- image_embeds = self .visual (pixel_values , grid_thw = grid_thw )
1211+ image_embeds = self .visual (pixel_values ,
1212+ grid_thw = grid_thw_list )
12281213
12291214 # Split concatenated embeddings for each image item.
12301215 # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
@@ -1526,4 +1511,4 @@ def get_mm_mapping(self) -> MultiModelKeys:
15261511 language_model = "language_model" ,
15271512 connector = "model.visual.merger" ,
15281513 tower_model = "model.visual." ,
1529- )
1514+ )
0 commit comments