Skip to content

Commit

Permalink
[Misc] Support attention logits soft-capping with flash-attn (vllm-pr…
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Aug 1, 2024
1 parent 562e580 commit 805a8a7
Show file tree
Hide file tree
Showing 14 changed files with 71 additions and 47 deletions.
2 changes: 1 addition & 1 deletion requirements-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ torch == 2.4.0
# These must be updated alongside torch
torchvision == 0.19 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
xformers == 0.0.27.post2 # Requires PyTorch 2.4.0
vllm-flash-attn == 2.6.0 # Requires PyTorch 2.4.0
vllm-flash-attn == 2.6.1 # Requires PyTorch 2.4.0
19 changes: 13 additions & 6 deletions tests/kernels/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def ref_paged_attn(
block_tables: torch.Tensor,
scale: float,
sliding_window: Optional[int] = None,
soft_cap: Optional[float] = None,
) -> torch.Tensor:
num_seqs = len(query_lens)
block_tables = block_tables.cpu().numpy()
Expand Down Expand Up @@ -53,6 +54,8 @@ def ref_paged_attn(
(query_len + sliding_window) +
1).bool().logical_not()
mask |= sliding_window_mask
if soft_cap is not None:
attn = soft_cap * torch.tanh(attn / soft_cap)
attn.masked_fill_(mask, float("-inf"))
attn = torch.softmax(attn, dim=-1).to(v.dtype)
out = torch.einsum("hqk,khd->qhd", attn, v)
Expand All @@ -68,13 +71,15 @@ def ref_paged_attn(
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@torch.inference_mode()
def test_flash_attn_with_paged_kv(
kv_lens: List[int],
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
Expand Down Expand Up @@ -108,6 +113,7 @@ def test_flash_attn_with_paged_kv(
causal=True,
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
).squeeze(1)

ref_output = ref_paged_attn(
Expand All @@ -118,6 +124,7 @@ def test_flash_attn_with_paged_kv(
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap,
)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
Expand All @@ -129,14 +136,16 @@ def test_flash_attn_with_paged_kv(
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("sliding_window", [None])
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@torch.inference_mode()
def test_varlen_with_paged_kv(
seq_lens: List[Tuple[int, int]],
num_heads: Tuple[int, int],
head_size: int,
sliding_window: Optional[int],
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
Expand All @@ -163,10 +172,6 @@ def test_varlen_with_paged_kv(
head_size,
dtype=dtype)
value_cache = torch.randn_like(key_cache)
# Normalize the scale of the key and value caches to mitigate
# numerical instability.
key_cache /= head_size**0.5
value_cache /= head_size**0.5
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
Expand All @@ -192,6 +197,7 @@ def test_varlen_with_paged_kv(
causal=True,
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
)

ref_output = ref_paged_attn(
Expand All @@ -203,6 +209,7 @@ def test_varlen_with_paged_kv(
block_tables=block_tables,
scale=scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
1 change: 1 addition & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(
sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
raise NotImplementedError

Expand Down
3 changes: 3 additions & 0 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,12 +283,15 @@ def __init__(
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
assert blocksparse_params is not None
assert alibi_slopes is None, ValueError(
"Alibi not support for blocksparse flash attention.")
assert sliding_window is None, ValueError(
"sliding_window is invalid for blocksparse attention.")
assert logits_soft_cap is None, ValueError(
"logits_soft_cap is invalid for blocksparse attention.")

if "num_heads" not in blocksparse_params:
blocksparse_params["num_heads"] = num_heads
Expand Down
21 changes: 10 additions & 11 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,15 +288,6 @@ def build(self, seq_lens: List[int], query_lens: List[int],
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1

logits_soft_cap = getattr(self.runner.model_config.hf_config,
"attn_logit_softcapping", None)
if logits_soft_cap is not None:
raise ValueError(
"Please use Flashinfer backend for models with logits_soft_cap"
" (i.e., Gemma-2). Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")

max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
Expand Down Expand Up @@ -405,9 +396,11 @@ def __init__(
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
assert blocksparse_params is None, ValueError(
"FlashAttention does not support block-sparse attention.")
if blocksparse_params is not None:
raise ValueError(
"FlashAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
Expand All @@ -418,6 +411,10 @@ def __init__(
self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
Expand Down Expand Up @@ -525,6 +522,7 @@ def forward(
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
Expand All @@ -544,6 +542,7 @@ def forward(
causal=True,
alibi_slopes=self.alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=self.logits_soft_cap,
)

if decode_meta := attn_metadata.decode_metadata:
Expand Down
14 changes: 5 additions & 9 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ class FlashInferMetadata(AttentionMetadata):
# The data type of the paged kv cache
data_type: torch.dtype = None
device: torch.device = torch.device("cuda")
# Only used by gemma2 model
logits_soft_cap: Optional[float] = None

def __post_init__(self):
# Refer to
Expand Down Expand Up @@ -391,9 +389,6 @@ def build(self, seq_lens: List[int], query_lens: List[int],
dtype=torch.long,
device=device)

logits_soft_cap = getattr(self.runner.model_config.hf_config,
"attn_logit_softcapping", None)

if len(self.paged_kv_indptr) > 0:
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
device="cpu",
Expand Down Expand Up @@ -430,8 +425,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
query_start_loc=query_start_loc,
device=device,
data_type=kv_cache_dtype,
use_cuda_graph=use_captured_graph,
logits_soft_cap=logits_soft_cap)
use_cuda_graph=use_captured_graph)


class FlashInferImpl(AttentionImpl):
Expand All @@ -446,6 +440,7 @@ def __init__(
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
Expand All @@ -458,6 +453,7 @@ def __init__(
raise ValueError("Sliding window is not supported in FlashInfer.")
self.sliding_window = (-1, -1)
self.kv_cache_dtype = kv_cache_dtype
self.logits_soft_cap = logits_soft_cap

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
Expand Down Expand Up @@ -532,7 +528,7 @@ def forward(
output = prefill_meta.prefill_wrapper.forward(
query,
kv_cache,
logits_soft_cap=attn_metadata.logits_soft_cap,
logits_soft_cap=self.logits_soft_cap,
causal=True)
else:
assert attn_metadata.decode_metadata is not None
Expand All @@ -541,5 +537,5 @@ def forward(
query,
kv_cache,
sm_scale=self.scale,
logits_soft_cap=attn_metadata.logits_soft_cap)
logits_soft_cap=self.logits_soft_cap)
return output.view(num_tokens, hidden_size)
8 changes: 6 additions & 2 deletions vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,13 @@ def __init__(
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
assert blocksparse_params is None, ValueError(
"Torch SPDA does not support block-sparse attention.")
if blocksparse_params is not None:
raise ValueError(
"IPEX backend does not support block-sparse attention.")
if logits_soft_cap is not None:
raise ValueError("IPEX backend does not support logits_soft_cap.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
Expand All @@ -109,6 +110,9 @@ def __init__(
raise NotImplementedError("FP8 KV cache dtype is not supported.")
if blocksparse_params is not None:
raise NotImplementedError("Blocksparse is not supported.")
if logits_soft_cap is not None:
raise NotImplementedError(
"Attention logits soft-capping is not supported.")

if torch_xla.tpu.version() < 4:
raise NotImplementedError("TPU version must be 4 or higher.")
Expand Down
10 changes: 8 additions & 2 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,15 @@ def __init__(
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
assert blocksparse_params is None, ValueError(
"ROCFlashAttention does not support blocksparse attention.")
if blocksparse_params is not None:
raise ValueError(
"ROCmFlashAttention does not support blocksparse attention.")
if logits_soft_cap is not None:
raise ValueError(
"ROCmFlashAttention does not support attention logits soft "
"capping.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
Expand Down
8 changes: 6 additions & 2 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,13 @@ def __init__(
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
assert blocksparse_params is None, ValueError(
"Torch SPDA does not support block-sparse attention.")
if blocksparse_params is not None:
raise ValueError(
"Torch SPDA does not support block-sparse attention.")
if logits_soft_cap is not None:
raise ValueError("Torch SPDA does not support logits soft cap.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
Expand Down
9 changes: 0 additions & 9 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,6 @@ def build(self, seq_lens: List[int], query_lens: List[int],
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1

logits_soft_cap = getattr(self.runner.model_config.hf_config,
"attn_logit_softcapping", None)
if logits_soft_cap is not None:
raise ValueError(
"Please use Flashinfer backend for models with logits_soft_cap "
"(i.e., Gemma-2). Otherwise, the output might be wrong. "
"Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")

max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
Expand Down
9 changes: 7 additions & 2 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,14 @@ def __init__(
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
assert blocksparse_params is None, ValueError(
"XFormer does not support block-sparse attention.")
if blocksparse_params is not None:
raise ValueError(
"XFormers does not support block-sparse attention.")
if logits_soft_cap is not None:
raise ValueError(
"XFormers does not support attention logits soft capping.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
Expand Down
3 changes: 2 additions & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
prefix: str = "",
) -> None:
super().__init__()
Expand Down Expand Up @@ -82,7 +83,7 @@ def __init__(
impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params)
blocksparse_params, logits_soft_cap)

def forward(
self,
Expand Down
7 changes: 5 additions & 2 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def __init__(self,
max_position_embeddings: int,
rope_theta: float,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
quant_config: Optional[QuantizationConfig] = None,
attn_logits_soft_cap: Optional[float] = None) -> None:
super().__init__()
self.layer_idx = layer_idx
self.config = config
Expand Down Expand Up @@ -150,7 +151,8 @@ def __init__(self,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
logits_soft_cap=attn_logits_soft_cap)

def forward(
self,
Expand Down Expand Up @@ -189,6 +191,7 @@ def __init__(
rope_theta=config.rope_theta,
cache_config=cache_config,
quant_config=quant_config,
attn_logits_soft_cap=config.attn_logit_softcapping,
)
self.hidden_size = config.hidden_size
self.mlp = Gemma2MLP(
Expand Down

0 comments on commit 805a8a7

Please sign in to comment.