Skip to content

Commit 30d0891

Browse files
authored
[MM][Perf] Minor Optimization on Qwen3-VL fast_pos_embed_interpolate (#25337)
Signed-off-by: Roger Wang <hey@rogerw.io>
1 parent cf56cf7 commit 30d0891

File tree

1 file changed

+60
-75
lines changed

1 file changed

+60
-75
lines changed

vllm/model_executor/models/qwen3_vl.py

Lines changed: 60 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def __init__(
270270
self.temporal_patch_size = vision_config.temporal_patch_size
271271
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
272272
self.use_data_parallel = use_data_parallel
273+
self.num_grid_per_side = int(self.num_position_embeddings**0.5)
273274

274275
# NOTE: This is used for creating empty tensor for all_gather for
275276
# DP ViT. Here out_hidden_size is enlarged due to deepstack
@@ -377,82 +378,68 @@ def rot_pos_emb(self, grid_thw):
377378
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
378379
return rotary_pos_emb
379380

380-
def fast_pos_embed_interpolate(self, grid_thw):
381-
num_grid_per_side = int(self.num_position_embeddings**0.5)
381+
def fast_pos_embed_interpolate(self,
382+
grid_thw: list[list[int]]) -> torch.Tensor:
382383

383-
idx_list = [[] for _ in range(4)]
384-
weight_list = [[] for _ in range(4)]
384+
num_grid_per_side = self.num_grid_per_side
385+
m_size = self.spatial_merge_size
386+
hidden_dim = self.pos_embed.embedding_dim
385387

388+
outputs = []
386389
for t, h, w in grid_thw:
387390
h_idxs = torch.linspace(0,
388391
num_grid_per_side - 1,
389392
h,
390-
dtype=torch.float32)
393+
dtype=torch.float32,
394+
device=self.device)
391395
w_idxs = torch.linspace(0,
392396
num_grid_per_side - 1,
393397
w,
394-
dtype=torch.float32)
395-
396-
h_idxs_floor = h_idxs.to(torch.long)
397-
w_idxs_floor = w_idxs.to(torch.long)
398-
h_idxs_ceil = torch.clamp(h_idxs.to(torch.long) + 1,
399-
max=num_grid_per_side - 1)
400-
w_idxs_ceil = torch.clamp(w_idxs.to(torch.long) + 1,
401-
max=num_grid_per_side - 1)
402-
403-
dh = h_idxs - h_idxs_floor
404-
dw = w_idxs - w_idxs_floor
405-
406-
idx_list[0].extend(((h_idxs_floor * num_grid_per_side)[None].T +
407-
w_idxs_floor[None]).flatten().tolist() * t)
408-
idx_list[1].extend(((h_idxs_floor * num_grid_per_side)[None].T +
409-
w_idxs_ceil[None]).flatten().tolist() * t)
410-
idx_list[2].extend(((h_idxs_ceil * num_grid_per_side)[None].T +
411-
w_idxs_floor[None]).flatten().tolist() * t)
412-
idx_list[3].extend(((h_idxs_ceil * num_grid_per_side)[None].T +
413-
w_idxs_ceil[None]).flatten().tolist() * t)
414-
415-
weight_list[0].extend(
416-
((1 - dh)[None].T * (1 - dw)[None]).flatten().tolist() * t)
417-
weight_list[1].extend(
418-
((1 - dh)[None].T * dw[None]).flatten().tolist() * t)
419-
weight_list[2].extend(
420-
(dh[None].T * (1 - dw)[None]).flatten().tolist() * t)
421-
weight_list[3].extend(
422-
(dh[None].T * dw[None]).flatten().tolist() * t)
423-
424-
device = self.pos_embed.weight.device
425-
dtype = self.pos_embed.weight.dtype
426-
427-
p0 = self.pos_embed(
428-
torch.tensor(
429-
idx_list[0], dtype=torch.long, device=device)) * torch.tensor(
430-
weight_list[0], dtype=dtype, device=device)[:, None]
431-
p1 = self.pos_embed(
432-
torch.tensor(
433-
idx_list[1], dtype=torch.long, device=device)) * torch.tensor(
434-
weight_list[1], dtype=dtype, device=device)[:, None]
435-
p2 = self.pos_embed(
436-
torch.tensor(
437-
idx_list[2], dtype=torch.long, device=device)) * torch.tensor(
438-
weight_list[2], dtype=dtype, device=device)[:, None]
439-
p3 = self.pos_embed(
440-
torch.tensor(
441-
idx_list[3], dtype=torch.long, device=device)) * torch.tensor(
442-
weight_list[3], dtype=dtype, device=device)[:, None]
443-
444-
patch_pos_embeds = p0 + p1 + p2 + p3
445-
patch_pos_embeds = patch_pos_embeds.split(
446-
[t * h * w for t, h, w in grid_thw])
447-
patch_pos_embeds_permute = []
448-
m_size = self.spatial_merge_size
449-
for pos_embed, (t, h, w) in zip(patch_pos_embeds, grid_thw):
450-
pos_embed = pos_embed.view(t, h // m_size, m_size, w // m_size,
451-
m_size, -1).permute(0, 1, 3, 2, 4,
452-
5).flatten(0, 4)
453-
patch_pos_embeds_permute.append(pos_embed)
454-
patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
455-
return patch_pos_embeds
398+
dtype=torch.float32,
399+
device=self.device)
400+
401+
h_floor = h_idxs.to(torch.long)
402+
w_floor = w_idxs.to(torch.long)
403+
h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)
404+
w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)
405+
406+
dh = h_idxs - h_floor
407+
dw = w_idxs - w_floor
408+
409+
w00 = ((1 - dh)[:, None] * (1 - dw)[None, :]).reshape(-1)
410+
w01 = ((1 - dh)[:, None] * dw[None, :]).reshape(-1)
411+
w10 = (dh[:, None] * (1 - dw)[None, :]).reshape(-1)
412+
w11 = (dh[:, None] * dw[None, :]).reshape(-1)
413+
414+
idx00 = (h_floor[:, None] * num_grid_per_side +
415+
w_floor[None, :]).reshape(-1)
416+
idx01 = (h_floor[:, None] * num_grid_per_side +
417+
w_ceil[None, :]).reshape(-1)
418+
idx10 = (h_ceil[:, None] * num_grid_per_side +
419+
w_floor[None, :]).reshape(-1)
420+
idx11 = (h_ceil[:, None] * num_grid_per_side +
421+
w_ceil[None, :]).reshape(-1)
422+
423+
indices = torch.stack([idx00, idx01, idx10, idx11], dim=0)
424+
weights = torch.stack([w00, w01, w10, w11],
425+
dim=0).to(dtype=self.dtype,
426+
device=self.device)
427+
weights = weights.unsqueeze(-1)
428+
429+
embeds = self.pos_embed(indices)
430+
weighted_embeds = embeds * weights
431+
p0, p1, p2, p3 = weighted_embeds.unbind(dim=0)
432+
combined = p0 + p1 + p2 + p3
433+
434+
combined = combined.view(h * w, hidden_dim)
435+
repeated = combined.unsqueeze(0).expand(t, -1, -1).contiguous()
436+
repeated = repeated.view(t, h // m_size, m_size, w // m_size,
437+
m_size, hidden_dim)
438+
repeated = repeated.permute(0, 1, 3, 2, 4,
439+
5).reshape(-1, hidden_dim)
440+
outputs.append(repeated)
441+
442+
return torch.cat(outputs, dim=0)
456443

457444
def compute_attn_mask_seqlen(
458445
self,
@@ -477,12 +464,9 @@ def forward(
477464
hidden_states = hidden_states + pos_embeds
478465
rotary_pos_emb = self.rot_pos_emb(grid_thw)
479466

480-
if isinstance(grid_thw, list):
481-
grid_thw_tensor = torch.tensor(grid_thw,
482-
device=hidden_states.device,
483-
dtype=torch.int32)
484-
else:
485-
grid_thw_tensor = grid_thw
467+
grid_thw_tensor = torch.tensor(grid_thw,
468+
device=self.device,
469+
dtype=torch.int32)
486470

487471
cu_seqlens = torch.repeat_interleave(
488472
grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2],
@@ -1224,7 +1208,8 @@ def _process_image_input(
12241208
grid_thw_list,
12251209
rope_type="rope_3d")
12261210
else:
1227-
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
1211+
image_embeds = self.visual(pixel_values,
1212+
grid_thw=grid_thw_list)
12281213

12291214
# Split concatenated embeddings for each image item.
12301215
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
@@ -1526,4 +1511,4 @@ def get_mm_mapping(self) -> MultiModelKeys:
15261511
language_model="language_model",
15271512
connector="model.visual.merger",
15281513
tower_model="model.visual.",
1529-
)
1514+
)

0 commit comments

Comments
 (0)