Skip to content

Commit 04aca9e

Browse files
authored
[Inference/Kernel]Add get_cos_and_sin Kernel (#5528)
* Add get_cos_and_sin kernel * fix code comments * fix code typos * merge common codes of get_cos_and_sin kernel. * Fixed a typo * Changed 'asset allclose' to 'assert equal'.
1 parent 934e31a commit 04aca9e

File tree

5 files changed

+295
-6
lines changed

5 files changed

+295
-6
lines changed

colossalai/inference/modeling/models/nopadding_llama.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,22 @@ def llama_model_forward(
101101
use_cuda_kernel = False
102102

103103
hidden_states = self.embed_tokens(input_tokens_ids)
104-
if use_cuda_kernel and inputmetadata != torch.float32 and use_flash_attn2:
105-
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
104+
if use_cuda_kernel:
105+
if inputmetadata != torch.float32 and use_flash_attn2:
106+
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
107+
108+
hidden_dim = self._cos_cached.size(-1)
109+
total_length = hidden_states.size(0)
110+
cos = torch.empty((total_length, hidden_dim), dtype=self._cos_cached.dtype, device=self._cos_cached.device)
111+
sin = torch.empty((total_length, hidden_dim), dtype=self._sin_cached.dtype, device=self._sin_cached.device)
112+
inference_ops.get_cos_and_sin(
113+
self._cos_cached, self._sin_cached, cos, sin, sequence_lengths, kv_seq_len, inputmetadata.is_prompts
114+
)
115+
cos_sin = (cos, sin)
116+
106117
else:
107118
cu_seqlens = None
108-
109-
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
119+
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
110120

111121
sm_scale = 1.0 / (inputmetadata.head_dim**0.5)
112122

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
#include <ATen/cuda/CUDAContext.h>
2+
#include <torch/extension.h>
3+
4+
#include "utils/vector_copy_utils.h"
5+
#include "../common/micros.h"
6+
#include "stdio.h"
7+
8+
template <typename scalar_t, bool Aligned, int VecSize>
9+
__device__ void apply_cos_and_sin_memcopy(
10+
scalar_t* __restrict__ cos,
11+
scalar_t* __restrict__ sin,
12+
const scalar_t* __restrict__ cos_cache_ptr,
13+
const scalar_t* __restrict__ sin_cache_ptr,
14+
const int* __restrict__ sequence_lengths,
15+
const int head_dim,
16+
const int dest_offset_id,
17+
const int src_offset_id
18+
) {
19+
20+
int begin_id = threadIdx.x * VecSize;
21+
22+
for (; begin_id <= head_dim - VecSize; begin_id += blockDim.x){
23+
copy_vector<scalar_t, VecSize>(cos + dest_offset_id + begin_id, cos_cache_ptr + src_offset_id + begin_id);
24+
copy_vector<scalar_t, VecSize>(sin + dest_offset_id + begin_id, sin_cache_ptr + src_offset_id + begin_id);
25+
}
26+
27+
if (!Aligned) {
28+
for (; begin_id < head_dim; ++begin_id ) {
29+
cos[dest_offset_id + begin_id] = cos_cache_ptr[src_offset_id + begin_id];
30+
sin[dest_offset_id + begin_id] = sin_cache_ptr[src_offset_id + begin_id];
31+
}
32+
}
33+
}
34+
35+
template <typename scalar_t, bool Aligned, int VecSize>
36+
__global__ void apply_get_context_cos_and_sin_kernel(
37+
scalar_t* __restrict__ cos,
38+
scalar_t* __restrict__ sin,
39+
const scalar_t* __restrict__ cos_cache_ptr,
40+
const scalar_t* __restrict__ sin_cache_ptr,
41+
const int* __restrict__ sequence_lengths,
42+
const int* __restrict__ cumsum_lengths,
43+
const int batch_size,
44+
const int head_dim
45+
) {
46+
int token_id = blockIdx.x;
47+
if ( token_id >= sequence_lengths[blockIdx.y] ) {
48+
return ;
49+
}
50+
51+
int src_offset_id = token_id * head_dim;
52+
int dest_offset_id = src_offset_id;
53+
54+
if (blockIdx.y > 0) {
55+
dest_offset_id += cumsum_lengths[blockIdx.y - 1] * head_dim;
56+
}
57+
58+
apply_cos_and_sin_memcopy<scalar_t, Aligned, VecSize>(
59+
cos,
60+
sin,
61+
cos_cache_ptr,
62+
sin_cache_ptr,
63+
sequence_lengths,
64+
head_dim,
65+
dest_offset_id,
66+
src_offset_id
67+
);
68+
69+
}
70+
71+
template <typename scalar_t, bool Aligned, int VecSize>
72+
__global__ void apply_get_decode_cos_and_sin_kernel(
73+
scalar_t* __restrict__ cos,
74+
scalar_t* __restrict__ sin,
75+
const scalar_t* __restrict__ cos_cache_ptr,
76+
const scalar_t* __restrict__ sin_cache_ptr,
77+
const int* __restrict__ sequence_lengths,
78+
const int batch_size,
79+
const int head_dim
80+
) {
81+
int src_offset_id = ( sequence_lengths[blockIdx.y] - 1 ) * head_dim;
82+
int dest_offset_id = blockIdx.y * head_dim;
83+
84+
apply_cos_and_sin_memcopy<scalar_t, Aligned, VecSize>(
85+
cos,
86+
sin,
87+
cos_cache_ptr,
88+
sin_cache_ptr,
89+
sequence_lengths,
90+
head_dim,
91+
dest_offset_id,
92+
src_offset_id
93+
);
94+
}
95+
96+
template<typename scalar_t>
97+
void apply_get_cos_and_sin(
98+
at::Tensor& cos_cache, // [max_rotary_position, head_dim]
99+
at::Tensor& sin_cache, // [max_rotary_position, head_dim]
100+
at::Tensor& cos, // [num_tokens, head_dim]
101+
at::Tensor& sin, // [num_tokens, head_dim]
102+
at::Tensor& sequence_lengths, // [batch_size]
103+
int max_seq_len_in_batch,
104+
bool is_prompts
105+
) {
106+
int token_num = cos.size(0);
107+
int head_dim = cos.size(1);
108+
int batch_size = sequence_lengths.size(0);
109+
110+
at::Tensor cumsum_lengths;
111+
112+
int vec_size = get_vec_size<scalar_t>(cos);
113+
114+
bool aligned = true;
115+
if (head_dim % vec_size != 0) {
116+
aligned = false;
117+
}
118+
119+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
120+
int block_size_y;
121+
int block_size_x;
122+
123+
if (is_prompts) {
124+
block_size_y = batch_size;
125+
block_size_x = max_seq_len_in_batch;
126+
// TODO: The cumsum operation can be fused into get_cos_and_sin kernel later on.
127+
cumsum_lengths = torch::cumsum(sequence_lengths, 0, torch::kInt32);
128+
}
129+
else{
130+
block_size_y = batch_size;
131+
block_size_x = 1;
132+
}
133+
134+
int thread_nums = (head_dim + vec_size - 1) / vec_size;
135+
136+
dim3 grid(block_size_x, block_size_y);
137+
dim3 block(std::min(thread_nums, 512));
138+
139+
#define GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, __vec_size) \
140+
do { \
141+
if (is_prompts){ \
142+
apply_get_context_cos_and_sin_kernel<scalar_t, __aligned, __vec_size><<<grid, block, 0, stream>>>( \
143+
cos.data_ptr<scalar_t>(), \
144+
sin.data_ptr<scalar_t>(), \
145+
cos_cache.data_ptr<scalar_t>(), \
146+
sin_cache.data_ptr<scalar_t>(), \
147+
sequence_lengths.data_ptr<int>(), \
148+
cumsum_lengths.data_ptr<int>(), \
149+
batch_size, \
150+
head_dim \
151+
); \
152+
} \
153+
else { \
154+
apply_get_decode_cos_and_sin_kernel<scalar_t, __aligned, __vec_size><<<grid, block, 0, stream>>>( \
155+
cos.data_ptr<scalar_t>(), \
156+
sin.data_ptr<scalar_t>(), \
157+
cos_cache.data_ptr<scalar_t>(), \
158+
sin_cache.data_ptr<scalar_t>(), \
159+
sequence_lengths.data_ptr<int>(), \
160+
batch_size, \
161+
head_dim \
162+
); \
163+
} \
164+
} while(0)
165+
166+
#define GET_COS_AND_SIN_KERNEL_LAUNCH_VEC_SIZE_CASE(__aligned) \
167+
do { \
168+
switch (vec_size) { \
169+
case 1: \
170+
GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, 1); \
171+
break; \
172+
case 2: \
173+
GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, 2); \
174+
break; \
175+
case 4: \
176+
GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, 4); \
177+
break; \
178+
default: \
179+
AT_ERROR("Unsupported vectorized size ", vec_size); \
180+
break; \
181+
} \
182+
} while(0)
183+
184+
if (aligned) {
185+
GET_COS_AND_SIN_KERNEL_LAUNCH_VEC_SIZE_CASE(true);
186+
}
187+
else {
188+
GET_COS_AND_SIN_KERNEL_LAUNCH_VEC_SIZE_CASE(false);
189+
}
190+
191+
AT_CUDA_CHECK(cudaGetLastError());
192+
}
193+
194+
void get_cos_and_sin(
195+
at::Tensor& cos_cache, // [max_rotary_position, head_dim]
196+
at::Tensor& sin_cache, // [max_rotary_position, head_dim]
197+
at::Tensor& cos, // [num_tokens, head_dim]
198+
at::Tensor& sin, // [num_tokens, head_dim]
199+
at::Tensor& sequence_lengths, // [batch_size]
200+
int max_seq_len_in_batch,
201+
bool is_prompts
202+
) {
203+
DISPATCH_FLOAT_HALF_AND_BFLOAT(
204+
cos.scalar_type(),
205+
"get_cos_and_sin",
206+
apply_get_cos_and_sin<scalar_t>(
207+
cos_cache,
208+
sin_cache,
209+
cos,
210+
sin,
211+
sequence_lengths,
212+
max_seq_len_in_batch,
213+
is_prompts
214+
);)
215+
}

extensions/csrc/cuda/pybind/inference.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ void fused_add_rms_layernorm(torch::Tensor& input, // [..., hidden_size]
5151
torch::Tensor& weight, // [hidden_size]
5252
float epsilon);
5353

54+
void get_cos_and_sin(at::Tensor& cos_cache, // [max_rotary_position, head_dim]
55+
at::Tensor& sin_cache, // [max_rotary_position, head_dim]
56+
at::Tensor& cos, // [num_tokens, head_dim]
57+
at::Tensor& sin, // [num_tokens, head_dim]
58+
at::Tensor& sequence_lengths, // [batch_size]
59+
int max_seq_len_in_batch, bool is_prompts);
60+
5461
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
5562
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
5663
"Copy the GPU memory of kvcache during the decode stage.");
@@ -60,10 +67,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6067

6168
m.def(
6269
"rotary_embedding_and_cache_copy", &rotary_embedding_and_cache_copy,
63-
"performing Rotary Embedding-related calculations and KVCache Memcopy.");
70+
"Performing Rotary Embedding-related calculations and KVCache Memcopy.");
6471

6572
m.def("rotary_embedding", &rotary_embedding,
66-
"performing Rotary Embedding-related calculations.");
73+
"Performing Rotary Embedding-related calculations.");
6774

6875
m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply");
6976

@@ -72,4 +79,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
7279

7380
m.def("fused_add_rms_layernorm", &fused_add_rms_layernorm,
7481
"In-place fused Add and RMS Normalization.");
82+
83+
m.def("get_cos_and_sin", &get_cos_and_sin,
84+
"Get cos and sin from the cache.");
7585
}

extensions/inference/inference_ops_cuda.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def sources_files(self):
1616
"cuda/fused_rotary_emb_and_cache_kernel.cu",
1717
"cuda/activation_kernel.cu",
1818
"cuda/rms_layernorm_kernel.cu",
19+
"cuda/get_cos_and_sin_kernel.cu",
1920
]
2021
]
2122
return ret
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import numpy as np
2+
import pytest
3+
import torch
4+
5+
from colossalai.kernel.kernel_loader import InferenceOpsLoader
6+
from tests.test_infer.test_ops.triton.test_xine_copy import get_cos_sin
7+
8+
inference_ops = InferenceOpsLoader().load()
9+
10+
11+
def numpy_equal(x, y):
12+
x_numpy = x.detach().cpu().numpy()
13+
y_numpy = y.detach().cpu().numpy()
14+
15+
np.testing.assert_equal(x_numpy, y_numpy)
16+
17+
18+
@pytest.mark.parametrize("BATCH_SIZE", [4])
19+
@pytest.mark.parametrize("MAX_SEQ_LEN", [64])
20+
@pytest.mark.parametrize("HEAD_DIM", [64])
21+
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
22+
def test_get_cos_and_sin(BATCH_SIZE, MAX_SEQ_LEN, HEAD_DIM, dtype):
23+
MAX_TOTAL_TOKENS = BATCH_SIZE * MAX_SEQ_LEN
24+
cos_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda")
25+
sin_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda")
26+
lengths = torch.randint(2, MAX_SEQ_LEN, (BATCH_SIZE,), device="cuda").to(torch.int32)
27+
28+
max_seq_len_in_batch = lengths.max()
29+
30+
# prefill
31+
cos_ref, sin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype)
32+
33+
cos = torch.zeros_like(cos_ref)
34+
sin = torch.zeros_like(sin_ref)
35+
36+
inference_ops.get_cos_and_sin(cos_cache, sin_cache, cos, sin, lengths, max_seq_len_in_batch, True)
37+
38+
numpy_equal(cos, cos_ref)
39+
numpy_equal(sin, sin_ref)
40+
41+
# decoding
42+
ncos_ref, nsin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype)
43+
44+
cos = torch.zeros_like(ncos_ref)
45+
sin = torch.zeros_like(nsin_ref)
46+
47+
inference_ops.get_cos_and_sin(cos_cache, sin_cache, cos, sin, lengths, max_seq_len_in_batch, False)
48+
numpy_equal(cos, ncos_ref)
49+
numpy_equal(sin, nsin_ref)
50+
51+
52+
if __name__ == "__main__":
53+
test_get_cos_and_sin(16, 4096, 256, torch.float16)

0 commit comments

Comments
 (0)