From 5c6c40fdd9c714308abc12823214ecb35da3376d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 8 Mar 2024 02:34:06 +0000 Subject: [PATCH] Revert "Separate attention backends (#3005)" This reverts commit 2daf23ab0cf00da157b1255faddcf0a269283d36. --- .gitignore | 3 - setup.py | 48 +--- tests/kernels/test_prefix_prefill.py | 2 +- vllm/__init__.py | 30 +-- .../backends/xformers.py => attention.py} | 216 +++++++++++++----- .../layers/attention/__init__.py | 5 - .../layers/attention/attention.py | 59 ----- .../layers/attention/backends/flash_attn.py | 124 ---------- .../layers/attention/ops/__init__.py | 0 .../layers/attention/ops/paged_attn.py | 138 ----------- .../backends => triton_kernel}/__init__.py | 0 .../ops => triton_kernel}/prefix_prefill.py | 0 vllm/model_executor/models/baichuan.py | 13 +- vllm/model_executor/models/bloom.py | 10 +- vllm/model_executor/models/chatglm.py | 4 +- vllm/model_executor/models/deepseek.py | 10 +- vllm/model_executor/models/falcon.py | 28 +-- vllm/model_executor/models/gemma.py | 10 +- vllm/model_executor/models/gpt2.py | 6 +- vllm/model_executor/models/gpt_bigcode.py | 10 +- vllm/model_executor/models/gpt_j.py | 4 +- vllm/model_executor/models/gpt_neox.py | 4 +- vllm/model_executor/models/internlm2.py | 10 +- vllm/model_executor/models/llama.py | 12 +- vllm/model_executor/models/mixtral.py | 4 +- vllm/model_executor/models/mixtral_quant.py | 4 +- vllm/model_executor/models/mpt.py | 12 +- vllm/model_executor/models/olmo.py | 8 +- vllm/model_executor/models/opt.py | 8 +- vllm/model_executor/models/orion.py | 10 +- vllm/model_executor/models/phi.py | 4 +- vllm/model_executor/models/qwen.py | 4 +- vllm/model_executor/models/qwen2.py | 12 +- vllm/model_executor/models/stablelm.py | 10 +- vllm/model_executor/models/starcoder2.py | 4 +- 35 files changed, 268 insertions(+), 558 deletions(-) rename vllm/model_executor/layers/{attention/backends/xformers.py => attention.py} (56%) delete mode 100644 vllm/model_executor/layers/attention/__init__.py delete mode 100644 vllm/model_executor/layers/attention/attention.py delete mode 100644 vllm/model_executor/layers/attention/backends/flash_attn.py delete mode 100644 vllm/model_executor/layers/attention/ops/__init__.py delete mode 100644 vllm/model_executor/layers/attention/ops/paged_attn.py rename vllm/model_executor/layers/{attention/backends => triton_kernel}/__init__.py (100%) rename vllm/model_executor/layers/{attention/ops => triton_kernel}/prefix_prefill.py (100%) diff --git a/.gitignore b/.gitignore index 0b14c98270c41..b5195629e5cf3 100644 --- a/.gitignore +++ b/.gitignore @@ -184,6 +184,3 @@ _build/ # Benchmark dataset *.json - -# Third-party Python packages. -vllm/thirdparty_files/ diff --git a/setup.py b/setup.py index 57d7a139e8237..745b5a9b2d02a 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,6 @@ import os import re import subprocess -import sys import warnings from pathlib import Path from typing import List, Set @@ -15,8 +14,6 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME ROOT_DIR = os.path.dirname(__file__) -# This is a temporary directory to store third-party packages. -THIRDPARTY_SUBDIR = "vllm/thirdparty_files" # If you are developing the C++ backend of vLLM, consider building vLLM with # `python setup.py develop` since it will give you incremental builds. @@ -327,46 +324,8 @@ def get_torch_arch_list() -> Set[str]: "nvcc": NVCC_FLAGS_PUNICA, }, )) - - # Download the FlashAttention package. - # Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/setup.py#L518-L530 - flash_attn_version = "2.5.6" - install_dir = os.path.join(ROOT_DIR, THIRDPARTY_SUBDIR) - subprocess.check_call( - [ - sys.executable, - "-m", - "pip", - "install", - "-q", - f"--target={install_dir}", - "einops", # Dependency of flash-attn. - f"flash-attn=={flash_attn_version}", - "--no-dependencies", # Required to avoid re-installing torch. - ], - env=dict(os.environ, CC="gcc"), - ) - - # Copy the FlashAttention package into the vLLM package after build. - class build_ext(BuildExtension): - - def run(self): - super().run() - target_dir = os.path.join(self.build_lib, THIRDPARTY_SUBDIR) - if not os.path.exists(target_dir): - os.makedirs(target_dir) - self.copy_tree(install_dir, target_dir) - - class BinaryDistribution(setuptools.Distribution): - - def has_ext_modules(self): - return True - -else: - build_ext = BuildExtension - BinaryDistribution = setuptools.Distribution - if _is_neuron(): - neuronxcc_version = get_neuronxcc_version() +elif _is_neuron(): + neuronxcc_version = get_neuronxcc_version() vllm_extension_sources = [ "csrc/cache_kernels.cu", @@ -509,7 +468,6 @@ def get_requirements() -> List[str]: python_requires=">=3.8", install_requires=get_requirements(), ext_modules=ext_modules, - cmdclass={"build_ext": build_ext} if not _is_neuron() else {}, - distclass=BinaryDistribution, + cmdclass={"build_ext": BuildExtension} if not _is_neuron() else {}, package_data=package_data, ) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index e881cd1ec3753..c068b38a66910 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -3,7 +3,7 @@ import time import torch -from vllm.model_executor.layers.attention.ops.prefix_prefill import ( +from vllm.model_executor.layers.triton_kernel.prefix_prefill import ( context_attention_fwd) from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask diff --git a/vllm/__init__.py b/vllm/__init__.py index 59f1345b58d42..f1e30f5eb6e6e 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -1,28 +1,12 @@ """vLLM: a high-throughput and memory-efficient inference engine for LLMs""" - -# Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/ray/__init__.py#L11 -def _configure_system(): - import os - import sys - - # Importing flash-attn. - thirdparty_files = os.path.join(os.path.abspath(os.path.dirname(__file__)), - "thirdparty_files") - sys.path.insert(0, thirdparty_files) - - -_configure_system() -# Delete configuration function. -del _configure_system - -from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs # noqa: E402 -from vllm.engine.async_llm_engine import AsyncLLMEngine # noqa: E402 -from vllm.engine.llm_engine import LLMEngine # noqa: E402 -from vllm.engine.ray_utils import initialize_cluster # noqa: E402 -from vllm.entrypoints.llm import LLM # noqa: E402 -from vllm.outputs import CompletionOutput, RequestOutput # noqa: E402 -from vllm.sampling_params import SamplingParams # noqa: E402 +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.llm_engine import LLMEngine +from vllm.engine.ray_utils import initialize_cluster +from vllm.entrypoints.llm import LLM +from vllm.outputs import CompletionOutput, RequestOutput +from vllm.sampling_params import SamplingParams __version__ = "0.3.3" diff --git a/vllm/model_executor/layers/attention/backends/xformers.py b/vllm/model_executor/layers/attention.py similarity index 56% rename from vllm/model_executor/layers/attention/backends/xformers.py rename to vllm/model_executor/layers/attention.py index bad2a648b6703..2a82325b80213 100644 --- a/vllm/model_executor/layers/attention/backends/xformers.py +++ b/vllm/model_executor/layers/attention.py @@ -1,19 +1,37 @@ -"""Attention layer with xFormers and PagedAttention.""" -import importlib +"""Multi-head attention.""" from typing import List, Optional +import importlib import torch +import torch.nn as nn from xformers import ops as xops from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, LowerTriangularMaskWithTensorBias) +from vllm._C import ops +from vllm._C import cache_ops from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.attention.ops.paged_attn import ( - PagedAttentionImpl) +from vllm.model_executor.layers.triton_kernel.prefix_prefill import ( + context_attention_fwd) from vllm.utils import is_hip +_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] +# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. +_PARTITION_SIZE = 512 + + +class PagedAttention(nn.Module): + """MHA/MQA/GQA layer with PagedAttention. -class XFormersBackend: + This class takes query, key, and value tensors as input. The input tensors + can either contain prompt tokens or generation tokens. + The class does the following: + + 1. Reshape and store the input key and value tensors in the KV cache. + 2. Perform (multi-head/multi-query/grouped-query) attention using either + xformers or the PagedAttention custom op. + 3. Return the output tensor. + """ def __init__( self, @@ -24,6 +42,7 @@ def __init__( alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, ) -> None: + super().__init__() self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -31,17 +50,48 @@ def __init__( self.sliding_window = sliding_window if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes + self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes() - if head_size not in suppored_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") - self.use_ref_attention = _check_use_ref_attention() + if self.head_size not in _SUPPORTED_HEAD_SIZES: + raise ValueError(f"head_size ({self.head_size}) is not supported. " + f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.") + + self.use_ref_attention = self.check_use_ref_attention() + + def check_use_ref_attention(self) -> bool: + if not is_hip(): + return False + # For ROCm, check whether flash attention is installed or not. + # if not, use_ref_attention needs to be True + return importlib.util.find_spec("flash_attn") is None + + def ref_masked_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + ) -> torch.Tensor: + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + seq_len, _, _ = query.shape + attn_mask = torch.triu(torch.ones(seq_len, + seq_len, + dtype=query.dtype, + device=query.device), + diagonal=1) + attn_mask = attn_mask * torch.finfo(query.dtype).min + + attn_weights = self.scale * torch.einsum("qhd,khd->hqk", query, + key).float() + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out def forward( self, @@ -52,7 +102,7 @@ def forward( value_cache: Optional[torch.Tensor], input_metadata: InputMetadata, ) -> torch.Tensor: - """Forward pass with xFormers and PagedAttention. + """PagedAttention forward pass. Args: query: shape = [batch_size, seq_len, num_heads * head_size] @@ -77,14 +127,19 @@ def forward( # vectors will not be cached. This happens during the initial memory # profiling run. if key_cache is not None and value_cache is not None: - PagedAttentionImpl.reshape_and_cache(key, value, key_cache, - value_cache, input_metadata) + cache_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + input_metadata.slot_mapping.flatten(), + input_metadata.kv_cache_dtype, + ) if input_metadata.is_prompt: - # Prompt run. + # normal attention if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): - # normal attention if self.num_kv_heads != self.num_heads: # As of Nov 2023, xformers only supports MHA. For MQA/GQA, # project the key and value tensors to the desired number of @@ -120,19 +175,13 @@ def forward( seq_len, query.dtype) if self.use_ref_attention: - output = _ref_masked_attention( + output = self.ref_masked_attention( query, key, value, - self.num_heads, - self.num_kv_heads, - self.head_size, - self.scale, ) - # Using view got RuntimeError: view size is not compatible - # with input tensor's size and stride (at least one - # dimension spans across two contiguous subspaces). - # Use reshape instead. + # Using view got RuntimeError: view size is not compatible with input tensor's size and stride + # (at least one dimension spans across two contiguous subspaces). Use reshape instead return output.reshape(batch_size, seq_len, hidden_size) # TODO(woosuk): Too many view operations. Let's try to reduce @@ -157,21 +206,27 @@ def forward( (is_hip()) else None, ) output = out.view_as(query) - else: # prefix-enabled attention - output = PagedAttentionImpl.forward_prefix( + output = torch.empty_like(query) + context_attention_fwd( query, key, value, + output, key_cache, value_cache, - input_metadata, - self.alibi_slopes, + input_metadata.block_tables, # [BS, max_block_per_request] + input_metadata.start_loc, + input_metadata.prompt_lens, + input_metadata.context_lens, + input_metadata.max_seq_len, + getattr(self, "alibi_slopes", None), ) + else: # Decoding run. - output = PagedAttentionImpl.forward_decode( + output = _paged_attention( query, key_cache, value_cache, @@ -219,37 +274,76 @@ def _make_alibi_bias( return attn_bias -def _check_use_ref_attention() -> bool: - if not is_hip(): - return False - # For ROCm, check whether flash attention is installed or not. - # if not, use_ref_attention needs to be True - return importlib.util.find_spec("flash_attn") is None - - -def _ref_masked_attention( +def _paged_attention( query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - num_heads: int, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + input_metadata: InputMetadata, num_kv_heads: int, - head_size: int, scale: float, + alibi_slopes: Optional[torch.Tensor], ) -> torch.Tensor: - query = query.view(-1, num_heads, head_size) - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) - - seq_len, _, _ = query.shape - attn_mask = torch.triu(torch.ones(seq_len, - seq_len, - dtype=query.dtype, - device=query.device), - diagonal=1) - attn_mask = attn_mask * torch.finfo(query.dtype).min - - attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() - attn_weights = attn_weights + attn_mask.float() - attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) - out = torch.einsum("hqk,khd->qhd", attn_weights, value) - return out + output = torch.empty_like(query) + + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = ( + (input_metadata.max_context_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + # TODO(woosuk): Tune this heuristic. + # For context len > 8192, use V2 kernel to avoid shared memory shortage. + use_v1 = input_metadata.max_context_len <= 8192 and ( + max_num_partitions == 1 or num_seqs * num_heads > 512) + if use_v1: + # Run PagedAttention V1. + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + input_metadata.block_tables, + input_metadata.context_lens, + block_size, + input_metadata.max_context_len, + alibi_slopes, + input_metadata.kv_cache_dtype, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + input_metadata.block_tables, + input_metadata.context_lens, + block_size, + input_metadata.max_context_len, + alibi_slopes, + input_metadata.kv_cache_dtype, + ) + return output diff --git a/vllm/model_executor/layers/attention/__init__.py b/vllm/model_executor/layers/attention/__init__.py deleted file mode 100644 index 1c42a3d28f976..0000000000000 --- a/vllm/model_executor/layers/attention/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from vllm.model_executor.layers.attention.attention import Attention - -__all__ = [ - "Attention", -] diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py deleted file mode 100644 index 830e82e10f7ad..0000000000000 --- a/vllm/model_executor/layers/attention/attention.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Attention layer.""" -from typing import List, Optional - -import torch -import torch.nn as nn - -from vllm.model_executor.input_metadata import InputMetadata -from vllm.utils import is_hip - - -class Attention(nn.Module): - """Attention layer. - - This class takes query, key, and value tensors as input. The input tensors - can either contain prompt tokens or generation tokens. - The class does the following: - - 1. Store the input key and value tensors in the KV cache. - 2. Perform (multi-head/multi-query/grouped-query) attention. - 3. Return the output tensor. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, - ) -> None: - super().__init__() - if (not is_hip() and torch.cuda.get_device_capability()[0] >= 8 and - torch.get_default_dtype() in (torch.float16, torch.bfloat16)): - # Ampere or later NVIDIA GPUs. - # NOTE(woosuk): FlashAttention does not support FP32. - from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend - self.backend = FlashAttentionBackend(num_heads, head_size, scale, - num_kv_heads, alibi_slopes, - sliding_window) - else: - # Turing and Volta NVIDIA GPUs or AMD GPUs. - # Or FP32 on any GPU. - from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend - self.backend = XFormersBackend(num_heads, head_size, scale, - num_kv_heads, alibi_slopes, - sliding_window) - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - key_cache: Optional[torch.Tensor], - value_cache: Optional[torch.Tensor], - input_metadata: InputMetadata, - ) -> torch.Tensor: - return self.backend.forward(query, key, value, key_cache, value_cache, - input_metadata) diff --git a/vllm/model_executor/layers/attention/backends/flash_attn.py b/vllm/model_executor/layers/attention/backends/flash_attn.py deleted file mode 100644 index 512f4e49c7eb2..0000000000000 --- a/vllm/model_executor/layers/attention/backends/flash_attn.py +++ /dev/null @@ -1,124 +0,0 @@ -"""Attention layer with Flash and PagedAttention.""" -from typing import List, Optional - -# NOTE(woosuk): This imports flash_attn under vllm/thirdparty_files/. -from flash_attn import flash_attn_func -import torch - -from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.attention.ops.paged_attn import ( - PagedAttentionImpl) - - -class FlashAttentionBackend: - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, - ) -> None: - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = sliding_window - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes() - if head_size not in suppored_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") - - self.sliding_window = ((self.sliding_window, self.sliding_window) if - self.sliding_window is not None else (-1, -1)) - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - key_cache: Optional[torch.Tensor], - value_cache: Optional[torch.Tensor], - input_metadata: InputMetadata, - ) -> torch.Tensor: - """Forward pass with FlashAttention and PagedAttention. - - Args: - query: shape = [batch_size, seq_len, num_heads * head_size] - key: shape = [batch_size, seq_len, num_kv_heads * head_size] - value: shape = [batch_size, seq_len, num_kv_heads * head_size] - key_cache: shape = [num_blocks, num_kv_heads, head_size/x, - block_size, x] - value_cache: shape = [num_blocks, num_kv_heads, head_size, - block_size] - input_metadata: metadata for the inputs. - Returns: - shape = [batch_size, seq_len, num_heads * head_size] - """ - batch_size, seq_len, hidden_size = query.shape - # Reshape the query, key, and value tensors. - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - - # Reshape the keys and values and store them in the cache. - # If key_cache and value_cache are not provided, the new key and value - # vectors will not be cached. This happens during the initial memory - # profiling run. - if key_cache is not None and value_cache is not None: - PagedAttentionImpl.reshape_and_cache(key, value, key_cache, - value_cache, input_metadata) - - if input_metadata.is_prompt: - # Prompt run. - if (key_cache is None or value_cache is None - or input_metadata.block_tables.numel() == 0): - # normal attention - query = query.unflatten(0, (batch_size, seq_len)) - key = key.unflatten(0, (batch_size, seq_len)) - value = value.unflatten(0, (batch_size, seq_len)) - output = flash_attn_func( - query, - key, - value, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - ) - else: - # prefix-enabled attention - output = PagedAttentionImpl.forward_prefix( - query, - key, - value, - key_cache, - value_cache, - input_metadata, - self.num_heads, - self.num_kv_heads, - self.alibi_slopes, - ) - else: - # Decoding run. - output = PagedAttentionImpl.forward_decode( - query, - key_cache, - value_cache, - input_metadata, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - ) - - # Reshape the output tensor. - return output.view(batch_size, seq_len, hidden_size) diff --git a/vllm/model_executor/layers/attention/ops/__init__.py b/vllm/model_executor/layers/attention/ops/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/vllm/model_executor/layers/attention/ops/paged_attn.py b/vllm/model_executor/layers/attention/ops/paged_attn.py deleted file mode 100644 index c5a9618c2395b..0000000000000 --- a/vllm/model_executor/layers/attention/ops/paged_attn.py +++ /dev/null @@ -1,138 +0,0 @@ -from typing import List, Optional - -import torch - -from vllm._C import cache_ops -from vllm._C import ops -from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.attention.ops.prefix_prefill import ( - context_attention_fwd) - -# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. -_PARTITION_SIZE = 512 - - -class PagedAttentionImpl: - - @staticmethod - def get_supported_head_sizes() -> List[int]: - return [64, 80, 96, 112, 128, 256] - - @staticmethod - def reshape_and_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - input_metadata: InputMetadata, - ) -> None: - cache_ops.reshape_and_cache( - key, - value, - key_cache, - value_cache, - input_metadata.slot_mapping.flatten(), - input_metadata.kv_cache_dtype, - ) - - @staticmethod - def forward_decode( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - input_metadata: InputMetadata, - num_kv_heads: int, - scale: float, - alibi_slopes: Optional[torch.Tensor], - ) -> torch.Tensor: - output = torch.empty_like(query) - - block_size = value_cache.shape[3] - num_seqs, num_heads, head_size = query.shape - max_num_partitions = ( - (input_metadata.max_context_len + _PARTITION_SIZE - 1) // - _PARTITION_SIZE) - # NOTE(woosuk): We use a simple heuristic to decide whether to use - # PagedAttention V1 or V2. If the number of partitions is 1, we use - # V1 to avoid the overhead of reduction. Also, if the number of - # sequences or heads is large, we use V1 since there is enough work - # to parallelize. - # TODO(woosuk): Tune this heuristic. - # For context len > 8192, use V2 kernel to avoid shared memory shortage. - use_v1 = input_metadata.max_context_len <= 8192 and ( - max_num_partitions == 1 or num_seqs * num_heads > 512) - if use_v1: - # Run PagedAttention V1. - ops.paged_attention_v1( - output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - input_metadata.block_tables, - input_metadata.context_lens, - block_size, - input_metadata.max_context_len, - alibi_slopes, - input_metadata.kv_cache_dtype, - ) - else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=output.dtype, - device=output.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=output.device, - ) - max_logits = torch.empty_like(exp_sums) - ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - input_metadata.block_tables, - input_metadata.context_lens, - block_size, - input_metadata.max_context_len, - alibi_slopes, - input_metadata.kv_cache_dtype, - ) - return output - - @staticmethod - def forward_prefix( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - input_metadata: InputMetadata, - alibi_slopes: Optional[torch.Tensor], - ) -> torch.Tensor: - output = torch.empty_like(query) - context_attention_fwd( - query, - key, - value, - output, - key_cache, - value_cache, - input_metadata.block_tables, # [BS, max_block_per_request] - input_metadata.start_loc, - input_metadata.prompt_lens, - input_metadata.context_lens, - input_metadata.max_seq_len, - alibi_slopes, - ) - return output diff --git a/vllm/model_executor/layers/attention/backends/__init__.py b/vllm/model_executor/layers/triton_kernel/__init__.py similarity index 100% rename from vllm/model_executor/layers/attention/backends/__init__.py rename to vllm/model_executor/layers/triton_kernel/__init__.py diff --git a/vllm/model_executor/layers/attention/ops/prefix_prefill.py b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py similarity index 100% rename from vllm/model_executor/layers/attention/ops/prefix_prefill.py rename to vllm/model_executor/layers/triton_kernel/prefix_prefill.py diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 6da0082b94285..550dec6487f9e 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -27,7 +27,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -151,10 +151,10 @@ def __init__( alibi_slopes = alibi_slopes[head_start:head_end].tolist() scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - scaling, - alibi_slopes=alibi_slopes) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes) else: self.rotary_emb = get_rope( self.head_dim, @@ -163,7 +163,8 @@ def __init__( base=self.rope_theta, ) self.scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, self.head_dim, self.scaling) + self.attn = PagedAttention(self.num_heads, self.head_dim, + self.scaling) def forward( self, diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 0548b2b140b1b..4adfb6b78102f 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -25,7 +25,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -107,10 +107,10 @@ def __init__( alibi_slopes = alibi_slopes[head_start:head_end].tolist() scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - scaling, - alibi_slopes=alibi_slopes) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes) def forward( self, diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 1c5dcfacaff2b..dca8d724f976b 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -10,7 +10,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -87,7 +87,7 @@ def __init__( base=10000 * rope_ratio, is_neox_style=False, ) - self.attn = Attention( + self.attn = PagedAttention( self.num_heads, self.head_dim, self.scaling, diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index f2dca3df27cfb..6dba952736921 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -29,7 +29,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -229,10 +229,10 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) def forward( self, diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 3c148be5b10f4..2b5e022312e3b 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -28,7 +28,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -150,10 +150,10 @@ def __init__( max_position=max_position_embeddings, base=rope_theta, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.inv_norm_factor, - num_kv_heads=self.num_kv_heads) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.inv_norm_factor, + num_kv_heads=self.num_kv_heads) elif self.use_alibi: tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads @@ -161,16 +161,16 @@ def __init__( alibi_slopes = (_get_alibi_slopes(self.total_num_heads) * self.inv_norm_factor) alibi_slopes = alibi_slopes[head_start:head_end].tolist() - self.attn = Attention(self.num_heads, - self.head_dim, - self.inv_norm_factor, - num_kv_heads=self.num_kv_heads, - alibi_slopes=alibi_slopes) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.inv_norm_factor, + num_kv_heads=self.num_kv_heads, + alibi_slopes=alibi_slopes) else: - self.attn = Attention(self.num_heads, - self.head_dim, - scale=self.inv_norm_factor, - num_kv_heads=self.num_kv_heads) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + scale=self.inv_norm_factor, + num_kv_heads=self.num_kv_heads) def forward( self, diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 386a36cf492d6..bf1f164ff700d 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -23,7 +23,7 @@ from vllm.config import LoRAConfig from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import GeluAndMul -from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -123,10 +123,10 @@ def __init__(self, base=self.rope_theta, is_neox_style=True, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) def forward( self, diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 3f7b21e5a4133..661da0fe0434e 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -25,7 +25,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -73,7 +73,9 @@ def __init__( bias=True, linear_method=linear_method, ) - self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + scale=self.scale) def forward( self, diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 5c30d47d93e36..ef4c1d4143c88 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -26,7 +26,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -85,10 +85,10 @@ def __init__( bias=True, linear_method=linear_method, ) - self.attn = Attention(self.num_heads, - self.head_dim, - scale=self.scale, - num_kv_heads=self.num_kv_heads) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + scale=self.scale, + num_kv_heads=self.num_kv_heads) def forward( self, diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index b8c6822e9825e..5bab30d9d442e 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -24,7 +24,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -86,7 +86,7 @@ def __init__( base=rope_theta, is_neox_style=False, ) - self.attn = Attention(self.num_heads, self.head_size, scaling) + self.attn = PagedAttention(self.num_heads, self.head_size, scaling) def forward( self, diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 98107350e60b9..8f7e1063e0c1d 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -24,7 +24,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -87,7 +87,7 @@ def __init__( max_position=max_position_embeddings, base=rope_theta, ) - self.attn = Attention(self.num_heads, self.head_size, scaling) + self.attn = PagedAttention(self.num_heads, self.head_size, scaling) def forward( self, diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 0ae0a85643456..ebf1d8a89a022 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -7,7 +7,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -114,10 +114,10 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) def forward( self, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 4c163dfdab537..d35887cc0f6a3 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -30,7 +30,7 @@ from vllm.config import LoRAConfig from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -139,11 +139,11 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=sliding_window) def forward( self, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index d47834e519697..0100624a44d78 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -29,7 +29,7 @@ from vllm.config import LoRAConfig from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -197,7 +197,7 @@ def __init__(self, base=int(self.rope_theta), is_neox_style=True, ) - self.attn = Attention( + self.attn = PagedAttention( self.num_heads, self.head_dim, self.scaling, diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 25c7f1978c0dc..a8dadce24aa1d 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -32,7 +32,7 @@ from transformers import MixtralConfig from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, ReplicatedLinear, @@ -214,7 +214,7 @@ def __init__(self, base=int(self.rope_theta), is_neox_style=True, ) - self.attn = Attention( + self.attn = PagedAttention( self.num_heads, self.head_dim, self.scaling, diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 16ecac3d0529a..22a876e2ef691 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -8,7 +8,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -105,11 +105,11 @@ def __init__( self.head_dim = self.d_model // self.total_num_heads scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - scaling, - alibi_slopes=alibi_slopes, - num_kv_heads=self.num_kv_heads) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes, + num_kv_heads=self.num_kv_heads) def forward( self, diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index fa7a6d850051e..9d563039208c8 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -43,7 +43,7 @@ from torch import nn from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.linear import ( ColumnParallelLinear, LinearMethodBase, @@ -126,9 +126,9 @@ def __init__( base=rope_theta, ) self.scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - scale=self.scaling) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + scale=self.scaling) # Attention output projection. self.attn_out = RowParallelLinear( diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 782f43ce265bd..393b2dcabcd5a 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -25,7 +25,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -89,9 +89,9 @@ def __init__( bias=bias, linear_method=linear_method, ) - self.attn = Attention(self.num_heads, - self.head_dim, - scale=self.scaling) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + scale=self.scaling) def forward( self, diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 6039b1cdc3534..0b067d4fc8802 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -12,7 +12,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, @@ -118,10 +118,10 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) def forward( self, diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 039dc7a9b7675..d143261968288 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -43,7 +43,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -108,7 +108,7 @@ def __init__(self, max_position=max_position_embeddings, base=rope_theta, ) - self.attn = Attention(self.num_heads, self.head_size, scaling) + self.attn = PagedAttention(self.num_heads, self.head_size, scaling) def forward( self, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index d4d5a4e8bb9a5..37af84c7cd53f 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -12,7 +12,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -104,7 +104,7 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, self.head_dim, self.scaling) + self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling) def forward( self, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 3586a7fb82778..e823e6f8c3dbe 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -30,7 +30,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -135,11 +135,11 @@ def __init__(self, max_position=max_position, base=self.rope_theta, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=self.sliding_window) def forward( self, diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index d1a547f815616..44c57e5a6d4f9 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -25,7 +25,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, @@ -122,10 +122,10 @@ def __init__(self, max_position=self.config.max_position_embeddings, base=self.config.rope_theta, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_key_value_heads) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_key_value_heads) def forward( self, diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index efa235233372f..1eda07b724cae 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -25,7 +25,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -103,7 +103,7 @@ def __init__(self, base=int(self.rope_theta), is_neox_style=True, ) - self.attn = Attention( + self.attn = PagedAttention( self.num_heads, self.head_dim, self.scaling,