Skip to content

Commit 64f4906

Browse files
JartXtjtanaa
andcommitted
qwen2vl and qwen2.5vl contiguous on rocm and torch.sdpa
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: JartX <sagformas@epdcenter.es>
1 parent 46ed9c6 commit 64f4906

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,11 @@ def forward(
423423
).contiguous()
424424
elif self.attn_backend == _Backend.TORCH_SDPA:
425425
# Execute attention entry by entry for speed & less VRAM.
426+
from vllm.platforms import current_platform
427+
if current_platform.is_rocm():
428+
q = q.contiguous()
429+
k = k.contiguous()
430+
v = v.contiguous()
426431
outputs = []
427432
for i in range(1, len(cu_seqlens)):
428433
start_idx = cu_seqlens[i - 1]

vllm/model_executor/models/qwen2_vl.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,11 @@ def forward(
453453
).contiguous()
454454
elif self.attn_backend == _Backend.TORCH_SDPA:
455455
# Execute attention entry by entry for speed & less VRAM.
456+
from vllm.platforms import current_platform
457+
if current_platform.is_rocm():
458+
q = q.contiguous()
459+
k = k.contiguous()
460+
v = v.contiguous()
456461
outputs = []
457462
for i in range(1, len(cu_seqlens)):
458463
start_idx = cu_seqlens[i - 1]

0 commit comments

Comments
 (0)