Skip to content

[Misc] Add attention mask pre-computation optimization back to Qwen2.5-VL #15273

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,17 @@ def get_window_index(self, grid_thw):
window_index = torch.cat(window_index, dim=0)
return window_index, cu_window_seqlens

def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
) -> tuple[Optional[int], Optional[list[int]]]:
max_seqlen, seqlens = None, None
if self.attn_backend == _Backend.FLASH_ATTN:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return max_seqlen, seqlens

def forward(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -645,25 +656,27 @@ def forward(
# transformers
hidden_states = hidden_states.unsqueeze(1)

max_seqlen = None
seqlens = None
# pre-compute seqlens for window/full attn to reduce cuMemcpy operations
max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen(
cu_seqlens)
max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen(
cu_window_seqlens)
for layer_num, blk in enumerate(self.blocks):
if layer_num in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
max_seqlen_now = max_seqlen_full
seqlens_now = seqlens_full
else:
cu_seqlens_now = cu_window_seqlens
# pre-compute cu_seqlens for window attn
if self.attn_backend == _Backend.FLASH_ATTN:
max_seqlen = (cu_seqlens_now[1:] -
cu_seqlens_now[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens_now[1:] - cu_seqlens_now[:-1]).tolist()
max_seqlen_now = max_seqlen_window
seqlens_now = seqlens_window

hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens_now,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
max_seqlen=max_seqlen_now,
seqlens=seqlens_now,
)

# For Qwen2.5-VL-3B, float16 will overflow at last block
Expand Down
18 changes: 12 additions & 6 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,16 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb

def compute_attn_mask_seqlen(
self, cu_seqlens: torch.Tensor
) -> tuple[Optional[int], Optional[list[int]]]:
max_seqlen, seqlens = None, None
if self.attn_backend == _Backend.FLASH_ATTN:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return max_seqlen, seqlens

def forward(
self,
x: torch.Tensor,
Expand All @@ -638,12 +648,8 @@ def forward(
# transformers
x = x.unsqueeze(1)

max_seqlen = None
seqlens = None
if self.attn_backend == _Backend.FLASH_ATTN:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
for blk in self.blocks:
x = blk(
x,
Expand Down