Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 49 additions & 23 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer."""
from typing import List, Optional
from typing import Callable, List, Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -68,9 +68,39 @@ def check_upstream_fa_availability(dtype: torch.dtype):
) and current_platform.has_device_capability(80):
from transformers.utils import is_flash_attn_2_available
return is_flash_attn_2_available()
if current_platform.is_rocm():
from importlib.util import find_spec
return find_spec("flash_attn") is not None
return False


def maybe_get_vit_flash_attn_backend(
attn_backend: _Backend,
use_upstream_fa: bool) -> tuple[_Backend, Callable]:
if attn_backend != _Backend.FLASH_ATTN and \
attn_backend != _Backend.ROCM_AITER_FA and \
check_upstream_fa_availability(torch.get_default_dtype()):
attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True

if current_platform.is_rocm() and \
attn_backend == _Backend.FLASH_ATTN:
use_upstream_fa = True

if (attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}):
if attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
else:
flash_attn_varlen_func = None

return attn_backend, flash_attn_varlen_func


class Attention(nn.Module, AttentionLayerBase):
"""Attention layer.

Expand Down Expand Up @@ -410,13 +440,9 @@ def __init__(
# to upstream flash attention if available.
# If vllm native fa is selected, we use it directly.
use_upstream_fa = False
if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
dtype):
backend = _Backend.FLASH_ATTN
use_upstream_fa = True

if current_platform.is_rocm() or current_platform.is_xpu():
# currently, only torch_sdpa is supported on rocm/xpu
if current_platform.is_xpu():
# currently, only torch_sdpa is supported on xpu
self.attn_backend = _Backend.TORCH_SDPA
else:

Expand All @@ -428,17 +454,25 @@ def __init__(
_Backend.FLASH_ATTN,
} else _Backend.TORCH_SDPA

self.attn_backend, self._flash_attn_varlen_func \
= maybe_get_vit_flash_attn_backend(
self.attn_backend,
use_upstream_fa,
)

if (self.attn_backend == _Backend.XFORMERS
and not check_xformers_availability()):
self.attn_backend = _Backend.TORCH_SDPA

if self.attn_backend == _Backend.FLASH_ATTN:
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
self._flash_attn_varlen_func = flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
self._flash_attn_varlen_func = flash_attn_varlen_func
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
}

# this condition is just to make sure that the
# use_upstream_fa in the log is correct
if current_platform.is_rocm() \
and self.attn_backend == _Backend.FLASH_ATTN:
use_upstream_fa = True

logger.info_once(
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
Expand Down Expand Up @@ -466,7 +500,7 @@ def forward(
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)

if self.attn_backend == _Backend.FLASH_ATTN:
if self.is_flash_attn_backend:
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
step=q_len,
dtype=torch.int32,
Expand Down Expand Up @@ -507,14 +541,6 @@ def forward(
from torch_xla.experimental.custom_kernel import flash_attention
out = flash_attention(query, key, value, sm_scale=self.scale)
out = out.transpose(1, 2)
elif self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func

# ROCm Flash Attention expects (batch, seq, heads, head_dim)
out = flash_attn_varlen_func(query,
key,
value,
softmax_scale=self.scale)
else:
# ViT attention hasn't supported this backend yet
raise NotImplementedError(
Expand Down
41 changes: 19 additions & 22 deletions vllm/model_executor/models/dots_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from transformers.models.qwen2_vl import Qwen2VLProcessor

from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.attention.layer import (check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend)
from vllm.config import VllmConfig
from vllm.distributed import utils as dist_utils
from vllm.distributed.parallel_state import (
Expand Down Expand Up @@ -267,10 +268,12 @@ def __init__(self,
self.attn_backend = get_vit_attn_backend(
self.hidden_size_per_attention_head, torch.get_default_dtype())
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True

self.attn_backend, self.flash_attn_varlen_func \
= maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
)
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
_Backend.ROCM_AITER_FA
Expand Down Expand Up @@ -306,25 +309,18 @@ def forward(
q, k = torch.chunk(qk_rotated, 2, dim=0)

if self.is_flash_attn_backend:
if self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
q_ = q.reshape(bs * q.shape[1], q.shape[2], q.shape[3])
k_ = k.reshape(bs * k.shape[1], k.shape[2], k.shape[3])
v_ = v.reshape(bs * v.shape[1], v.shape[2], v.shape[3])
output = flash_attn_varlen_func(q_,
k_,
v_,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False)
output = self.flash_attn_varlen_func(q_,
k_,
v_,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False)
context_layer = output.view(bs, -1,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
Expand Down Expand Up @@ -611,7 +607,8 @@ def compute_attn_mask_seqlen(
self, cu_seqlens: torch.Tensor
) -> tuple[Optional[int], Optional[list[int]]]:
max_seqlen, seqlens = None, None
if self.attn_backend == _Backend.FLASH_ATTN:
if (self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
Expand Down
49 changes: 23 additions & 26 deletions vllm/model_executor/models/ernie45_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
from transformers import BatchFeature

from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.attention.layer import (check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend)
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
Expand Down Expand Up @@ -176,14 +177,18 @@ def __init__(
dtype=torch.get_default_dtype())

self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True

self.attn_backend, self.flash_attn_varlen_func \
= maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
)

if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
_Backend.ROCM_AITER_FA
_Backend.FLASH_ATTN,
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Ernie45-VL does not support {self.attn_backend} backend now."
Expand Down Expand Up @@ -239,27 +244,18 @@ def forward(
q, k = torch.chunk(qk_rotated, 2, dim=0)

if self.is_flash_attn_backend:
# from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func)
if self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func

q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

output = flash_attn_varlen_func(q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False)
output = self.flash_attn_varlen_func(q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False)

context_layer = rearrange(output,
"(b s) h d -> s b (h d)",
Expand Down Expand Up @@ -516,7 +512,8 @@ def compute_attn_mask_seqlen(
self, cu_seqlens: torch.Tensor
) -> tuple[Optional[int], Optional[list[int]]]:
max_seqlen, seqlens = None, None
if self.attn_backend == _Backend.FLASH_ATTN:
if (self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
Expand Down
31 changes: 17 additions & 14 deletions vllm/model_executor/models/glm4_1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@
from transformers.video_utils import VideoMetadata

from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.attention.layer import (check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend)
from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
parallel_state)
Expand Down Expand Up @@ -263,19 +264,26 @@ def __init__(
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype())
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True

self.attn_backend, self.flash_attn_varlen_func \
= maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
)

if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.ROCM_AITER_FA,
}:
raise RuntimeError(
f"GLM-4V does not support {self.attn_backend} backend now.")

self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
}

def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
Expand Down Expand Up @@ -316,17 +324,11 @@ def forward(
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
q, k = torch.chunk(qk_rotated, 2, dim=0)

if self.attn_backend == _Backend.FLASH_ATTN:
# from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func)
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
if self.is_flash_attn_backend:

q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

output = flash_attn_varlen_func(
output = self.flash_attn_varlen_func(
q,
k,
v,
Expand Down Expand Up @@ -774,7 +776,8 @@ def compute_attn_mask_seqlen(
) -> tuple[Optional[int], Optional[list[int]]]:
max_seqlen, seqlens = None, None
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
if self.attn_backend == _Backend.FLASH_ATTN:
if (self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
return max_seqlen, seqlens

Expand Down
Loading