Skip to content

Commit

Permalink
[FIX] Support non-zero CUDA devices in custom kernels (vllm-project#1959
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jeejeelee authored Jan 3, 2024
1 parent 22ec576 commit 4c86061
Show file tree
Hide file tree
Showing 12 changed files with 74 additions and 30 deletions.
5 changes: 4 additions & 1 deletion csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>

#include "cuda_compat.h"
#include "dispatch_utils.h"
Expand Down Expand Up @@ -36,6 +37,7 @@ void silu_and_mul(

dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
Expand Down Expand Up @@ -71,6 +73,7 @@ __global__ void activation_kernel(
int64_t num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
Expand Down
3 changes: 3 additions & 0 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include "attention_dtypes.h"
#include "attention_utils.cuh"
Expand Down Expand Up @@ -616,6 +617,7 @@ void paged_attention_v1_launcher(

dim3 grid(num_heads, num_seqs, 1);
dim3 block(NUM_THREADS);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (head_size) {
// NOTE(woosuk): To reduce the compilation time, we only compile for the
Expand Down Expand Up @@ -784,6 +786,7 @@ void paged_attention_v2_launcher(
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);

dim3 block(NUM_THREADS);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (head_size) {
// NOTE(woosuk): To reduce the compilation time, we only compile for the
Expand Down
5 changes: 5 additions & 0 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include "cuda_compat.h"
#include "dispatch_utils.h"
Expand Down Expand Up @@ -33,6 +34,7 @@ void swap_blocks(
char *dst_ptr = static_cast<char*>(dst.data_ptr());

const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const at::cuda::OptionalCUDAGuard device_guard(src_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// NOTE(woosuk): This can be slow if the number of blocks is large.
for (const auto& pair : block_mapping) {
Expand Down Expand Up @@ -127,6 +129,7 @@ void copy_blocks(
const int numel_per_block = key_caches[0][0].numel();
dim3 grid(num_layers, num_pairs);
dim3 block(std::min(1024, numel_per_block));
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
Expand Down Expand Up @@ -207,6 +210,7 @@ void reshape_and_cache(

dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(),
Expand Down Expand Up @@ -367,6 +371,7 @@ void gather_cached_kv(

dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(),
Expand Down
3 changes: 3 additions & 0 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include "dispatch_utils.h"
#include "reduction_utils.cuh"
Expand Down Expand Up @@ -76,6 +77,7 @@ void rms_norm(

dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
Expand All @@ -101,6 +103,7 @@ void fused_add_rms_norm(

dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
Expand Down
2 changes: 2 additions & 0 deletions csrc/pos_encoding_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include "cuda_compat.h"
#include "dispatch_utils.h"
Expand Down Expand Up @@ -94,6 +95,7 @@ void rotary_embedding(

dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(),
Expand Down
3 changes: 2 additions & 1 deletion csrc/quantization/squeezellm/quant_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// half-tensor
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDATensorMethods.cuh>
#include <c10/cuda/CUDAGuard.h>

#define BLOCKWIDTH 128
#define BLOCKHEIGHT4 16
Expand Down Expand Up @@ -199,7 +200,7 @@ void squeezellm_gemm(
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
);
dim3 threads(BLOCKWIDTH);

const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
#ifndef USE_ROCM
(half2*) vec.data<at::Half>(),
Expand Down
5 changes: 3 additions & 2 deletions tests/kernels/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def create_kv_caches(
head_size: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
Expand All @@ -23,7 +24,7 @@ def create_kv_caches(
for _ in range(num_layers):
key_cache = torch.empty(size=key_cache_shape,
dtype=dtype,
device='cuda')
device=device)
key_cache.uniform_(-scale, scale)
key_caches.append(key_cache)

Expand All @@ -32,7 +33,7 @@ def create_kv_caches(
for _ in range(num_layers):
value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device='cuda')
device=device)
value_cache.uniform_(-scale, scale)
value_caches.append(value_cache)
return key_caches, value_caches
Expand Down
16 changes: 13 additions & 3 deletions tests/kernels/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,26 @@
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
D = [512, 4096, 5120, 13824] # Arbitrary values for testing
SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_silu_and_mul(
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: int,
) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda")
gpu_id = f"cuda:{device}"
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device=gpu_id)
layer = SiluAndMul()
out = layer(x)
ref_out = layer._forward(x)
Expand All @@ -33,16 +37,19 @@ def test_silu_and_mul(
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_gelu_new(
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: int,
) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
gpu_id = f"cuda:{device}"
x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id)
layer = NewGELU()
out = layer(x)
ref_out = layer._forward(x)
Expand All @@ -53,15 +60,18 @@ def test_gelu_new(
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
def test_gelu_fast(
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: int,
) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
gpu_id = f"cuda:{device}"
x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id)
layer = FastGELU()
out = layer(x)
ref_out = layer._forward(x)
Expand Down
25 changes: 15 additions & 10 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True]
SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]


def ref_masked_attention(
Expand Down Expand Up @@ -87,7 +88,7 @@ def ref_single_query_cached_kv_attention(
alibi_bias = None
if alibi_slopes is not None:
# Create the ALiBi bias used in the paged attention kernel.
position_ids = torch.arange(context_len, device="cuda").int()
position_ids = torch.arange(context_len, device=query.device).int()
alibi_bias = (position_ids - context_len + 1).float()
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
1, 1, -1)
Expand All @@ -105,6 +106,7 @@ def ref_single_query_cached_kv_attention(
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
def test_paged_attention(
kv_cache_factory,
version: str,
Expand All @@ -115,18 +117,19 @@ def test_paged_attention(
block_size: int,
dtype: torch.dtype,
seed: int,
device: int,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)

gpu_id = f"cuda:{device}"
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
query = torch.empty(num_seqs,
num_query_heads,
head_size,
dtype=dtype,
device="cuda")
device=gpu_id)
query.uniform_(-scale, scale)

assert num_query_heads % num_kv_heads == 0
Expand All @@ -135,12 +138,12 @@ def test_paged_attention(
if use_alibi:
alibi_slopes = torch.randn(num_query_heads,
dtype=torch.float,
device="cuda")
device=gpu_id)

context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
context_lens[-1] = MAX_SEQ_LEN
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
context_lens = torch.tensor(context_lens, dtype=torch.int, device=gpu_id)

# Create the block tables.
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
Expand All @@ -151,12 +154,12 @@ def test_paged_attention(
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
block_tables = torch.tensor(block_tables, dtype=torch.int, device=gpu_id)

# Create the KV caches.
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
num_kv_heads, head_size, dtype,
seed)
seed, gpu_id)
key_cache, value_cache = key_caches[0], value_caches[0]

# Call the paged attention kernel.
Expand Down Expand Up @@ -249,7 +252,7 @@ def ref_multi_query_kv_attention(
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
diagonal=1)
attn_mask = attn_mask * torch.finfo(dtype).min
attn_mask = attn_mask.to(dtype=dtype, device="cuda")
attn_mask = attn_mask.to(dtype=dtype, device=query.device)

ref_output = ref_masked_attention(
query[start_idx:end_idx],
Expand All @@ -269,18 +272,20 @@ def ref_multi_query_kv_attention(
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_multi_query_kv_attention(
num_seqs: int,
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
seed: int,
device: int,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)

gpu_id = f"cuda:{device}"
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# As the xformers library is already tested with its own tests, we can use
# a smaller MAX_SEQ_LEN here.
Expand All @@ -294,7 +299,7 @@ def test_multi_query_kv_attention(
num_query_heads + 2 * num_kv_heads,
head_size,
dtype=dtype,
device="cuda")
device=gpu_id)
qkv.uniform_(-scale, scale)
query, key, value = qkv.split(
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
Expand Down
Loading

0 comments on commit 4c86061

Please sign in to comment.