Skip to content

[Inference]Add CUDA KVCache Kernel #5406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 32 additions & 12 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
decoding_fused_rotary_embedding,
Expand All @@ -22,6 +23,8 @@
)
from colossalai.logging import get_dist_logger

inference_ops = InferenceOpsLoader().load()

logger = get_dist_logger(__name__)

try:
Expand Down Expand Up @@ -74,6 +77,12 @@ def llama_model_forward(
sequence_lengths = batch.get_sequence_lengths()
batch_size = batch.current_batch_size
kv_seq_len = sequence_lengths.max().item()
use_cuda_kernel = True
# NOTE: After testing, the performance of this configuration is relatively good. With updates
# and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's
# selection should be conducted.
if batch_size >= 32 and kv_seq_len > 512:
use_cuda_kernel = False

hidden_states = self.embed_tokens(input_ids)

Expand Down Expand Up @@ -107,6 +116,7 @@ def llama_model_forward(
output_tensor=output_tensor,
norm_output=norm_output,
sm_scale=sm_scale,
use_cuda_kernel=use_cuda_kernel,
)

if batch.is_prompts:
Expand Down Expand Up @@ -134,6 +144,7 @@ def llama_decoder_layer_forward(
output_tensor: torch.Tensor = None,
norm_output: torch.Tensor = None,
sm_scale: int = None,
use_cuda_kernel: bool = True,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""This function will replace the forward function of LlamaDecoderLayer.

Expand All @@ -153,6 +164,7 @@ def llama_decoder_layer_forward(
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
"""

hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual)
Expand All @@ -169,6 +181,7 @@ def llama_decoder_layer_forward(
fd_inter_tensor=fd_inter_tensor,
output_tensor=output_tensor,
sm_scale=sm_scale,
use_cuda_kernel=use_cuda_kernel,
)

# Fully Connected
Expand Down Expand Up @@ -252,6 +265,7 @@ def forward(
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
use_cuda_kernel: bool = True,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Args:
Expand All @@ -268,6 +282,7 @@ def forward(
storing intermediate values in flash-decoding. Defaults to None.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
"""

if self.num_heads != self.num_key_value_heads:
Expand All @@ -283,7 +298,6 @@ def forward(
)

block_size = k_cache.size(-2)

if is_prompts:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
attn_output = context_attention_unpadded(
Expand All @@ -300,17 +314,23 @@ def forward(
sm_scale=sm_scale,
)
else:
decoding_fused_rotary_embedding(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
block_tables,
sequence_lengths,
)
if use_cuda_kernel:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
inference_ops.decode_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
)
else:
decoding_fused_rotary_embedding(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
block_tables,
sequence_lengths,
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
Expand Down
6 changes: 6 additions & 0 deletions colossalai/kernel/kernel_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
FlashAttentionNpuExtension,
FlashAttentionXformersCudaExtension,
FusedOptimizerCudaExtension,
InferenceOpsCudaExtension,
LayerNormCudaExtension,
MoeCudaExtension,
ScaledMaskedSoftmaxCudaExtension,
Expand All @@ -21,6 +22,7 @@
"LayerNormLoader",
"MoeLoader",
"FusedOptimizerLoader",
"InferenceOpsLoader",
"ScaledMaskedSoftmaxLoader",
"ScaledUpperTriangleMaskedSoftmaxLoader",
]
Expand Down Expand Up @@ -97,6 +99,10 @@ class FusedOptimizerLoader(KernelLoader):
REGISTRY = [FusedOptimizerCudaExtension]


class InferenceOpsLoader(KernelLoader):
REGISTRY = [InferenceOpsCudaExtension]


class ScaledMaskedSoftmaxLoader(KernelLoader):
REGISTRY = [ScaledMaskedSoftmaxCudaExtension]

Expand Down
80 changes: 80 additions & 0 deletions examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch

from colossalai.inference.modeling.layers.attention import copy_to_cache
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import copy_kv_to_blocked_cache
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data

try:
import triton # noqa
except ImportError:
print("please install triton from https://github.com/openai/triton")

inference_ops = InferenceOpsLoader().load()

HEAD_DIM = 4
BATCH = 16
BLOCK_SIZE = 32
SAME_LEN = True
WARM_UPS = 10
REPS = 100
configs = [
triton.testing.Benchmark(
x_names=["KV_SEQ_LEN"],
x_vals=[2**i for i in range(8, 13)],
line_arg="provider",
line_vals=["torch_copy_func", "triton_copy_func", "cuda_copy_func"],
line_names=["torch_copy_func", "triton_copy_func", "cuda_copy_func"],
styles=[("red", "-"), ("blue", "-"), ("green", "-")],
ylabel="ms",
plot_name=f"kvcache_copy_decoding_stage-batch-{BATCH}",
args={"bsz": BATCH, "block_size": 16, "max_seq_len": 8192, "num_kv_heads": 16, "same_context_len": True},
)
]


@triton.testing.perf_report(configs)
def benchmark_kvcache_copy(
provider: str,
bsz: int,
block_size: int,
max_seq_len: int,
KV_SEQ_LEN: int, # maximum past kv length (unequal context lens in batch) or past kv len (equal context lens)
num_kv_heads: int,
same_context_len: bool,
):
dtype = torch.float32
device = get_current_device()

assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len"

new_k, new_v, k_cache, v_cache, context_lengths, block_tables = prepare_data(
bsz,
num_kv_heads,
HEAD_DIM,
block_size,
max_seq_len // block_size,
same_context_len,
KV_SEQ_LEN,
device=device,
dtype=dtype,
)

quantiles = [0.5, 0.2, 0.8]
# TODO copy_to_cache needs to support copying both k and v at the same time in the future.
if provider == "torch_copy_func":
fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding")
elif provider == "triton_copy_func":
fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)
elif provider == "cuda_copy_func":
new_k = new_k.squeeze(1) if new_k.dim() == 4 else new_k
new_v = new_v.squeeze(1) if new_v.dim() == 4 else new_v
fn = lambda: inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)

ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
return ms, min_ms, max_ms


if __name__ == "__main__":
benchmark_kvcache_copy.run(save_path=".", print_data=True)
3 changes: 3 additions & 0 deletions extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
FlashAttentionNpuExtension,
FlashAttentionXformersCudaExtension,
)
from .inference import InferenceOpsCudaExtension
from .layernorm import LayerNormCudaExtension
from .moe import MoeCudaExtension
from .optimizer import FusedOptimizerCudaExtension
Expand All @@ -15,6 +16,7 @@
LayerNormCudaExtension,
MoeCudaExtension,
FusedOptimizerCudaExtension,
InferenceOpsCudaExtension,
ScaledMaskedSoftmaxCudaExtension,
ScaledUpperTriangleMaskedSoftmaxCudaExtension,
FlashAttentionDaoCudaExtension,
Expand All @@ -28,6 +30,7 @@
"LayerNormCudaExtension",
"MoeCudaExtension",
"FusedOptimizerCudaExtension",
"InferenceOpsCudaExtension",
"ScaledMaskedSoftmaxCudaExtension",
"ScaledUpperTriangleMaskedSoftmaxCudaExtension",
"FlashAttentionDaoCudaExtension",
Expand Down
15 changes: 15 additions & 0 deletions extensions/csrc/cuda/colossal_inference_C_frontend.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include <torch/extension.h>

void decode_kv_cache_memcpy(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor&
value_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& block_tables); // [batch_size, max_seq_len]

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
"Copy the GPU memory of kvcache during the decode stage.");
}
90 changes: 90 additions & 0 deletions extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <stdio.h>

#include "type_shim.h"

template<typename scalar_t>
__global__ void decode_kv_cache_memcpy_kernel(
const scalar_t* __restrict__ key,
const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache,
scalar_t* __restrict__ value_cache,
const int* __restrict__ sequence_lengths,
const int* __restrict__ block_tables,
const int num_heads,
const int head_size,
const int block_size,
const int key_stride,
const int value_stride,
const int block_table_stride
)
{
const int seq_id = blockIdx.x;
const int seq_len = sequence_lengths[seq_id] - 1;
const int seq_id_in_block_table = seq_len / block_size;
const int block_offset = seq_len % block_size;
const int block_id = block_tables[seq_id * block_table_stride + seq_id_in_block_table];
const int hidden_size = num_heads * head_size;

if ( block_id < 0 ) {
return ;
}

for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
const int head_id = i / head_size;
const int head_offset = i % head_size;
const int key_src_id = seq_id * key_stride + i;
const int value_src_id = seq_id * value_stride + i;
const int target_src_id = block_id * hidden_size * block_size
+ head_id * block_size * head_size
+ block_offset * head_size + head_offset;

key_cache[target_src_id] = key[key_src_id];
value_cache[target_src_id] = value[value_src_id];
}

}

void decode_kv_cache_memcpy(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor& value_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& block_tables) // [batch_size, max_seq_len]
{
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(2);

int key_stride = key.stride(0);
int value_stride = value.stride(0);
int block_table_stride = block_tables.stride(0);

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
DISPATCH_FLOAT_HALF_AND_BFLOAT(
key.scalar_type(),
"decode_kv_cache_memcpy",
decode_kv_cache_memcpy_kernel<scalar_t><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
num_heads,
head_size,
block_size,
key_stride,
value_stride,
block_table_stride
);)

AT_CUDA_CHECK(cudaGetLastError());

}
21 changes: 21 additions & 0 deletions extensions/csrc/cuda/type_shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,27 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}

#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}

#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) { \
case at::ScalarType::Float: { \
Expand Down
3 changes: 3 additions & 0 deletions extensions/cuda_extension.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
import time
from abc import abstractmethod
from pathlib import Path
from typing import List

from .base_extension import _Extension
from .cpp_extension import _CppExtension
from .utils import check_pytorch_version, check_system_pytorch_cuda_match, set_cuda_arch_list

Expand Down
Loading