Skip to content

Commit 1bd9d1c

Browse files
fix qwen2vl vision eager-attention (#33213)
* fix-qwen2vl-vision-eager-attention * code-quality * Update src/transformers/models/qwen2_vl/modeling_qwen2_vl.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * code-quality --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
1 parent 51d15eb commit 1bd9d1c

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

src/transformers/models/qwen2_vl/modeling_qwen2_vl.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ class VisionAttention(nn.Module):
275275
def __init__(self, dim: int, num_heads: int = 16) -> None:
276276
super().__init__()
277277
self.num_heads = num_heads
278+
self.head_dim = dim // num_heads
278279
self.qkv = nn.Linear(dim, dim * 3, bias=True)
279280
self.proj = nn.Linear(dim, dim)
280281

@@ -286,9 +287,11 @@ def forward(
286287
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
287288
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
288289

289-
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
290+
attention_mask = torch.full(
291+
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
292+
)
290293
for i in range(1, len(cu_seqlens)):
291-
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
294+
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
292295

293296
q = q.transpose(0, 1)
294297
k = k.transpose(0, 1)

0 commit comments

Comments
 (0)