@@ -467,8 +467,6 @@ def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
467467            dh_grid , dw_grid  =  torch .meshgrid (dh , dw , indexing = "ij" )
468468            h_floor_grid , w_floor_grid  =  torch .meshgrid (h_floor , w_floor , indexing = "ij" )
469469            h_ceil_grid , w_ceil_grid  =  torch .meshgrid (h_ceil , w_ceil , indexing = "ij" )
470-             h_floor_grid_idx  =  h_floor_grid  *  num_grid_per_side 
471-             h_ceil_grid_idx  =  h_ceil_grid  *  num_grid_per_side 
472470
473471            # original computation of weights 
474472            # w00 = (1 - dh_grid) * (1 - dw_grid) 
@@ -480,30 +478,25 @@ def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
480478            w11  =  dh_grid  *  dw_grid 
481479            w10  =  dh_grid  -  w11 
482480            w01  =  dw_grid  -  w11 
483-             w00  =  1  -  dh_grid  -  dw_grid   +   w11 
481+             w00  =  1  -  dh_grid  -  w01 
484482
485-             idx00  =  h_floor_grid_idx  +  w_floor_grid 
486-             idx01  =  h_floor_grid_idx  +  w_ceil_grid 
487-             idx10  =  h_ceil_grid_idx  +  w_floor_grid 
488-             idx11  =  h_ceil_grid_idx  +  w_ceil_grid 
483+             h_grid  =  torch .stack ([h_floor_grid , h_floor_grid , h_ceil_grid , h_ceil_grid ])
484+             w_grid  =  torch .stack ([w_floor_grid , w_ceil_grid , w_floor_grid , w_ceil_grid ])
485+             h_grid_idx  =  h_grid  *  num_grid_per_side 
489486
490-             indices  =  torch . stack ([ idx00 ,  idx01 ,  idx10 ,  idx11 ],  dim = 0 ).reshape (4 , - 1 )
487+             indices  =  ( h_grid_idx   +   w_grid ).reshape (4 , - 1 )
491488            weights  =  torch .stack ([w00 , w01 , w10 , w11 ], dim = 0 ).reshape (4 , - 1 , 1 )
492-             weights  =  weights .to (
493-                 dtype = self .dtype , device = self .device , non_blocking = True 
494-             )
489+             weights  =  weights .to (dtype = self .dtype )
495490
496491            embeds  =  self .pos_embed (indices )
497492            weighted_embeds  =  embeds  *  weights 
498-             p0 , p1 , p2 , p3  =  weighted_embeds .unbind (dim = 0 )
499-             combined  =  p0  +  p1  +  p2  +  p3 
493+             combined  =  weighted_embeds .sum (dim = 0 )
500494
501-             combined  =  combined .view (h  *  w , hidden_dim )
502-             repeated  =  combined .unsqueeze (0 ).expand (t , - 1 , - 1 ).contiguous ()
503-             repeated  =  repeated .view (
504-                 t , h  //  m_size , m_size , w  //  m_size , m_size , hidden_dim 
495+             combined  =  combined .reshape (
496+                 h  //  m_size , m_size , w  //  m_size , m_size , hidden_dim 
505497            )
506-             repeated  =  repeated .permute (0 , 1 , 3 , 2 , 4 , 5 ).reshape (- 1 , hidden_dim )
498+             combined  =  combined .permute (0 , 2 , 1 , 3 , 4 ).reshape (1 , - 1 , hidden_dim )
499+             repeated  =  combined .expand (t , - 1 , - 1 ).reshape (- 1 , hidden_dim )
507500            outputs .append (repeated )
508501
509502        return  torch .cat (outputs , dim = 0 )
0 commit comments