88namespace vllm {
99
1010template <typename scalar_t , bool IS_NEOX>
11- inline __device__ void apply_rotary_embedding (
11+ inline __device__ void apply_token_rotary_embedding (
1212 scalar_t * __restrict__ arr,
1313 const scalar_t * __restrict__ cos_ptr,
1414 const scalar_t * __restrict__ sin_ptr,
@@ -38,22 +38,18 @@ inline __device__ void apply_rotary_embedding(
3838}
3939
4040template <typename scalar_t , bool IS_NEOX>
41- __global__ void rotary_embedding_kernel (
42- const int64_t * __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
41+ inline __device__ void apply_rotary_embedding (
4342 scalar_t * __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
4443 scalar_t * __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
45- const scalar_t * __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
46- const int rot_dim,
47- const int64_t query_stride,
48- const int64_t key_stride,
44+ const scalar_t * cache_ptr,
45+ const int head_size,
4946 const int num_heads,
5047 const int num_kv_heads,
51- const int head_size) {
52- // Each thread block is responsible for one token.
53- const int token_idx = blockIdx .x ;
54- int64_t pos = positions[token_idx];
55- const scalar_t * cache_ptr = cos_sin_cache + pos * rot_dim;
56-
48+ const int rot_dim,
49+ const int token_idx,
50+ const int64_t query_stride,
51+ const int64_t key_stride)
52+ {
5753 const int embed_dim = rot_dim / 2 ;
5854 const scalar_t * cos_ptr = cache_ptr;
5955 const scalar_t * sin_ptr = cache_ptr + embed_dim;
@@ -63,7 +59,7 @@ __global__ void rotary_embedding_kernel(
6359 const int head_idx = i / embed_dim;
6460 const int64_t token_head = token_idx * query_stride + head_idx * head_size;
6561 const int rot_offset = i % embed_dim;
66- apply_rotary_embedding <scalar_t , IS_NEOX>(query + token_head, cos_ptr,
62+ apply_token_rotary_embedding <scalar_t , IS_NEOX>(query + token_head, cos_ptr,
6763 sin_ptr, rot_offset, embed_dim);
6864 }
6965
@@ -72,11 +68,53 @@ __global__ void rotary_embedding_kernel(
7268 const int head_idx = i / embed_dim;
7369 const int64_t token_head = token_idx * key_stride + head_idx * head_size;
7470 const int rot_offset = i % embed_dim;
75- apply_rotary_embedding <scalar_t , IS_NEOX>(key + token_head, cos_ptr,
71+ apply_token_rotary_embedding <scalar_t , IS_NEOX>(key + token_head, cos_ptr,
7672 sin_ptr, rot_offset, embed_dim);
7773 }
7874}
7975
76+ template <typename scalar_t , bool IS_NEOX>
77+ __global__ void rotary_embedding_kernel (
78+ const int64_t * __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
79+ scalar_t * __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
80+ scalar_t * __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
81+ const scalar_t * __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
82+ const int rot_dim,
83+ const int64_t query_stride,
84+ const int64_t key_stride,
85+ const int num_heads,
86+ const int num_kv_heads,
87+ const int head_size) {
88+ // Each thread block is responsible for one token.
89+ const int token_idx = blockIdx .x ;
90+ int64_t pos = positions[token_idx];
91+ const scalar_t * cache_ptr = cos_sin_cache + pos * rot_dim;
92+
93+ apply_rotary_embedding<scalar_t , IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
94+ }
95+
96+ template <typename scalar_t , bool IS_NEOX>
97+ __global__ void batched_rotary_embedding_kernel (
98+ const int64_t * __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
99+ scalar_t * __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
100+ scalar_t * __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
101+ const scalar_t * __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
102+ const int64_t * __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens]
103+ const int rot_dim,
104+ const int64_t query_stride,
105+ const int64_t key_stride,
106+ const int num_heads,
107+ const int num_kv_heads,
108+ const int head_size) {
109+ // Each thread block is responsible for one token.
110+ const int token_idx = blockIdx .x ;
111+ int64_t pos = positions[token_idx];
112+ int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
113+ const scalar_t * cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
114+
115+ apply_rotary_embedding<scalar_t , IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
116+ }
117+
80118} // namespace vllm
81119
82120void rotary_embedding (
@@ -128,3 +166,61 @@ void rotary_embedding(
128166 }
129167 });
130168}
169+
170+ /*
171+ Batched version of rotary embedding, pack multiple LoRAs together
172+ and process in batched manner.
173+ */
174+ void batched_rotary_embedding (
175+ torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
176+ torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
177+ torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
178+ int head_size,
179+ torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
180+ bool is_neox,
181+ int rot_dim,
182+ torch::Tensor& cos_sin_cache_offsets // [num_tokens]
183+ ) {
184+ int64_t num_tokens = cos_sin_cache_offsets.size (0 );
185+ int num_heads = query.size (-1 ) / head_size;
186+ int num_kv_heads = key.size (-1 ) / head_size;
187+ int64_t query_stride = query.stride (-2 );
188+ int64_t key_stride = key.stride (-2 );
189+
190+ dim3 grid (num_tokens);
191+ dim3 block (std::min (num_heads * rot_dim / 2 , 512 ));
192+ const at::cuda::OptionalCUDAGuard device_guard (device_of (query));
193+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
194+ VLLM_DISPATCH_FLOATING_TYPES (
195+ query.scalar_type (),
196+ " rotary_embedding" ,
197+ [&] {
198+ if (is_neox) {
199+ vllm::batched_rotary_embedding_kernel<scalar_t , true ><<<grid, block, 0 , stream>>> (
200+ positions.data_ptr <int64_t >(),
201+ query.data_ptr <scalar_t >(),
202+ key.data_ptr <scalar_t >(),
203+ cos_sin_cache.data_ptr <scalar_t >(),
204+ cos_sin_cache_offsets.data_ptr <int64_t >(),
205+ rot_dim,
206+ query_stride,
207+ key_stride,
208+ num_heads,
209+ num_kv_heads,
210+ head_size);
211+ } else {
212+ vllm::batched_rotary_embedding_kernel<scalar_t , false ><<<grid, block, 0 , stream>>> (
213+ positions.data_ptr <int64_t >(),
214+ query.data_ptr <scalar_t >(),
215+ key.data_ptr <scalar_t >(),
216+ cos_sin_cache.data_ptr <scalar_t >(),
217+ cos_sin_cache_offsets.data_ptr <int64_t >(),
218+ rot_dim,
219+ query_stride,
220+ key_stride,
221+ num_heads,
222+ num_kv_heads,
223+ head_size);
224+ }
225+ });
226+ }
0 commit comments