Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Fix kernel bug #1959

Merged
merged 11 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion csrc/activation_kernels.cu
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>

#include "cuda_compat.h"
#include "dispatch_utils.h"
Expand Down Expand Up @@ -36,6 +37,7 @@ void silu_and_mul(

dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
Expand Down Expand Up @@ -71,6 +73,7 @@ __global__ void activation_kernel(
int64_t num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
Expand Down
3 changes: 3 additions & 0 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include "attention_dtypes.h"
#include "attention_utils.cuh"
Expand Down Expand Up @@ -616,6 +617,7 @@ void paged_attention_v1_launcher(

dim3 grid(num_heads, num_seqs, 1);
dim3 block(NUM_THREADS);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (head_size) {
// NOTE(woosuk): To reduce the compilation time, we only compile for the
Expand Down Expand Up @@ -784,6 +786,7 @@ void paged_attention_v2_launcher(
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);

dim3 block(NUM_THREADS);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (head_size) {
// NOTE(woosuk): To reduce the compilation time, we only compile for the
Expand Down
5 changes: 5 additions & 0 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include "cuda_compat.h"
#include "dispatch_utils.h"
Expand Down Expand Up @@ -33,6 +34,7 @@ void swap_blocks(
char *dst_ptr = static_cast<char*>(dst.data_ptr());

const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const at::cuda::OptionalCUDAGuard device_guard(src_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// NOTE(woosuk): This can be slow if the number of blocks is large.
for (const auto& pair : block_mapping) {
Expand Down Expand Up @@ -127,6 +129,7 @@ void copy_blocks(
const int numel_per_block = key_caches[0][0].numel();
dim3 grid(num_layers, num_pairs);
dim3 block(std::min(1024, numel_per_block));
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
Expand Down Expand Up @@ -207,6 +210,7 @@ void reshape_and_cache(

dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(),
Expand Down Expand Up @@ -367,6 +371,7 @@ void gather_cached_kv(

dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(),
Expand Down
3 changes: 3 additions & 0 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include "dispatch_utils.h"
#include "reduction_utils.cuh"
Expand Down Expand Up @@ -76,6 +77,7 @@ void rms_norm(

dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
Expand All @@ -101,6 +103,7 @@ void fused_add_rms_norm(

dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
Expand Down
2 changes: 2 additions & 0 deletions csrc/pos_encoding_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include "cuda_compat.h"
#include "dispatch_utils.h"
Expand Down Expand Up @@ -94,6 +95,7 @@ void rotary_embedding(

dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(),
Expand Down
3 changes: 2 additions & 1 deletion csrc/quantization/squeezellm/quant_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// half-tensor
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDATensorMethods.cuh>
#include <c10/cuda/CUDAGuard.h>

#define BLOCKWIDTH 128
#define BLOCKHEIGHT4 16
Expand Down Expand Up @@ -199,7 +200,7 @@ void squeezellm_gemm(
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
);
dim3 threads(BLOCKWIDTH);

const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
#ifndef USE_ROCM
(half2*) vec.data<at::Half>(),
Expand Down
5 changes: 3 additions & 2 deletions tests/kernels/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def create_kv_caches(
head_size: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
Expand All @@ -23,7 +24,7 @@ def create_kv_caches(
for _ in range(num_layers):
key_cache = torch.empty(size=key_cache_shape,
dtype=dtype,
device='cuda')
device=device)
key_cache.uniform_(-scale, scale)
key_caches.append(key_cache)

Expand All @@ -32,7 +33,7 @@ def create_kv_caches(
for _ in range(num_layers):
value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device='cuda')
device=device)
value_cache.uniform_(-scale, scale)
value_caches.append(value_cache)
return key_caches, value_caches
Expand Down
16 changes: 13 additions & 3 deletions tests/kernels/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,26 @@
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
D = [512, 4096, 5120, 13824] # Arbitrary values for testing
SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_silu_and_mul(
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: int,
) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda")
gpu_id = f"cuda:{device}"
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device=gpu_id)
layer = SiluAndMul()
out = layer(x)
ref_out = layer._forward(x)
Expand All @@ -33,16 +37,19 @@ def test_silu_and_mul(
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_gelu_new(
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: int,
) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
gpu_id = f"cuda:{device}"
x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id)
layer = NewGELU()
out = layer(x)
ref_out = layer._forward(x)
Expand All @@ -53,15 +60,18 @@ def test_gelu_new(
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
def test_gelu_fast(
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: int,
) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
gpu_id = f"cuda:{device}"
x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id)
layer = FastGELU()
out = layer(x)
ref_out = layer._forward(x)
Expand Down
25 changes: 15 additions & 10 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True]
SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]


def ref_masked_attention(
Expand Down Expand Up @@ -87,7 +88,7 @@ def ref_single_query_cached_kv_attention(
alibi_bias = None
if alibi_slopes is not None:
# Create the ALiBi bias used in the paged attention kernel.
position_ids = torch.arange(context_len, device="cuda").int()
position_ids = torch.arange(context_len, device=query.device).int()
alibi_bias = (position_ids - context_len + 1).float()
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
1, 1, -1)
Expand All @@ -105,6 +106,7 @@ def ref_single_query_cached_kv_attention(
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
def test_paged_attention(
kv_cache_factory,
version: str,
Expand All @@ -115,18 +117,19 @@ def test_paged_attention(
block_size: int,
dtype: torch.dtype,
seed: int,
device: int,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)

gpu_id = f"cuda:{device}"
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
query = torch.empty(num_seqs,
num_query_heads,
head_size,
dtype=dtype,
device="cuda")
device=gpu_id)
query.uniform_(-scale, scale)

assert num_query_heads % num_kv_heads == 0
Expand All @@ -135,12 +138,12 @@ def test_paged_attention(
if use_alibi:
alibi_slopes = torch.randn(num_query_heads,
dtype=torch.float,
device="cuda")
device=gpu_id)

context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
context_lens[-1] = MAX_SEQ_LEN
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
context_lens = torch.tensor(context_lens, dtype=torch.int, device=gpu_id)

# Create the block tables.
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
Expand All @@ -151,12 +154,12 @@ def test_paged_attention(
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
block_tables = torch.tensor(block_tables, dtype=torch.int, device=gpu_id)

# Create the KV caches.
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
num_kv_heads, head_size, dtype,
seed)
seed, gpu_id)
key_cache, value_cache = key_caches[0], value_caches[0]

# Call the paged attention kernel.
Expand Down Expand Up @@ -249,7 +252,7 @@ def ref_multi_query_kv_attention(
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
diagonal=1)
attn_mask = attn_mask * torch.finfo(dtype).min
attn_mask = attn_mask.to(dtype=dtype, device="cuda")
attn_mask = attn_mask.to(dtype=dtype, device=query.device)

ref_output = ref_masked_attention(
query[start_idx:end_idx],
Expand All @@ -269,18 +272,20 @@ def ref_multi_query_kv_attention(
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_multi_query_kv_attention(
num_seqs: int,
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
seed: int,
device: int,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)

gpu_id = f"cuda:{device}"
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# As the xformers library is already tested with its own tests, we can use
# a smaller MAX_SEQ_LEN here.
Expand All @@ -294,7 +299,7 @@ def test_multi_query_kv_attention(
num_query_heads + 2 * num_kv_heads,
head_size,
dtype=dtype,
device="cuda")
device=gpu_id)
qkv.uniform_(-scale, scale)
query, key, value = qkv.split(
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
Expand Down
Loading
Loading