From 4c86061fbf36a7fb088a379a8643fcb1745924f7 Mon Sep 17 00:00:00 2001 From: Jee Li Date: Wed, 3 Jan 2024 11:09:59 +0800 Subject: [PATCH] [FIX] Support non-zero CUDA devices in custom kernels (#1959) --- csrc/activation_kernels.cu | 5 +++- csrc/attention/attention_kernels.cu | 3 +++ csrc/cache_kernels.cu | 5 ++++ csrc/layernorm_kernels.cu | 3 +++ csrc/pos_encoding_kernels.cu | 2 ++ .../squeezellm/quant_cuda_kernel.cu | 3 ++- tests/kernels/conftest.py | 5 ++-- tests/kernels/test_activation.py | 16 +++++++++--- tests/kernels/test_attention.py | 25 +++++++++++-------- tests/kernels/test_cache.py | 17 ++++++++----- tests/kernels/test_layernorm.py | 9 ++++--- tests/kernels/test_pos_encoding.py | 11 +++++--- 12 files changed, 74 insertions(+), 30 deletions(-) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 1cca2c5fccc18..5ba9ab178d5a4 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -1,5 +1,6 @@ -#include #include +#include +#include #include "cuda_compat.h" #include "dispatch_utils.h" @@ -36,6 +37,7 @@ void silu_and_mul( dim3 grid(num_tokens); dim3 block(std::min(d, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), @@ -71,6 +73,7 @@ __global__ void activation_kernel( int64_t num_tokens = input.numel() / d; \ dim3 grid(num_tokens); \ dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ VLLM_DISPATCH_FLOATING_TYPES( \ input.scalar_type(), \ diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index eff28d3dacd0e..9dcacfbe47d48 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -21,6 +21,7 @@ #include #include +#include #include "attention_dtypes.h" #include "attention_utils.cuh" @@ -616,6 +617,7 @@ void paged_attention_v1_launcher( dim3 grid(num_heads, num_seqs, 1); dim3 block(NUM_THREADS); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); switch (head_size) { // NOTE(woosuk): To reduce the compilation time, we only compile for the @@ -784,6 +786,7 @@ void paged_attention_v2_launcher( int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); dim3 block(NUM_THREADS); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); switch (head_size) { // NOTE(woosuk): To reduce the compilation time, we only compile for the diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 59bacffdf4642..9f173534070a6 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1,5 +1,6 @@ #include #include +#include #include "cuda_compat.h" #include "dispatch_utils.h" @@ -33,6 +34,7 @@ void swap_blocks( char *dst_ptr = static_cast(dst.data_ptr()); const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); + const at::cuda::OptionalCUDAGuard device_guard(src_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // NOTE(woosuk): This can be slow if the number of blocks is large. for (const auto& pair : block_mapping) { @@ -127,6 +129,7 @@ void copy_blocks( const int numel_per_block = key_caches[0][0].numel(); dim3 grid(num_layers, num_pairs); dim3 block(std::min(1024, numel_per_block)); + const at::cuda::OptionalCUDAGuard device_guard(cache_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { @@ -207,6 +210,7 @@ void reshape_and_cache( dim3 grid(num_tokens); dim3 block(std::min(num_heads * head_size, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( key.scalar_type(), @@ -367,6 +371,7 @@ void gather_cached_kv( dim3 grid(num_tokens); dim3 block(std::min(num_heads * head_size, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( key.scalar_type(), diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 7434f4fd7998e..6d34d014c858e 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -1,5 +1,6 @@ #include #include +#include #include "dispatch_utils.h" #include "reduction_utils.cuh" @@ -76,6 +77,7 @@ void rms_norm( dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), @@ -101,6 +103,7 @@ void fused_add_rms_norm( dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 486ebe1d464c8..5f522795619e1 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -1,5 +1,6 @@ #include #include +#include #include "cuda_compat.h" #include "dispatch_utils.h" @@ -94,6 +95,7 @@ void rotary_embedding( dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( query.scalar_type(), diff --git a/csrc/quantization/squeezellm/quant_cuda_kernel.cu b/csrc/quantization/squeezellm/quant_cuda_kernel.cu index 2c37d01e0ae5c..b17ced6fce79b 100644 --- a/csrc/quantization/squeezellm/quant_cuda_kernel.cu +++ b/csrc/quantization/squeezellm/quant_cuda_kernel.cu @@ -7,6 +7,7 @@ // half-tensor #include #include +#include #define BLOCKWIDTH 128 #define BLOCKHEIGHT4 16 @@ -199,7 +200,7 @@ void squeezellm_gemm( (width + BLOCKWIDTH - 1) / BLOCKWIDTH ); dim3 threads(BLOCKWIDTH); - + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); vllm::squeezellm::NUQ4MatMulKernel<<>>( #ifndef USE_ROCM (half2*) vec.data(), diff --git a/tests/kernels/conftest.py b/tests/kernels/conftest.py index 97516bd3052cf..fca97ab76bf09 100644 --- a/tests/kernels/conftest.py +++ b/tests/kernels/conftest.py @@ -12,6 +12,7 @@ def create_kv_caches( head_size: int, dtype: torch.dtype, seed: int, + device: str, ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) @@ -23,7 +24,7 @@ def create_kv_caches( for _ in range(num_layers): key_cache = torch.empty(size=key_cache_shape, dtype=dtype, - device='cuda') + device=device) key_cache.uniform_(-scale, scale) key_caches.append(key_cache) @@ -32,7 +33,7 @@ def create_kv_caches( for _ in range(num_layers): value_cache = torch.empty(size=value_cache_shape, dtype=dtype, - device='cuda') + device=device) value_cache.uniform_(-scale, scale) value_caches.append(value_cache) return key_caches, value_caches diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index ba062054bf406..826bf8350af17 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -7,22 +7,26 @@ NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing D = [512, 4096, 5120, 13824] # Arbitrary values for testing SEEDS = [0] +DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) @torch.inference_mode() def test_silu_and_mul( num_tokens: int, d: int, dtype: torch.dtype, seed: int, + device: int, ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda") + gpu_id = f"cuda:{device}" + x = torch.randn(num_tokens, 2 * d, dtype=dtype, device=gpu_id) layer = SiluAndMul() out = layer(x) ref_out = layer._forward(x) @@ -33,16 +37,19 @@ def test_silu_and_mul( @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) @torch.inference_mode() def test_gelu_new( num_tokens: int, d: int, dtype: torch.dtype, seed: int, + device: int, ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - x = torch.randn(num_tokens, d, dtype=dtype, device="cuda") + gpu_id = f"cuda:{device}" + x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id) layer = NewGELU() out = layer(x) ref_out = layer._forward(x) @@ -53,15 +60,18 @@ def test_gelu_new( @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) def test_gelu_fast( num_tokens: int, d: int, dtype: torch.dtype, seed: int, + device: int, ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - x = torch.randn(num_tokens, d, dtype=dtype, device="cuda") + gpu_id = f"cuda:{device}" + x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id) layer = FastGELU() out = layer(x) ref_out = layer._forward(x) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 614b65f82ccbd..814d40f56def0 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -24,6 +24,7 @@ BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] SEEDS = [0] +DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] def ref_masked_attention( @@ -87,7 +88,7 @@ def ref_single_query_cached_kv_attention( alibi_bias = None if alibi_slopes is not None: # Create the ALiBi bias used in the paged attention kernel. - position_ids = torch.arange(context_len, device="cuda").int() + position_ids = torch.arange(context_len, device=query.device).int() alibi_bias = (position_ids - context_len + 1).float() alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( 1, 1, -1) @@ -105,6 +106,7 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) def test_paged_attention( kv_cache_factory, version: str, @@ -115,18 +117,19 @@ def test_paged_attention( block_size: int, dtype: torch.dtype, seed: int, + device: int, ) -> None: random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - + gpu_id = f"cuda:{device}" scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype, - device="cuda") + device=gpu_id) query.uniform_(-scale, scale) assert num_query_heads % num_kv_heads == 0 @@ -135,12 +138,12 @@ def test_paged_attention( if use_alibi: alibi_slopes = torch.randn(num_query_heads, dtype=torch.float, - device="cuda") + device=gpu_id) context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] context_lens[-1] = MAX_SEQ_LEN max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") + context_lens = torch.tensor(context_lens, dtype=torch.int, device=gpu_id) # Create the block tables. max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size @@ -151,12 +154,12 @@ def test_paged_attention( for _ in range(max_num_blocks_per_seq) ] block_tables.append(block_table) - block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + block_tables = torch.tensor(block_tables, dtype=torch.int, device=gpu_id) # Create the KV caches. key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, dtype, - seed) + seed, gpu_id) key_cache, value_cache = key_caches[0], value_caches[0] # Call the paged attention kernel. @@ -249,7 +252,7 @@ def ref_multi_query_kv_attention( attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1) attn_mask = attn_mask * torch.finfo(dtype).min - attn_mask = attn_mask.to(dtype=dtype, device="cuda") + attn_mask = attn_mask.to(dtype=dtype, device=query.device) ref_output = ref_masked_attention( query[start_idx:end_idx], @@ -269,6 +272,7 @@ def ref_multi_query_kv_attention( @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) @torch.inference_mode() def test_multi_query_kv_attention( num_seqs: int, @@ -276,11 +280,12 @@ def test_multi_query_kv_attention( head_size: int, dtype: torch.dtype, seed: int, + device: int, ) -> None: random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - + gpu_id = f"cuda:{device}" # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. # As the xformers library is already tested with its own tests, we can use # a smaller MAX_SEQ_LEN here. @@ -294,7 +299,7 @@ def test_multi_query_kv_attention( num_query_heads + 2 * num_kv_heads, head_size, dtype=dtype, - device="cuda") + device=gpu_id) qkv.uniform_(-scale, scale) query, key, value = qkv.split( [num_query_heads, num_kv_heads, num_kv_heads], dim=1) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 9b5d7687a3fec..1d8d41e013b03 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -14,6 +14,7 @@ NUM_BLOCKS = [1024, 36000] # Arbitrary values for testing NUM_MAPPINGS = [256] # Arbitrary values for testing SEEDS = [0] +DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @@ -24,6 +25,7 @@ @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) @torch.inference_mode() def test_copy_blocks( kv_cache_factory, @@ -35,11 +37,12 @@ def test_copy_blocks( num_blocks: int, dtype: torch.dtype, seed: int, + device: int, ) -> None: random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - + gpu_id = f"cuda:{device}" # Generate random block mappings where each source block is mapped to two # destination blocks. assert 2 * num_mappings <= num_blocks @@ -56,7 +59,7 @@ def test_copy_blocks( # Create the KV caches. key_caches, value_caches = kv_cache_factory(num_blocks, block_size, num_layers, num_heads, - head_size, dtype, seed) + head_size, dtype, seed, gpu_id) # Clone the KV caches. cloned_key_caches = [key_cache.clone() for key_cache in key_caches] @@ -88,6 +91,7 @@ def test_copy_blocks( @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) @torch.inference_mode() def test_reshape_and_cache( kv_cache_factory, @@ -98,28 +102,29 @@ def test_reshape_and_cache( num_blocks: int, dtype: torch.dtype, seed: int, + device: int, ) -> None: random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - + gpu_id = f"cuda:{device}" # Create a random slot mapping. num_slots = block_size * num_blocks slot_mapping = random.sample(range(num_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device="cuda") + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=gpu_id) qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, - device="cuda") + device=gpu_id) _, key, value = qkv.unbind(dim=1) # Create the KV caches. key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1, num_heads, head_size, dtype, - seed) + seed, gpu_id) key_cache, value_cache = key_caches[0], value_caches[0] # Clone the KV caches. diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index b362e2c43f0da..8a06b3aa268be 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -8,6 +8,7 @@ HIDDEN_SIZES = [768, 5120, 8192] # Arbitrary values for testing ADD_RESIDUAL = [False, True] SEEDS = [0] +DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -15,6 +16,7 @@ @pytest.mark.parametrize("add_residual", ADD_RESIDUAL) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) @torch.inference_mode() def test_rms_norm( num_tokens: int, @@ -22,14 +24,15 @@ def test_rms_norm( add_residual: bool, dtype: torch.dtype, seed: int, + device: int, ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - - layer = RMSNorm(hidden_size).to(dtype).cuda() + gpu_id = f"cuda:{device}" + layer = RMSNorm(hidden_size).to(dtype=dtype, device=gpu_id) layer.weight.data.normal_(mean=1.0, std=0.1) scale = 1 / (2 * hidden_size) - x = torch.randn(num_tokens, hidden_size, dtype=dtype, device="cuda") + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=gpu_id) x *= scale residual = torch.randn_like(x) * scale if add_residual else None diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 25d6bf2378cad..aad310e2bc6d2 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -13,6 +13,7 @@ BATCH_SIZES = [1, 5] # Arbitrary values for testing SEQ_LENS = [11, 8192] # Arbitrary values for testing SEEDS = [0] +DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @@ -23,6 +24,7 @@ @pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) @torch.inference_mode() def test_rotary_embedding( is_neox_style: bool, @@ -33,6 +35,7 @@ def test_rotary_embedding( rotary_dim: Optional[int], dtype: torch.dtype, seed: int, + device: int, max_position: int = 8192, base: int = 10000, ) -> None: @@ -40,20 +43,20 @@ def test_rotary_embedding( rotary_dim = head_size torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - + gpu_id = f"cuda:{device}" if rotary_dim is None: rotary_dim = head_size rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style) - rope = rope.to(dtype).cuda() + rope = rope.to(dtype=dtype, device=gpu_id) positions = torch.randint(0, max_position, (batch_size, seq_len), - device="cuda") + device=gpu_id) query = torch.randn(batch_size, seq_len, num_heads * head_size, dtype=dtype, - device="cuda") + device=gpu_id) key = torch.randn_like(query) # NOTE(woosuk): The reference implementation should be executed first