Skip to content

Commit d93f53c

Browse files
lgeigerDhruvilbhatt
authored andcommitted
[Models][Qwen3VL] Speedup fast_pos_embed_interpolate (vllm-project#26647)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com> Signed-off-by: Dhruvil Bhatt <bhattdbh@amazon.com>
1 parent 18ff3dc commit d93f53c

File tree

1 file changed

+11
-18
lines changed

1 file changed

+11
-18
lines changed

vllm/model_executor/models/qwen3_vl.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -467,8 +467,6 @@ def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
467467
dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij")
468468
h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij")
469469
h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij")
470-
h_floor_grid_idx = h_floor_grid * num_grid_per_side
471-
h_ceil_grid_idx = h_ceil_grid * num_grid_per_side
472470

473471
# original computation of weights
474472
# w00 = (1 - dh_grid) * (1 - dw_grid)
@@ -480,30 +478,25 @@ def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
480478
w11 = dh_grid * dw_grid
481479
w10 = dh_grid - w11
482480
w01 = dw_grid - w11
483-
w00 = 1 - dh_grid - dw_grid + w11
481+
w00 = 1 - dh_grid - w01
484482

485-
idx00 = h_floor_grid_idx + w_floor_grid
486-
idx01 = h_floor_grid_idx + w_ceil_grid
487-
idx10 = h_ceil_grid_idx + w_floor_grid
488-
idx11 = h_ceil_grid_idx + w_ceil_grid
483+
h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid])
484+
w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid])
485+
h_grid_idx = h_grid * num_grid_per_side
489486

490-
indices = torch.stack([idx00, idx01, idx10, idx11], dim=0).reshape(4, -1)
487+
indices = (h_grid_idx + w_grid).reshape(4, -1)
491488
weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1)
492-
weights = weights.to(
493-
dtype=self.dtype, device=self.device, non_blocking=True
494-
)
489+
weights = weights.to(dtype=self.dtype)
495490

496491
embeds = self.pos_embed(indices)
497492
weighted_embeds = embeds * weights
498-
p0, p1, p2, p3 = weighted_embeds.unbind(dim=0)
499-
combined = p0 + p1 + p2 + p3
493+
combined = weighted_embeds.sum(dim=0)
500494

501-
combined = combined.view(h * w, hidden_dim)
502-
repeated = combined.unsqueeze(0).expand(t, -1, -1).contiguous()
503-
repeated = repeated.view(
504-
t, h // m_size, m_size, w // m_size, m_size, hidden_dim
495+
combined = combined.reshape(
496+
h // m_size, m_size, w // m_size, m_size, hidden_dim
505497
)
506-
repeated = repeated.permute(0, 1, 3, 2, 4, 5).reshape(-1, hidden_dim)
498+
combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim)
499+
repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim)
507500
outputs.append(repeated)
508501

509502
return torch.cat(outputs, dim=0)

0 commit comments

Comments
 (0)