Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 32 additions & 18 deletions vllm/model_executor/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,25 +406,39 @@ def fast_pos_embed_interpolate(self,
dh = h_idxs - h_floor
dw = w_idxs - w_floor

w00 = ((1 - dh)[:, None] * (1 - dw)[None, :]).reshape(-1)
w01 = ((1 - dh)[:, None] * dw[None, :]).reshape(-1)
w10 = (dh[:, None] * (1 - dw)[None, :]).reshape(-1)
w11 = (dh[:, None] * dw[None, :]).reshape(-1)

idx00 = (h_floor[:, None] * num_grid_per_side +
w_floor[None, :]).reshape(-1)
idx01 = (h_floor[:, None] * num_grid_per_side +
w_ceil[None, :]).reshape(-1)
idx10 = (h_ceil[:, None] * num_grid_per_side +
w_floor[None, :]).reshape(-1)
idx11 = (h_ceil[:, None] * num_grid_per_side +
w_ceil[None, :]).reshape(-1)

indices = torch.stack([idx00, idx01, idx10, idx11], dim=0)
# Create meshgrid view for all h, w vars
dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing='ij')
h_floor_grid, w_floor_grid = torch.meshgrid(h_floor,
w_floor,
indexing='ij')
h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil,
w_ceil,
indexing='ij')
h_floor_grid_idx = h_floor_grid * num_grid_per_side
h_ceil_grid_idx = h_ceil_grid * num_grid_per_side

# original computation of weights
# w00 = (1 - dh_grid) * (1 - dw_grid)
# w01 = (1 - dh_grid) * dw_grid
# w10 = dh_grid * (1 - dw_grid)
# w11 = dh_grid * dw_grid
# we reuse w11 here to avoid duplicate
# dh_grid * dw_grid computation
w11 = dh_grid * dw_grid
w10 = dh_grid - w11
w01 = dw_grid - w11
w00 = 1 - dh_grid - dw_grid + w11

idx00 = h_floor_grid_idx + w_floor_grid
idx01 = h_floor_grid_idx + w_ceil_grid
idx10 = h_ceil_grid_idx + w_floor_grid
idx11 = h_ceil_grid_idx + w_ceil_grid

indices = torch.stack([idx00, idx01, idx10, idx11],
dim=0).reshape(4, -1)
weights = torch.stack([w00, w01, w10, w11],
dim=0).to(dtype=self.dtype,
device=self.device)
weights = weights.unsqueeze(-1)
dim=0).reshape(4, -1, 1)
weights = weights.to(dtype=self.dtype, device=self.device)

embeds = self.pos_embed(indices)
weighted_embeds = embeds * weights
Expand Down