Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Support MQA/GQA in decode phase by using FlashAttention #2744

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def main(args: argparse.Namespace):
dtype=args.dtype,
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
use_flash_attn=args.use_flash_attn,
device=args.device,
)

Expand Down Expand Up @@ -120,12 +121,16 @@ def run_to_completion(profile_dir: Optional[str] = None):
action='store_true',
help='enforce eager mode and disable CUDA graph')
parser.add_argument(
"--kv-cache-dtype",
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8_e5m2'],
default='auto',
help=
'Data type for kv cache storage. If "auto", will use model data type.')
parser.add_argument('--use-flash-attn',
action='store_true',
help='Use paged kv cache flash attention kernel. '
'Note this will rewrite block_size of kv cache.')
parser.add_argument(
'--profile',
action='store_true',
Expand Down
11 changes: 9 additions & 2 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def run_vllm(
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
use_flash_attn: bool,
device: str,
) -> float:
from vllm import LLM, SamplingParams
Expand All @@ -86,6 +87,7 @@ def run_vllm(
max_model_len=max_model_len,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
use_flash_attn=use_flash_attn,
device=device,
)

Expand Down Expand Up @@ -211,7 +213,8 @@ def main(args: argparse.Namespace):
args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype,
args.max_model_len, args.enforce_eager,
args.kv_cache_dtype, args.device)
args.kv_cache_dtype, args.use_flash_attn,
args.device)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -290,12 +293,16 @@ def main(args: argparse.Namespace):
action="store_true",
help="enforce eager execution")
parser.add_argument(
"--kv-cache-dtype",
'--kv-cache-dtype',
type=str,
choices=["auto", "fp8_e5m2"],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type.')
parser.add_argument('--use-flash-attn',
action='store_true',
help='Use paged kv cache flash attention kernel. '
'Note this will rewrite block_size of kv cache.')
parser.add_argument(
"--device",
type=str,
Expand Down
41 changes: 30 additions & 11 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time

import torch
from flash_attn.flash_attn_interface import flash_attn_with_kvcache

from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
from vllm._C import ops
Expand Down Expand Up @@ -64,14 +65,17 @@ def main(
block_tables = torch.tensor(block_tables, dtype=torch.int, device=device)

# Create the KV cache.
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
block_size,
1,
num_kv_heads,
head_size,
kv_cache_dtype,
dtype,
device=device)
use_flash_attn = version == "flash-attn"
key_caches, value_caches = create_kv_caches_with_random(
NUM_BLOCKS,
block_size,
1,
num_kv_heads,
head_size,
kv_cache_dtype,
dtype,
use_flash_attn=use_flash_attn,
device=device)
key_cache, value_cache = key_caches[0], value_caches[0]

# Prepare for the paged attention kernel.
Expand All @@ -97,6 +101,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
torch.cuda.cudart().cudaProfilerStart()
start_time = time.perf_counter()

nonlocal output
for _ in range(num_iters):
if version == "v1":
ops.paged_attention_v1(
Expand Down Expand Up @@ -131,6 +136,17 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
alibi_slopes,
kv_cache_dtype,
)
elif version == "flash-attn":
output = flash_attn_with_kvcache(
query.unsqueeze(1),
key_cache,
value_cache,
cache_seqlens=context_lens,
block_table=block_tables,
softmax_scale=scale,
causal=True,
alibi_slopes=alibi_slopes,
)
else:
raise ValueError(f"Invalid version: {version}")
torch.cuda.synchronize()
Expand Down Expand Up @@ -158,10 +174,10 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
description="Benchmark the paged attention kernel.")
parser.add_argument("--version",
type=str,
choices=["v1", "v2"],
default="v2")
choices=["v1", "v2", "flash-attn"],
default="flash-attn")
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--context-len", type=int, default=4096)
parser.add_argument("--context-len", type=int, default=1024)
parser.add_argument("--num-query-heads", type=int, default=64)
parser.add_argument("--num-kv-heads", type=int, default=8)
parser.add_argument("--head-size",
Expand All @@ -185,6 +201,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
'Data type for kv cache storage. If "auto", will use model data type.')
parser.add_argument("--device", type=str, choices=["cuda"], default="cuda")
args = parser.parse_args()
if args.version == "flash-attn":
# Paged KV cache block size in Flash Attention must be divisible by 256.
args.block_size = 256
print(args)

if args.num_query_heads % args.num_kv_heads != 0:
Expand Down
8 changes: 8 additions & 0 deletions csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ void copy_blocks(
std::vector<torch::Tensor>& value_caches,
const std::map<int64_t, std::vector<int64_t>>& block_mapping);

void cache(
torch::Tensor& key,
torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype);

void reshape_and_cache(
torch::Tensor& key,
torch::Tensor& value,
Expand Down
68 changes: 68 additions & 0 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,74 @@ void copy_blocks(

namespace vllm {

template<typename scalar_t>
__global__ void cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads, head_size]
scalar_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, head_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int stride,
const int num_heads,
const int head_size) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
// Padding token that should be ignored.
return;
}
const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int64_t src_idx = token_idx * stride + i;
const int64_t tgt_idx = slot_idx * n + i;
key_cache[tgt_idx] = key[src_idx];
value_cache[tgt_idx] = value[src_idx];
}
}

} // namespace vllm

void cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype)
{
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int stride = key.stride(0);

dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (kv_cache_dtype == "auto") {
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(),
"cache_kernel",
[&] {
vllm::cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(),
stride,
num_heads,
head_size);
});
} else if (kv_cache_dtype == "fp8_e5m2") {
TORCH_CHECK(false, "Cache kernel does not support kv cache data type: ", kv_cache_dtype);
} else {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}
}

namespace vllm {

template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
__global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
Expand Down
4 changes: 4 additions & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"copy_blocks",
&copy_blocks,
"Copy the cache blocks from src to dst");
cache_ops.def(
"cache",
&cache,
"Cache the key and value tensors");
cache_ops.def(
"reshape_and_cache",
&reshape_and_cache,
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ numpy
torch == 2.1.2
transformers >= 4.37.0 # Required for Qwen2
xformers == 0.0.23.post1 # Required for CUDA 12.1.
flash-attn >= 2.5.0
fastapi
uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
Expand Down
Loading
Loading