Skip to content

Commit 2daf23a

Browse files
authored
Separate attention backends (#3005)
1 parent cbf4c05 commit 2daf23a

35 files changed

+558
-268
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,6 @@ _build/
184184

185185
# Benchmark dataset
186186
*.json
187+
188+
# Third-party Python packages.
189+
vllm/thirdparty_files/

setup.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import re
55
import subprocess
6+
import sys
67
import warnings
78
from pathlib import Path
89
from typing import List, Set
@@ -14,6 +15,8 @@
1415
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
1516

1617
ROOT_DIR = os.path.dirname(__file__)
18+
# This is a temporary directory to store third-party packages.
19+
THIRDPARTY_SUBDIR = "vllm/thirdparty_files"
1720

1821
# If you are developing the C++ backend of vLLM, consider building vLLM with
1922
# `python setup.py develop` since it will give you incremental builds.
@@ -324,8 +327,46 @@ def get_torch_arch_list() -> Set[str]:
324327
"nvcc": NVCC_FLAGS_PUNICA,
325328
},
326329
))
327-
elif _is_neuron():
328-
neuronxcc_version = get_neuronxcc_version()
330+
331+
# Download the FlashAttention package.
332+
# Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/setup.py#L518-L530
333+
flash_attn_version = "2.5.6"
334+
install_dir = os.path.join(ROOT_DIR, THIRDPARTY_SUBDIR)
335+
subprocess.check_call(
336+
[
337+
sys.executable,
338+
"-m",
339+
"pip",
340+
"install",
341+
"-q",
342+
f"--target={install_dir}",
343+
"einops", # Dependency of flash-attn.
344+
f"flash-attn=={flash_attn_version}",
345+
"--no-dependencies", # Required to avoid re-installing torch.
346+
],
347+
env=dict(os.environ, CC="gcc"),
348+
)
349+
350+
# Copy the FlashAttention package into the vLLM package after build.
351+
class build_ext(BuildExtension):
352+
353+
def run(self):
354+
super().run()
355+
target_dir = os.path.join(self.build_lib, THIRDPARTY_SUBDIR)
356+
if not os.path.exists(target_dir):
357+
os.makedirs(target_dir)
358+
self.copy_tree(install_dir, target_dir)
359+
360+
class BinaryDistribution(setuptools.Distribution):
361+
362+
def has_ext_modules(self):
363+
return True
364+
365+
else:
366+
build_ext = BuildExtension
367+
BinaryDistribution = setuptools.Distribution
368+
if _is_neuron():
369+
neuronxcc_version = get_neuronxcc_version()
329370

330371
vllm_extension_sources = [
331372
"csrc/cache_kernels.cu",
@@ -468,6 +509,7 @@ def get_requirements() -> List[str]:
468509
python_requires=">=3.8",
469510
install_requires=get_requirements(),
470511
ext_modules=ext_modules,
471-
cmdclass={"build_ext": BuildExtension} if not _is_neuron() else {},
512+
cmdclass={"build_ext": build_ext} if not _is_neuron() else {},
513+
distclass=BinaryDistribution,
472514
package_data=package_data,
473515
)

tests/kernels/test_prefix_prefill.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import time
44

55
import torch
6-
from vllm.model_executor.layers.triton_kernel.prefix_prefill import (
6+
from vllm.model_executor.layers.attention.ops.prefix_prefill import (
77
context_attention_fwd)
88
from xformers import ops as xops
99
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask

vllm/__init__.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,28 @@
11
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
22

3-
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
4-
from vllm.engine.async_llm_engine import AsyncLLMEngine
5-
from vllm.engine.llm_engine import LLMEngine
6-
from vllm.engine.ray_utils import initialize_cluster
7-
from vllm.entrypoints.llm import LLM
8-
from vllm.outputs import CompletionOutput, RequestOutput
9-
from vllm.sampling_params import SamplingParams
3+
4+
# Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/ray/__init__.py#L11
5+
def _configure_system():
6+
import os
7+
import sys
8+
9+
# Importing flash-attn.
10+
thirdparty_files = os.path.join(os.path.abspath(os.path.dirname(__file__)),
11+
"thirdparty_files")
12+
sys.path.insert(0, thirdparty_files)
13+
14+
15+
_configure_system()
16+
# Delete configuration function.
17+
del _configure_system
18+
19+
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs # noqa: E402
20+
from vllm.engine.async_llm_engine import AsyncLLMEngine # noqa: E402
21+
from vllm.engine.llm_engine import LLMEngine # noqa: E402
22+
from vllm.engine.ray_utils import initialize_cluster # noqa: E402
23+
from vllm.entrypoints.llm import LLM # noqa: E402
24+
from vllm.outputs import CompletionOutput, RequestOutput # noqa: E402
25+
from vllm.sampling_params import SamplingParams # noqa: E402
1026

1127
__version__ = "0.3.3"
1228

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from vllm.model_executor.layers.attention.attention import Attention
2+
3+
__all__ = [
4+
"Attention",
5+
]
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Attention layer."""
2+
from typing import List, Optional
3+
4+
import torch
5+
import torch.nn as nn
6+
7+
from vllm.model_executor.input_metadata import InputMetadata
8+
from vllm.utils import is_hip
9+
10+
11+
class Attention(nn.Module):
12+
"""Attention layer.
13+
14+
This class takes query, key, and value tensors as input. The input tensors
15+
can either contain prompt tokens or generation tokens.
16+
The class does the following:
17+
18+
1. Store the input key and value tensors in the KV cache.
19+
2. Perform (multi-head/multi-query/grouped-query) attention.
20+
3. Return the output tensor.
21+
"""
22+
23+
def __init__(
24+
self,
25+
num_heads: int,
26+
head_size: int,
27+
scale: float,
28+
num_kv_heads: Optional[int] = None,
29+
alibi_slopes: Optional[List[float]] = None,
30+
sliding_window: Optional[int] = None,
31+
) -> None:
32+
super().__init__()
33+
if (not is_hip() and torch.cuda.get_device_capability()[0] >= 8 and
34+
torch.get_default_dtype() in (torch.float16, torch.bfloat16)):
35+
# Ampere or later NVIDIA GPUs.
36+
# NOTE(woosuk): FlashAttention does not support FP32.
37+
from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend
38+
self.backend = FlashAttentionBackend(num_heads, head_size, scale,
39+
num_kv_heads, alibi_slopes,
40+
sliding_window)
41+
else:
42+
# Turing and Volta NVIDIA GPUs or AMD GPUs.
43+
# Or FP32 on any GPU.
44+
from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend
45+
self.backend = XFormersBackend(num_heads, head_size, scale,
46+
num_kv_heads, alibi_slopes,
47+
sliding_window)
48+
49+
def forward(
50+
self,
51+
query: torch.Tensor,
52+
key: torch.Tensor,
53+
value: torch.Tensor,
54+
key_cache: Optional[torch.Tensor],
55+
value_cache: Optional[torch.Tensor],
56+
input_metadata: InputMetadata,
57+
) -> torch.Tensor:
58+
return self.backend.forward(query, key, value, key_cache, value_cache,
59+
input_metadata)
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
"""Attention layer with Flash and PagedAttention."""
2+
from typing import List, Optional
3+
4+
# NOTE(woosuk): This imports flash_attn under vllm/thirdparty_files/.
5+
from flash_attn import flash_attn_func
6+
import torch
7+
8+
from vllm.model_executor.input_metadata import InputMetadata
9+
from vllm.model_executor.layers.attention.ops.paged_attn import (
10+
PagedAttentionImpl)
11+
12+
13+
class FlashAttentionBackend:
14+
15+
def __init__(
16+
self,
17+
num_heads: int,
18+
head_size: int,
19+
scale: float,
20+
num_kv_heads: Optional[int] = None,
21+
alibi_slopes: Optional[List[float]] = None,
22+
sliding_window: Optional[int] = None,
23+
) -> None:
24+
self.num_heads = num_heads
25+
self.head_size = head_size
26+
self.scale = float(scale)
27+
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
28+
self.sliding_window = sliding_window
29+
if alibi_slopes is not None:
30+
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
31+
self.alibi_slopes = alibi_slopes
32+
33+
assert self.num_heads % self.num_kv_heads == 0
34+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
35+
suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes()
36+
if head_size not in suppored_head_sizes:
37+
raise ValueError(
38+
f"Head size {head_size} is not supported by PagedAttention. "
39+
f"Supported head sizes are: {suppored_head_sizes}.")
40+
41+
self.sliding_window = ((self.sliding_window, self.sliding_window) if
42+
self.sliding_window is not None else (-1, -1))
43+
44+
def forward(
45+
self,
46+
query: torch.Tensor,
47+
key: torch.Tensor,
48+
value: torch.Tensor,
49+
key_cache: Optional[torch.Tensor],
50+
value_cache: Optional[torch.Tensor],
51+
input_metadata: InputMetadata,
52+
) -> torch.Tensor:
53+
"""Forward pass with FlashAttention and PagedAttention.
54+
55+
Args:
56+
query: shape = [batch_size, seq_len, num_heads * head_size]
57+
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
58+
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
59+
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
60+
block_size, x]
61+
value_cache: shape = [num_blocks, num_kv_heads, head_size,
62+
block_size]
63+
input_metadata: metadata for the inputs.
64+
Returns:
65+
shape = [batch_size, seq_len, num_heads * head_size]
66+
"""
67+
batch_size, seq_len, hidden_size = query.shape
68+
# Reshape the query, key, and value tensors.
69+
query = query.view(-1, self.num_heads, self.head_size)
70+
key = key.view(-1, self.num_kv_heads, self.head_size)
71+
value = value.view(-1, self.num_kv_heads, self.head_size)
72+
73+
# Reshape the keys and values and store them in the cache.
74+
# If key_cache and value_cache are not provided, the new key and value
75+
# vectors will not be cached. This happens during the initial memory
76+
# profiling run.
77+
if key_cache is not None and value_cache is not None:
78+
PagedAttentionImpl.reshape_and_cache(key, value, key_cache,
79+
value_cache, input_metadata)
80+
81+
if input_metadata.is_prompt:
82+
# Prompt run.
83+
if (key_cache is None or value_cache is None
84+
or input_metadata.block_tables.numel() == 0):
85+
# normal attention
86+
query = query.unflatten(0, (batch_size, seq_len))
87+
key = key.unflatten(0, (batch_size, seq_len))
88+
value = value.unflatten(0, (batch_size, seq_len))
89+
output = flash_attn_func(
90+
query,
91+
key,
92+
value,
93+
softmax_scale=self.scale,
94+
causal=True,
95+
window_size=self.sliding_window,
96+
alibi_slopes=self.alibi_slopes,
97+
)
98+
else:
99+
# prefix-enabled attention
100+
output = PagedAttentionImpl.forward_prefix(
101+
query,
102+
key,
103+
value,
104+
key_cache,
105+
value_cache,
106+
input_metadata,
107+
self.num_heads,
108+
self.num_kv_heads,
109+
self.alibi_slopes,
110+
)
111+
else:
112+
# Decoding run.
113+
output = PagedAttentionImpl.forward_decode(
114+
query,
115+
key_cache,
116+
value_cache,
117+
input_metadata,
118+
self.num_kv_heads,
119+
self.scale,
120+
self.alibi_slopes,
121+
)
122+
123+
# Reshape the output tensor.
124+
return output.view(batch_size, seq_len, hidden_size)

0 commit comments

Comments
 (0)