@@ -467,8 +467,6 @@ def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
467467 dh_grid , dw_grid = torch .meshgrid (dh , dw , indexing = "ij" )
468468 h_floor_grid , w_floor_grid = torch .meshgrid (h_floor , w_floor , indexing = "ij" )
469469 h_ceil_grid , w_ceil_grid = torch .meshgrid (h_ceil , w_ceil , indexing = "ij" )
470- h_floor_grid_idx = h_floor_grid * num_grid_per_side
471- h_ceil_grid_idx = h_ceil_grid * num_grid_per_side
472470
473471 # original computation of weights
474472 # w00 = (1 - dh_grid) * (1 - dw_grid)
@@ -480,30 +478,25 @@ def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
480478 w11 = dh_grid * dw_grid
481479 w10 = dh_grid - w11
482480 w01 = dw_grid - w11
483- w00 = 1 - dh_grid - dw_grid + w11
481+ w00 = 1 - dh_grid - w01
484482
485- idx00 = h_floor_grid_idx + w_floor_grid
486- idx01 = h_floor_grid_idx + w_ceil_grid
487- idx10 = h_ceil_grid_idx + w_floor_grid
488- idx11 = h_ceil_grid_idx + w_ceil_grid
483+ h_grid = torch .stack ([h_floor_grid , h_floor_grid , h_ceil_grid , h_ceil_grid ])
484+ w_grid = torch .stack ([w_floor_grid , w_ceil_grid , w_floor_grid , w_ceil_grid ])
485+ h_grid_idx = h_grid * num_grid_per_side
489486
490- indices = torch . stack ([ idx00 , idx01 , idx10 , idx11 ], dim = 0 ).reshape (4 , - 1 )
487+ indices = ( h_grid_idx + w_grid ).reshape (4 , - 1 )
491488 weights = torch .stack ([w00 , w01 , w10 , w11 ], dim = 0 ).reshape (4 , - 1 , 1 )
492- weights = weights .to (
493- dtype = self .dtype , device = self .device , non_blocking = True
494- )
489+ weights = weights .to (dtype = self .dtype )
495490
496491 embeds = self .pos_embed (indices )
497492 weighted_embeds = embeds * weights
498- p0 , p1 , p2 , p3 = weighted_embeds .unbind (dim = 0 )
499- combined = p0 + p1 + p2 + p3
493+ combined = weighted_embeds .sum (dim = 0 )
500494
501- combined = combined .view (h * w , hidden_dim )
502- repeated = combined .unsqueeze (0 ).expand (t , - 1 , - 1 ).contiguous ()
503- repeated = repeated .view (
504- t , h // m_size , m_size , w // m_size , m_size , hidden_dim
495+ combined = combined .reshape (
496+ h // m_size , m_size , w // m_size , m_size , hidden_dim
505497 )
506- repeated = repeated .permute (0 , 1 , 3 , 2 , 4 , 5 ).reshape (- 1 , hidden_dim )
498+ combined = combined .permute (0 , 2 , 1 , 3 , 4 ).reshape (1 , - 1 , hidden_dim )
499+ repeated = combined .expand (t , - 1 , - 1 ).reshape (- 1 , hidden_dim )
507500 outputs .append (repeated )
508501
509502 return torch .cat (outputs , dim = 0 )
0 commit comments