diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 2fdc08c5c26df..7b8ba49a79801 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -27,6 +27,7 @@ def main(args: argparse.Namespace): kv_cache_dtype=args.kv_cache_dtype, device=args.device, ray_workers_use_nsight=args.ray_workers_use_nsight, + use_flash_attn=args.use_flash_attn, ) sampling_params = SamplingParams( @@ -151,5 +152,9 @@ def run_to_completion(profile_dir: Optional[str] = None): action='store_true', help="If specified, use nsight to profile ray workers", ) + parser.add_argument( + "--use-flash-attn", + action="store_true", + help="Use flash attention (requires flash-attn >= 2.5.0).") args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 72bdc4b3b4540..ae686626bf3bc 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -75,6 +75,7 @@ def run_vllm( device: str, enable_prefix_caching: bool, gpu_memory_utilization: float = 0.9, + use_flash_attn: Optional[bool] = False, ) -> float: from vllm import LLM, SamplingParams llm = LLM(model=model, @@ -89,7 +90,8 @@ def run_vllm( enforce_eager=enforce_eager, kv_cache_dtype=kv_cache_dtype, device=device, - enable_prefix_caching=enable_prefix_caching) + enable_prefix_caching=enable_prefix_caching, + use_flash_attn=use_flash_attn) # Add the requests to the engine. for prompt, _, output_len in requests: @@ -213,7 +215,8 @@ def main(args: argparse.Namespace): args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, args.trust_remote_code, args.dtype, args.max_model_len, args.enforce_eager, args.kv_cache_dtype, args.device, - args.enable_prefix_caching, args.gpu_memory_utilization) + args.enable_prefix_caching, args.gpu_memory_utilization, + args.use_flash_attn) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -314,6 +317,10 @@ def main(args: argparse.Namespace): "--enable-prefix-caching", action='store_true', help="enable automatic prefix caching for vLLM backend.") + parser.add_argument( + "--use-flash-attn", + action="store_true", + help="Use flash attention (requires flash-attn >= 2.5.0).") args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/benchmarks/kernels/benchmark_attention.py b/benchmarks/kernels/benchmark_attention.py new file mode 100644 index 0000000000000..07fabb5028986 --- /dev/null +++ b/benchmarks/kernels/benchmark_attention.py @@ -0,0 +1,258 @@ +from typing import Optional +import argparse +import random +import time + +import numpy as np +import torch + +try: + from flash_attn import flash_attn_func, flash_attn_with_kvcache +except ImportError: + flash_attn_func, flash_attn_with_kvcache = None, None + +from xformers import ops as xops +from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask + +from vllm._C import cache_ops +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random + +NUM_BLOCKS = 1024 + + +@torch.inference_mode() +def main( + version: str, + num_seqs: int, + context_len: int, + num_query_heads: int, + num_kv_heads: int, + head_size: int, + use_alibi: bool, + block_size: int, + dtype: torch.dtype, + seed: int, + do_profile: bool, + device: str = "cuda", + kv_cache_dtype: Optional[str] = None, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + use_flash_attn = version in ["flash-attn", "flash-attn-kvcache"] + if use_flash_attn: + if dtype not in [torch.half, torch.bfloat16 + ] or kv_cache_dtype != "auto": + raise ValueError( + "skip: flash-attn requires dtype and kv_cache_dtype to be half or bfloat16" + ) + + context_lens = [context_len for _ in range(num_seqs)] + max_context_len = max(context_lens) + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device=device) + zero_context_lens_tensor = torch.zeros_like(context_lens_tensor) + + scale = float(1.0 / (head_size**0.5)) + qkv = torch.empty(num_seqs, + max_context_len, + num_query_heads + 2 * num_kv_heads, + head_size, + dtype=dtype, + device=device) + qkv.uniform_(-scale, scale) + query, key, value = qkv.split( + [num_query_heads, num_kv_heads, num_kv_heads], dim=2) + + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, + dtype=torch.float, + device=device) + + # Create the block tables. + if use_flash_attn: + block_size = ((block_size + 256 - 1) // 256) * 256 + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables, slot_mapping = [], [] + for seq_idx in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + slot_mapping.append([]) + for i in range(context_lens[seq_idx]): + block_number = block_table[i // block_size] + block_offset = i % block_size + slot = block_number * block_size + block_offset + slot_mapping[-1].append(slot) + for _ in range(max_context_len - context_lens[seq_idx]): + slot_mapping[-1].append(-1) + block_tables = torch.tensor(block_tables, dtype=torch.int, device=device) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=device) + + # Create the KV cache. + key_caches, value_caches = create_kv_caches_with_random( + NUM_BLOCKS, + block_size, + 1, + num_kv_heads, + head_size, + kv_cache_dtype, + dtype, + device=device, + use_flash_attn=use_flash_attn) + key_cache, value_cache = key_caches[0], value_caches[0] + + if version == "xformers": + attn_bias = BlockDiagonalCausalMask.from_seqlens(context_lens) + if num_queries_per_kv > 1: + # Handle MQA and GQA + key_repeated = torch.repeat_interleave(key, + num_queries_per_kv, + dim=2) + value_repeated = torch.repeat_interleave(value, + num_queries_per_kv, + dim=2) + else: + key_repeated = key + value_repeated = value + + def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: + torch.cuda.synchronize() + if profile: + torch.cuda.cudart().cudaProfilerStart() + start_time = time.perf_counter() + + for _ in range(num_iters): + if version == "xformers": + cache_ops.reshape_and_cache( + key.reshape(-1, *key.shape[2:]), + value.reshape(-1, *key.shape[2:]), + key_cache, + value_cache, + slot_mapping.flatten(), + kv_cache_dtype, + ) + output = xops.memory_efficient_attention_forward( + query.reshape(1, -1, *query.shape[2:]), + key_repeated.reshape(1, -1, *key_repeated.shape[2:]), + value_repeated.reshape(1, -1, *value_repeated.shape[2:]), + attn_bias=attn_bias, + p=0.0, + scale=scale, + ) + output = output.reshape(query.shape) + elif version == "flash-attn": + flat_slot_mapping = slot_mapping.flatten() + slot_block_index = flat_slot_mapping // block_size + slot_block_offset = flat_slot_mapping % block_size + key_cache[slot_block_index, + slot_block_offset, :, :] = key.reshape( + -1, *key.shape[2:]) + value_cache[slot_block_index, + slot_block_offset, :, :] = value.reshape( + -1, *key.shape[2:]) + output = flash_attn_func( + q=query, + k=key, + v=value, + softmax_scale=scale, + causal=True, + alibi_slopes=alibi_slopes, + ) + elif version == "flash-attn-kvcache": + output = flash_attn_with_kvcache( + q=query, + k_cache=key_cache, + v_cache=value_cache, + k=key, + v=value, + cache_seqlens=zero_context_lens_tensor, + block_table=block_tables, + softmax_scale=scale, + causal=True, + alibi_slopes=alibi_slopes, + ) + else: + raise ValueError(f"Invalid version: {version}") + torch.cuda.synchronize() + + end_time = time.perf_counter() + if profile: + torch.cuda.cudart().cudaProfilerStart() + return (end_time - start_time) / num_iters + + # Warmup. + print("Warming up...") + run_benchmark = run_cuda_benchmark + run_benchmark(num_iters=3, profile=False) + + # Benchmark. + if do_profile: + latency = run_benchmark(num_iters=1, profile=True) + else: + latency = run_benchmark(num_iters=100, profile=False) + print( + f"Version: {version}, Context Length: {context_len}, Batch size: {num_seqs}, Kernel running time: {latency * 1000000:.3f} us" + ) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description="Benchmark the paged attention kernel.") + parser.add_argument( + "--version", + type=str, + choices=["xformers", "flash-attn", "flash-attn-kvcache"], + default="xformers") + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument("--context-len", type=int, default=4096) + parser.add_argument("--num-query-heads", type=int, default=64) + parser.add_argument("--num-kv-heads", type=int, default=8) + parser.add_argument("--head-size", + type=int, + choices=[64, 80, 96, 112, 128, 256], + default=128) + parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) + parser.add_argument("--use-alibi", action="store_true") + parser.add_argument("--dtype", + type=str, + choices=["half", "bfloat16", "float"], + default="half") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--profile", action="store_true") + parser.add_argument( + "--kv-cache-dtype", + type=str, + choices=["auto", "fp8_e5m2"], + default="auto", + help= + 'Data type for kv cache storage. If "auto", will use model data type.') + parser.add_argument("--device", type=str, choices=["cuda"], default="cuda") + args = parser.parse_args() + print(args) + + if args.num_query_heads % args.num_kv_heads != 0: + raise ValueError("num_query_heads must be divisible by num_kv_heads") + main( + version=args.version, + num_seqs=args.batch_size, + context_len=args.context_len, + num_query_heads=args.num_query_heads, + num_kv_heads=args.num_kv_heads, + head_size=args.head_size, + block_size=args.block_size, + use_alibi=args.use_alibi, + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + seed=args.seed, + do_profile=args.profile, + kv_cache_dtype=args.kv_cache_dtype, + ) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index d921dea1220e1..dc30d8a1dcc88 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -5,6 +5,11 @@ import torch +try: + from flash_attn import flash_attn_with_kvcache +except ImportError: + flash_attn_with_kvcache = None + from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random from vllm._C import ops @@ -33,6 +38,14 @@ def main( if torch.cuda.is_available(): torch.cuda.manual_seed(seed) + use_flash_attn = version == "flash-attn" + if use_flash_attn: + if dtype not in [torch.half, torch.bfloat16 + ] or kv_cache_dtype != "auto": + raise ValueError( + "skip: flash-attn requires dtype and kv_cache_dtype to be half or bfloat16" + ) + scale = float(1.0 / (head_size**0.5)) query = torch.empty(num_seqs, num_query_heads, @@ -53,6 +66,8 @@ def main( context_lens = torch.tensor(context_lens, dtype=torch.int, device=device) # Create the block tables. + if use_flash_attn: + block_size = ((block_size + 256 - 1) // 256) * 256 max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size block_tables = [] for _ in range(num_seqs): @@ -64,14 +79,16 @@ def main( block_tables = torch.tensor(block_tables, dtype=torch.int, device=device) # Create the KV cache. - key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, - block_size, - 1, - num_kv_heads, - head_size, - kv_cache_dtype, - dtype, - device=device) + key_caches, value_caches = create_kv_caches_with_random( + NUM_BLOCKS, + block_size, + 1, + num_kv_heads, + head_size, + kv_cache_dtype, + dtype, + device=device, + use_flash_attn=use_flash_attn) key_cache, value_cache = key_caches[0], value_caches[0] # Prepare for the paged attention kernel. @@ -131,6 +148,17 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: alibi_slopes, kv_cache_dtype, ) + elif version == "flash-attn": + flash_attn_with_kvcache( + q=query.reshape(num_seqs, -1, *query.shape[1:]), + k_cache=key_cache, + v_cache=value_cache, + cache_seqlens=context_lens, + block_table=block_tables, + softmax_scale=scale, + causal=True, + alibi_slopes=alibi_slopes, + ) else: raise ValueError(f"Invalid version: {version}") torch.cuda.synchronize() @@ -158,7 +186,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: description="Benchmark the paged attention kernel.") parser.add_argument("--version", type=str, - choices=["v1", "v2"], + choices=["v1", "v2", "flash-attn"], default="v2") parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--context-len", type=int, default=4096) diff --git a/requirements.txt b/requirements.txt index d6c33ad85da58..ce820b1770b62 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ pynvml == 11.5.0 triton >= 2.1.0 outlines == 0.0.34 cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead. +flash-attn >= 2.5.0 diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index fb571de63d4e1..add9c992939f9 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -6,6 +6,11 @@ from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache +except ImportError: + flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache = None, None, None + from vllm._C import ops, cache_ops from vllm.utils import get_max_shared_memory_bytes from vllm.utils import is_hip @@ -111,7 +116,7 @@ def ref_single_query_cached_kv_attention( output[i].copy_(out, non_blocking=True) -@pytest.mark.parametrize("version", ["v1", "v2"]) +@pytest.mark.parametrize("version", ["v1", "v2", "flash-attn"]) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -144,6 +149,16 @@ def test_paged_attention( query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) query.uniform_(-scale, scale) + use_flash_attn = version == "flash-attn" + if use_flash_attn: + if dtype not in [torch.half, torch.bfloat16 + ] or kv_cache_dtype != "auto": + pytest.skip( + "flash-attn requires dtype and kv_cache_dtype to be half or bfloat16" + ) + if head_size >= 128: + pytest.skip("flash-attn tests may OOM due to larger block size") + assert num_query_heads % num_kv_heads == 0 num_queries_per_kv = num_query_heads // num_kv_heads alibi_slopes = None @@ -156,6 +171,8 @@ def test_paged_attention( context_lens = torch.tensor(context_lens, dtype=torch.int) # Create the block tables. + if use_flash_attn: + block_size = ((block_size + 256 - 1) // 256) * 256 max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size block_tables = [] for _ in range(num_seqs): @@ -170,7 +187,7 @@ def test_paged_attention( key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype, dtype, seed, - device) + device, use_flash_attn) key_cache, value_cache = key_caches[0], value_caches[0] # Call the paged attention kernel. @@ -221,13 +238,30 @@ def test_paged_attention( alibi_slopes, kv_cache_dtype, ) + elif version == "flash-attn": + output = flash_attn_with_kvcache( + q=query.reshape(num_seqs, -1, *query.shape[1:]), + k_cache=key_cache, + v_cache=value_cache, + cache_seqlens=context_lens, + block_table=block_tables, + softmax_scale=scale, + causal=True, + alibi_slopes=alibi_slopes, + ) + output = output.reshape_as(query) else: raise AssertionError(f"Unknown version: {version}") # Run the reference implementation. + x = 16 // torch.tensor([], dtype=dtype).element_size() + if use_flash_attn: + key_cache = key_cache.unflatten(-1, (head_size // x, x)).permute( + 0, 2, 4, 1, 3) + value_cache = value_cache.permute(0, 2, 3, 1) + if kv_cache_dtype == "fp8_e5m2": # Convert cache data back to dtype. - x = 16 // torch.tensor([], dtype=dtype).element_size() key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x) dequantized_key_cache = torch.empty(size=key_cache_shape, @@ -266,7 +300,9 @@ def test_paged_attention( # so we use a relaxed tolerance for the test. if kv_cache_dtype == "fp8_e5m2": atol, rtol = 1e-2, 1e-5 - assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) + if use_flash_attn and use_alibi: + atol, rtol = 2e-1, 5e-2 + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) def ref_multi_query_kv_attention( @@ -303,6 +339,8 @@ def ref_multi_query_kv_attention( # TODO(woosuk): Add tests for USE_ALIBI=True. +@pytest.mark.parametrize("version", + ["xformers", "flash-attn", "flash-attn-varlen"]) @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -311,6 +349,7 @@ def ref_multi_query_kv_attention( @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() def test_multi_query_kv_attention( + version: str, num_seqs: int, num_heads: Tuple[int, int], head_size: int, @@ -329,6 +368,17 @@ def test_multi_query_kv_attention( max_len = min(MAX_SEQ_LEN, 4096) seq_lens = random.sample(range(1, max_len), num_seqs) num_tokens = sum(seq_lens) + max_seq_len = max(seq_lens) + + use_flash_attn = version in ["flash-attn", "flash-attn-varlen"] + if use_flash_attn and dtype not in [torch.half, torch.bfloat16]: + pytest.skip( + "flash-attn requires kv_cache_dtype to be half or bfloat16") + + cu_seq_lens = [0] + for seq_len in seq_lens: + cu_seq_lens.append(cu_seq_lens[-1] + seq_len) + cu_seq_lens = torch.tensor(cu_seq_lens, dtype=torch.int, device=device) scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads @@ -343,30 +393,81 @@ def test_multi_query_kv_attention( num_queries_per_kv = num_query_heads // num_kv_heads if num_queries_per_kv > 1: # Handle MQA and GQA - key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) - value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) - attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) - output = xops.memory_efficient_attention_forward( - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), - attn_bias=attn_bias, - p=0.0, - scale=scale, - ) - output = output.squeeze(0) + key_repeated = torch.repeat_interleave(key, num_queries_per_kv, dim=1) + value_repeated = torch.repeat_interleave(value, + num_queries_per_kv, + dim=1) + else: + key_repeated = key + value_repeated = value + + if version == "xformers": + attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) + output = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key_repeated.unsqueeze(0), + value_repeated.unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=scale, + ) + output = output.squeeze(0) + elif version == "flash-attn": + # padding the inputs, use the same logic with batched prefill + # in attention.py. + qs, ks, vs = [], [], [] + for i, seq_len in enumerate(seq_lens): + left, right = cu_seq_lens[i], cu_seq_lens[i + 1] + qs.append( + torch.nn.functional.pad( + query[left:right], (0, 0, 0, 0, 0, max_seq_len - seq_len))) + ks.append( + torch.nn.functional.pad( + key[left:right], (0, 0, 0, 0, 0, max_seq_len - seq_len))) + vs.append( + torch.nn.functional.pad( + value[left:right], (0, 0, 0, 0, 0, max_seq_len - seq_len))) + query_padded = torch.stack(qs, dim=0) + key_padded = torch.stack(ks, dim=0) + value_padded = torch.stack(vs, dim=0) + + output = flash_attn_func( + query_padded, + key_padded, + value_padded, + softmax_scale=scale, + causal=True, + ) + outputs = [] + for i, seq_len in enumerate(seq_lens): + outputs.append(output[i, :seq_len]) + output = torch.cat(outputs, dim=0) + elif version == "flash-attn-varlen": + # We test `flash_attn_varlen_func` here (which is more equalivant to + # xformers's MEAF kernel), but it is not actually used in attention.py + # for prefilling as inputs are padded in vLLM. + output = flash_attn_varlen_func( + query, + key, + value, + cu_seqlens_q=cu_seq_lens, + cu_seqlens_k=cu_seq_lens, + max_seqlen_q=max_seq_len, + max_seqlen_k=max_seq_len, + softmax_scale=scale, + causal=True, + ) + else: + raise AssertionError(f"Unknown version: {version}") - cu_seq_lens = [0] - for seq_len in seq_lens: - cu_seq_lens.append(cu_seq_lens[-1] + seq_len) ref_output = ref_multi_query_kv_attention( cu_seq_lens, query, - key, - value, + key_repeated, + value_repeated, scale, dtype, ) atol = get_default_atol(output) if is_hip() else 1e-3 rtol = get_default_rtol(output) if is_hip() else 1e-5 - assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) diff --git a/vllm/config.py b/vllm/config.py index f792e89095246..2fd5de726867c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -7,6 +7,11 @@ import torch from transformers import PretrainedConfig +try: + import flash_attn +except ImportError: + flash_attn = None + from vllm.logger import init_logger from vllm.transformers_utils.config import get_config from vllm.utils import get_cpu_memory, is_hip, is_neuron, get_nvcc_cuda_version @@ -84,6 +89,7 @@ def __init__( enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, max_logprobs: int = 5, + use_flash_attn: Optional[bool] = False, ) -> None: self.model = model self.tokenizer = tokenizer @@ -99,6 +105,7 @@ def __init__( self.enforce_eager = enforce_eager self.max_context_len_to_capture = max_context_len_to_capture self.max_logprobs = max_logprobs + self.use_flash_attn = use_flash_attn if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true": # download model from ModelScope hub, @@ -124,6 +131,7 @@ def __init__( self._verify_tokenizer_mode() self._verify_quantization() self._verify_cuda_graph() + self._verify_flash_attn() def _verify_load_format(self) -> None: load_format = self.load_format.lower() @@ -213,6 +221,18 @@ def _verify_cuda_graph(self) -> None: self.max_context_len_to_capture = min(self.max_context_len_to_capture, self.max_model_len) + def _verify_flash_attn(self) -> None: + if flash_attn is None: + raise ValueError( + "flash-attn is not installed. Please install flash-attn>=2.5.0 to use " + "the flash attention kernel.") + if Version(flash_attn.__version__) < Version("2.5.0"): + raise ValueError( + "flash-attn >= 2.5.0 is required. Please upgrade flash-attn to " + "the latest version.") + if is_hip(): + raise ValueError("flash-attn cannot doesn't support ROCm.") + def verify_with_parallel_config( self, parallel_config: "ParallelConfig", diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3e146d2e6c0c4..88641fbbad45c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -52,6 +52,7 @@ class EngineArgs: max_cpu_loras: Optional[int] = None device: str = 'auto' ray_workers_use_nsight: bool = False + use_flash_attn: Optional[bool] = False def __post_init__(self): if self.tokenizer is None: @@ -310,6 +311,10 @@ def add_cli_args( default=EngineArgs.device, choices=["auto", "cuda", "neuron"], help='Device type for vLLM execution.') + parser.add_argument( + '--use-flash-attn', + action='store_true', + help='Use flash attention (requires flash-attn >= 2.5.0).') return parser @classmethod @@ -324,6 +329,12 @@ def create_engine_configs( self, ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig, DeviceConfig, Optional[LoRAConfig]]: + + if self.use_flash_attn: + # flash-attn's flash_attn_with_kvcache requires block size must be + # a multiple of 256. + self.block_size = ((self.block_size + 256 - 1) // 256) * 256 + device_config = DeviceConfig(self.device) model_config = ModelConfig( self.model, self.tokenizer, self.tokenizer_mode, @@ -331,7 +342,7 @@ def create_engine_configs( self.dtype, self.seed, self.revision, self.code_revision, self.tokenizer_revision, self.max_model_len, self.quantization, self.enforce_eager, self.max_context_len_to_capture, - self.max_logprobs) + self.max_logprobs, self.use_flash_attn) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index ebba0ba0a261a..fac0540653148 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -27,6 +27,7 @@ class InputMetadata: block_tables: Optional[torch.Tensor] use_cuda_graph: bool kv_cache_dtype: str + use_flash_attn: bool = False def __post_init__(self): # will not appear in the __repr__ and __init__ diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 4b63b9eaf59a7..3813e3d46acb6 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -36,14 +36,16 @@ def __init__( super().__init__() if _use_flash_attn(): from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend # noqa: E501 - self.backend = FlashAttentionBackend(num_heads, head_size, scale, - num_kv_heads, alibi_slopes, - sliding_window) + self.flash_attn_backend = FlashAttentionBackend( + num_heads, head_size, scale, + num_kv_heads, alibi_slopes, + sliding_window) else: from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend # noqa: E501 - self.backend = XFormersBackend(num_heads, head_size, scale, - num_kv_heads, alibi_slopes, - sliding_window) + self.xformer_backend = XFormersBackend( + num_heads, head_size, scale, + num_kv_heads, alibi_slopes, + sliding_window) def forward( self, @@ -54,8 +56,15 @@ def forward( value_cache: Optional[torch.Tensor], input_metadata: InputMetadata, ) -> torch.Tensor: - return self.backend.forward(query, key, value, key_cache, value_cache, - input_metadata) + if input_metadata.use_flash_attn and \ + self.flash_attn_backend is not None: + return self.flash_attn_backend.forward(query, key, value, + key_cache, value_cache, + input_metadata) + else: + return self.xformer_backend.forward(query, key, value, + key_cache, value_cache, + input_metadata) @lru_cache(maxsize=1) diff --git a/vllm/model_executor/layers/attention/backends/flash_attn.py b/vllm/model_executor/layers/attention/backends/flash_attn.py index 58ccd461b993e..d601f006c7a08 100644 --- a/vllm/model_executor/layers/attention/backends/flash_attn.py +++ b/vllm/model_executor/layers/attention/backends/flash_attn.py @@ -1,7 +1,7 @@ """Attention layer with Flash and PagedAttention.""" from typing import List, Optional -from flash_attn import flash_attn_func +from flash_attn import flash_attn_func, flash_attn_with_kvcache import torch from vllm.model_executor.input_metadata import InputMetadata @@ -74,17 +74,22 @@ def forward( # vectors will not be cached. This happens during the initial memory # profiling run. if key_cache is not None and value_cache is not None: - PagedAttentionImpl.reshape_and_cache(key, value, key_cache, - value_cache, input_metadata) + # Update kv-cache using tensor indexing. We don't use the kernel + # `flash_attn_with_kvcache` for kv-cache updating as it submitted + # many small kernels for each key/value and is slow. + flatten_slot_mapping = input_metadata.slot_mapping.flatten() + slot_block_index = flatten_slot_mapping // key_cache.shape[1] + slot_block_offset = flatten_slot_mapping % key_cache.shape[1] + key_cache[slot_block_index, slot_block_offset, :, :] = key + value_cache[slot_block_index, slot_block_offset, :, :] = value if input_metadata.is_prompt: - # Prompt run. + # normal attention + query = query.unflatten(0, (batch_size, seq_len)) + key = key.unflatten(0, (batch_size, seq_len)) + value = value.unflatten(0, (batch_size, seq_len)) if (key_cache is None or value_cache is None - or input_metadata.block_tables.numel() == 0): - # normal attention - query = query.unflatten(0, (batch_size, seq_len)) - key = key.unflatten(0, (batch_size, seq_len)) - value = value.unflatten(0, (batch_size, seq_len)) + or not input_metadata.context_lens.any()): output = flash_attn_func( query, key, @@ -96,25 +101,52 @@ def forward( ) else: # prefix-enabled attention - output = PagedAttentionImpl.forward_prefix( - query, - key, - value, - key_cache, - value_cache, - input_metadata, - self.alibi_slopes, + output = flash_attn_with_kvcache( + q=query, + k_cache=key_cache, + v_cache=value_cache, + cache_seqlens=input_metadata.context_lens + seq_len, + block_table=input_metadata.block_tables, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, ) else: # Decoding run. - output = PagedAttentionImpl.forward_decode( - query, - key_cache, - value_cache, - input_metadata, - self.num_kv_heads, - self.scale, - self.alibi_slopes, + + # NOTE: in `_prepare_prompt` and `_prepare_decode` is filled in + # different manner (which may needs to be fixed in the future): in + # the former `context_lens` is the length of contexts whose kv-cache + # has been stored in previous rounds, (e.g., with prefix cache). + # However, in the later `context_lens` is the length of current + # attention context (includes the token whose kv-cache will be + # computed and filled in this round). + # + # - The kernel `flash_attn_with_kvcache` expects `cache_seqlens` to + # be the length of the context whose kv-cache has been stored. + # - The kernel `context_attention_fwd` expects it to be the length + # of already computed query-key-values in previous rounds. + # - The kernel `paged_attention_v1/v2` expect it to be the length of + # current attention context., same as flash-attn (without k & v). + # + # The flash-attn kernel can also be used with the k/v in current + # round as argument as they will be stored into the key-value cache + # inside the kernel. In which case, the `cache_seqlens` is expected + # to be the context length of tokens whose k/v has already been + # stored into kv-cache. However, it is found inefficient (especially + # for prompting) due to too many calls of cudaMemcpy kernels. + + # see also: https://github.com/Dao-AILab/flash-attention/commit/54e80a3829c6d2337570d01e78ebd9529c02d342 + output = flash_attn_with_kvcache( + q=query.reshape(batch_size, -1, *query.shape[1:]), + k_cache=key_cache, + v_cache=value_cache, + cache_seqlens=input_metadata.context_lens, + block_table=input_metadata.block_tables, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, ) # Reshape the output tensor. diff --git a/vllm/utils.py b/vllm/utils.py index d4a8c962c3bfc..98412e93e206b 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -274,6 +274,7 @@ def create_kv_caches_with_random( model_dtype: Optional[Union[str, torch.dtype]] = None, seed: Optional[int] = 0, device: Optional[str] = "cuda", + use_flash_attn: Optional[bool] = False, ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -299,8 +300,12 @@ def create_kv_caches_with_random( raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") scale = head_size**-0.5 - x = 16 // torch.tensor([], dtype=torch_dtype).element_size() - key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + if use_flash_attn: + key_cache_shape = (num_blocks, block_size, num_heads, head_size) + else: + x = 16 // torch.tensor([], dtype=torch_dtype).element_size() + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, + x) key_caches = [] for _ in range(num_layers): key_cache = torch.empty(size=key_cache_shape, @@ -315,7 +320,10 @@ def create_kv_caches_with_random( f"Does not support key cache of type {cache_dtype}") key_caches.append(key_cache) - value_cache_shape = (num_blocks, num_heads, head_size, block_size) + if use_flash_attn: + value_cache_shape = (num_blocks, block_size, num_heads, head_size) + else: + value_cache_shape = (num_blocks, num_heads, head_size, block_size) value_caches = [] for _ in range(num_layers): value_cache = torch.empty(size=value_cache_shape, diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 880299783935c..b8a863431fdd4 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -34,6 +34,17 @@ def __init__( self.num_layers = model_config.get_num_layers(parallel_config) self.num_heads = model_config.get_num_kv_heads(parallel_config) + if cache_config.cache_dtype == "auto": + self.dtype = model_config.dtype + else: + self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + if model_config.use_flash_attn and self.dtype not in [ + torch.half, torch.bfloat16 + ]: + raise ValueError( + "flash-attn requires cache_dtype to be half or bfloat16") + self.block_size = cache_config.block_size self.num_gpu_blocks = cache_config.num_gpu_blocks self.num_cpu_blocks = cache_config.num_cpu_blocks @@ -42,11 +53,6 @@ def __init__( if is_neuron(): return - if cache_config.cache_dtype == "auto": - self.dtype = model_config.dtype - else: - self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] - # Initialize the cache. self.gpu_cache = self.allocate_gpu_cache() self.cpu_cache = self.allocate_cpu_cache() @@ -58,21 +64,35 @@ def __init__( self.events = [torch.cuda.Event() for _ in range(self.num_layers)] def get_key_block_shape(self) -> Tuple[int, int, int, int]: - element_size = torch.tensor([], dtype=self.dtype).element_size() - x = 16 // element_size - return ( - self.num_heads, - self.head_size // x, - self.block_size, - x, - ) + if self.model_config.use_flash_attn: + return ( + self.block_size, + self.num_heads, + self.head_size, + ) + else: + element_size = torch.tensor([], dtype=self.dtype).element_size() + x = 16 // element_size + return ( + self.num_heads, + self.head_size // x, + self.block_size, + x, + ) def get_value_block_shape(self) -> Tuple[int, int, int]: - return ( - self.num_heads, - self.head_size, - self.block_size, - ) + if self.model_config.use_flash_attn: + return ( + self.block_size, + self.num_heads, + self.head_size, + ) + else: + return ( + self.num_heads, + self.head_size, + self.block_size, + ) def allocate_gpu_cache(self) -> List[KVCache]: gpu_cache: List[KVCache] = [] @@ -159,6 +179,8 @@ def get_cache_block_size( model_config: ModelConfig, parallel_config: ParallelConfig, ) -> int: + ''' Returns the nbytes of kv cache for a single token. + ''' head_size = model_config.get_head_size() num_heads = model_config.get_num_kv_heads(parallel_config) num_layers = model_config.get_num_layers(parallel_config) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1ef783da6d08e..c6912b4ca022c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -149,18 +149,28 @@ def _prepare_prompt( prompt_tokens = seq_data.get_token_ids() prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) + computed_len = 0 # NOTE: This only works for oooooooxxx style attention. computed_block_nums = seq_group_metadata.computed_block_nums - if computed_block_nums is not None and len( - computed_block_nums) > 0 and self.sliding_window is None: - # Prefix is not supported with sliding_window - computed_len = len(computed_block_nums) * self.block_size - prompt_tokens = prompt_tokens[computed_len:] - prefix_block_tables.append(computed_block_nums) + if computed_block_nums is not None: + if len(computed_block_nums) > 0 and self.sliding_window is None: + # Prefix is not supported with sliding_window + computed_len = len(computed_block_nums) * self.block_size + prompt_tokens = prompt_tokens[computed_len:] + current_block_tables = computed_block_nums else: - prefix_block_tables.append([]) + current_block_tables = [] + + # append seq groups's block table as the key-value cache + # will be updated (cached) by the flash-attn kernels + if seq_group_metadata.block_tables: + current_block_tables.extend( + seq_group_metadata.block_tables[seq_id] + [len(current_block_tables):]) + prefix_block_tables.append(current_block_tables) + # actual prompt lens context_lens.append(computed_len) subquery_lens.append(prompt_len - computed_len) @@ -265,6 +275,7 @@ def _prepare_prompt( block_tables=block_tables, use_cuda_graph=False, kv_cache_dtype=self.kv_cache_dtype, + use_flash_attn=getattr(self.model_config, 'use_flash_attn', False), ) return (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, @@ -393,6 +404,7 @@ def _prepare_decode( block_tables=block_tables, use_cuda_graph=use_captured_graph, kv_cache_dtype=self.kv_cache_dtype, + use_flash_attn=self.model_config.use_flash_attn, ) return (input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests) @@ -720,6 +732,8 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: block_tables=block_tables[:batch_size], use_cuda_graph=True, kv_cache_dtype=self.kv_cache_dtype, + use_flash_attn=getattr(self.model_config, 'use_flash_attn', + False), ) if self.lora_config: