Skip to content

Commit

Permalink
Allocate more shared memory to attention kernel (vllm-project#1154)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 authored Sep 27, 2023
1 parent 03ffd0a commit cf5cb1e
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 3 deletions.
5 changes: 5 additions & 0 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ __global__ void single_query_cached_kv_attention_kernel(
} // namespace vllm

#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
cudaFuncSetAttribute( \
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
<<<grid, block, shared_mem_size, stream>>>( \
out_ptr, \
Expand Down Expand Up @@ -401,6 +404,8 @@ void single_query_cached_kv_attention_launcher(
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_context_len * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
int shared_mem_size = std::max(logits_size, outputs_size);

dim3 grid(num_heads, num_seqs);
Expand Down
13 changes: 13 additions & 0 deletions csrc/cuda_utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#include <torch/extension.h>

int get_device_attribute(
int attribute,
int device_id);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"get_device_attribute",
&get_device_attribute,
"Gets the specified device attribute.");
}

14 changes: 14 additions & 0 deletions csrc/cuda_utils_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
int get_device_attribute(
int attribute,
int device_id)
{
int device, value;
if (device_id < 0) {
cudaGetDevice(&device);
}
else {
device = device_id;
}
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
return value;
}
11 changes: 11 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,17 @@ def get_torch_arch_list() -> Set[str]:
)
ext_modules.append(quantization_extension)

# Misc. CUDA utils.
cuda_utils_extension = CUDAExtension(
name="vllm.cuda_utils",
sources=["csrc/cuda_utils.cpp", "csrc/cuda_utils_kernels.cu"],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(cuda_utils_extension)


def get_path(*filepath) -> str:
return os.path.join(ROOT_DIR, *filepath)
Expand Down
8 changes: 7 additions & 1 deletion tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask

from vllm import attention_ops
from vllm.utils import get_max_shared_memory_bytes

MAX_SEQ_LEN = 8192
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
NUM_BLOCKS = 128 # Arbitrary values for testing

DTYPES = [torch.half, torch.bfloat16, torch.float]
Expand Down Expand Up @@ -135,6 +139,7 @@ def test_single_query_cached_kv_attention(
device="cuda")

context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
context_lens[-1] = MAX_SEQ_LEN
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")

Expand Down Expand Up @@ -243,6 +248,7 @@ def test_multi_query_kv_attention(
torch.cuda.manual_seed(seed)

seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
seq_lens[-1] = MAX_SEQ_LEN
num_tokens = sum(seq_lens)

scale = float(1.0 / (head_size**0.5))
Expand Down
13 changes: 12 additions & 1 deletion vllm/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import enum
from platform import uname
import uuid
from platform import uname

import psutil
import torch

from vllm import cuda_utils


class Device(enum.Enum):
GPU = enum.auto()
Expand All @@ -25,6 +27,15 @@ def reset(self) -> None:
self.counter = 0


def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes."""
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 # pylint: disable=invalid-name
max_shared_mem = cuda_utils.get_device_attribute(
cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu)
return int(max_shared_mem)


def get_gpu_memory(gpu: int = 0) -> int:
"""Returns the total memory of the GPU in bytes."""
return torch.cuda.get_device_properties(gpu).total_memory
Expand Down
26 changes: 25 additions & 1 deletion vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine
from vllm.utils import get_gpu_memory
from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes


class Worker:
Expand Down Expand Up @@ -136,6 +136,10 @@ def profile_num_available_blocks(
def init_cache_engine(self, cache_config: CacheConfig) -> None:
self.cache_config = cache_config
self.block_size = cache_config.block_size

_check_if_can_support_max_seq_len(self.scheduler_config.max_model_len,
self.block_size)

self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.parallel_config)
self.cache_events = self.cache_engine.events
Expand Down Expand Up @@ -347,3 +351,23 @@ def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:

def _pad_to_max(x: List[int], max_len: int) -> List[int]:
return x + [0] * (max_len - len(x))


def _check_if_can_support_max_seq_len(max_seq_len: int,
block_size: int) -> None:
# Follows the logic in
# attention_kernels.cu::single_query_cached_kv_attention_launcher
max_shared_mem = get_max_shared_memory_bytes()
float32_bytes = torch.finfo(torch.float).bits // 8
padded_max_seq_len = (
(max_seq_len + block_size - 1) / block_size) * block_size
# padded_max_seq_len + extra buffer
required_shared_mem = (padded_max_seq_len + 512) * float32_bytes
if padded_max_seq_len * float32_bytes > max_shared_mem:
raise RuntimeError(
f"vLLM cannot currently support max_model_len={max_seq_len} "
f"with block_size={block_size} on GPU with compute "
f"capability {torch.cuda.get_device_capability()} "
f"(required shared memory {required_shared_mem} > "
f"available shared memory {max_shared_mem}). "
"This will be fixed in a future release.")

0 comments on commit cf5cb1e

Please sign in to comment.