Skip to content

Commit

Permalink
[Test] Test multiple attn backend for chunked prefill. (vllm-project#…
Browse files Browse the repository at this point in the history
  • Loading branch information
rkooo567 authored Apr 12, 2024
1 parent 7fd3949 commit 36729ba
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 23 deletions.
8 changes: 7 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@ steps:
command: pytest -v -s async_engine

- label: Basic Correctness Test
command: pytest -v -s basic_correctness
commands:
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_chunked_prefill.py

- label: Core Test
command: pytest -v -s core
Expand Down
6 changes: 0 additions & 6 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
"""
import pytest

from vllm.attention.selector import VLLM_ATTENTION_BACKEND

MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
Expand All @@ -16,7 +14,6 @@
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False, True])
@pytest.mark.parametrize("attn_backend", ["XFORMERS", "FLASH_ATTN"])
def test_models(
hf_runner,
vllm_runner,
Expand All @@ -25,10 +22,7 @@ def test_models(
dtype: str,
max_tokens: int,
enforce_eager: bool,
attn_backend: str,
monkeypatch,
) -> None:
monkeypatch.setenv(VLLM_ATTENTION_BACKEND, attn_backend)
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model
Expand Down
4 changes: 0 additions & 4 deletions tests/basic_correctness/test_chunked_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ def test_models(
enforce_eager: bool,
tensor_parallel_size: int,
) -> None:
if (tensor_parallel_size == 2 and chunked_prefill_token_size != 16
and not enforce_eager):
pytest.skip(f"Skip {chunked_prefill_token_size=} and {enforce_eager=} "
"for high TP to save testing time.")
max_num_seqs = min(chunked_prefill_token_size, 256)
enable_chunked_prefill = False
max_num_batched_tokens = None
Expand Down
18 changes: 6 additions & 12 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def __init__(
# AMD Radeon 7900 series (gfx1100) currently does not support
# xFormers nor FlashAttention. As a temporary workaround, we use
# naive PyTorch implementation of attention.
self.attn_fuc = _naive_attention()
self.attn_fuc = _naive_attention
logger.debug("Using naive attention in ROCmBackend")
elif self.use_triton_flash_attn:
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
Expand Down Expand Up @@ -334,26 +334,21 @@ def _naive_attention(
prompt_lens: List[int],
scale: float,
) -> torch.Tensor:
num_tokens = query.shape[0]
output = torch.empty_like(query)
start = 0
for _, prompt_len in enumerate(prompt_lens):
end = start + prompt_len
out = _naive_masked_attention(
query[None, start:end],
key[None, start:end],
value[None, start:end],
query[start:end],
key[start:end],
value[start:end],
scale,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out)
start += prompt_len

# Using view got RuntimeError: view size is not compatible
# with input tensor's size and stride (at least one
# dimension spans across two contiguous subspaces).
# Use reshape instead.
return output.reshape(num_tokens, -1)
return output


def _naive_masked_attention(
Expand All @@ -362,14 +357,13 @@ def _naive_masked_attention(
value: torch.Tensor,
scale: float,
) -> torch.Tensor:
seq_len, _, _ = query.shape
seq_len, head_size, head_dim = query.shape
attn_mask = torch.triu(torch.ones(seq_len,
seq_len,
dtype=query.dtype,
device=query.device),
diagonal=1)
attn_mask = attn_mask * torch.finfo(query.dtype).min

attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
attn_weights = attn_weights + attn_mask.float()
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
Expand Down

0 comments on commit 36729ba

Please sign in to comment.