@@ -124,18 +124,54 @@ __global__ void batched_rotary_embedding_kernel(
124
124
void rotary_embedding (
125
125
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
126
126
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
127
- // [num_tokens, num_heads * head_size]
127
+ // [num_tokens, num_heads * head_size] or
128
+ // [batch_size, seq_len, num_heads, head_size] or
129
+ // [num_tokens, num_heads, head_size]
128
130
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
129
- // [num_tokens, num_kv_heads * head_size]
131
+ // [num_tokens, num_kv_heads * head_size] or
132
+ // [batch_size, seq_len, num_heads, head_size] or
133
+ // [num_tokens, num_heads, head_size]
130
134
int64_t head_size,
131
135
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
132
136
bool is_neox) {
133
- int64_t num_tokens = query.numel () / query.size (-1 );
137
+ // num_tokens = batch_size * seq_len
138
+ int64_t num_tokens = positions.numel ();
139
+ int positions_ndim = positions.dim ();
140
+
141
+ // Make sure num_tokens dim is consistent across positions, query, and key.
142
+ TORCH_CHECK (
143
+ positions_ndim == 1 || positions_ndim == 2 ,
144
+ " positions must have shape [num_tokens] or [batch_size, seq_len]" );
145
+ if (positions_ndim == 1 ) {
146
+ TORCH_CHECK (
147
+ query.size (0 ) == positions.size (0 ) && key.size (0 ) == positions.size (0 ),
148
+ " query, key and positions must have the same number of tokens" );
149
+ }
150
+ if (positions_ndim == 2 ) {
151
+ TORCH_CHECK (
152
+ query.size (0 ) == positions.size (0 ) &&
153
+ key.size (0 ) == positions.size (0 ) &&
154
+ query.size (1 ) == positions.size (1 ) &&
155
+ key.size (1 ) == positions.size (1 ),
156
+ " query, key and positions must have the same batch_size and seq_len" );
157
+ }
158
+
159
+ // Make sure head_size is valid for query and key
160
+ // hidden_size = num_heads * head_size
161
+ int query_hidden_size = query.numel () / num_tokens;
162
+ int key_hidden_size = key.numel () / num_tokens;
163
+ TORCH_CHECK (query_hidden_size % head_size == 0 );
164
+ TORCH_CHECK (key_hidden_size % head_size == 0 );
165
+
166
+ // Make sure query and key have consistent number of heads
167
+ int num_heads = query_hidden_size / head_size;
168
+ int num_kv_heads = key_hidden_size / head_size;
169
+ TORCH_CHECK (num_heads % num_kv_heads == 0 );
170
+
134
171
int rot_dim = cos_sin_cache.size (1 );
135
- int num_heads = query.size (-1 ) / head_size;
136
- int num_kv_heads = key.size (-1 ) / head_size;
137
- int64_t query_stride = query.stride (-2 );
138
- int64_t key_stride = key.stride (-2 );
172
+ int seq_dim_idx = positions_ndim - 1 ;
173
+ int64_t query_stride = query.stride (seq_dim_idx);
174
+ int64_t key_stride = key.stride (seq_dim_idx);
139
175
140
176
dim3 grid (num_tokens);
141
177
dim3 block (std::min<int64_t >(num_heads * rot_dim / 2 , 512 ));
@@ -165,19 +201,58 @@ and process in batched manner.
165
201
void batched_rotary_embedding (
166
202
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
167
203
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
168
- // [num_tokens, num_heads * head_size]
204
+ // [num_tokens, num_heads * head_size] or
205
+ // [batch_size, seq_len, num_heads, head_size] or
206
+ // [num_tokens, num_heads, head_size]
169
207
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
170
- // [num_tokens, num_kv_heads * head_size]
208
+ // [num_tokens, num_kv_heads * head_size] or
209
+ // [batch_size, seq_len, num_heads, head_size] or
210
+ // [num_tokens, num_heads, head_size]
171
211
int64_t head_size,
172
212
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
173
213
bool is_neox, int64_t rot_dim,
174
- torch::Tensor& cos_sin_cache_offsets // [num_tokens]
214
+ torch::Tensor& cos_sin_cache_offsets // [num_tokens] or [batch_size]
175
215
) {
216
+ // num_tokens = batch_size * seq_len
176
217
int64_t num_tokens = cos_sin_cache_offsets.size (0 );
177
- int num_heads = query.size (-1 ) / head_size;
178
- int num_kv_heads = key.size (-1 ) / head_size;
179
- int64_t query_stride = query.stride (-2 );
180
- int64_t key_stride = key.stride (-2 );
218
+ TORCH_CHECK (
219
+ positions.size (0 ) == num_tokens || positions.numel () == num_tokens,
220
+ " positions must have the same num_tokens or batch_size as "
221
+ " cos_sin_cache_offsets" );
222
+
223
+ int positions_ndim = positions.dim ();
224
+ // Make sure num_tokens dim is consistent across positions, query, and key.
225
+ TORCH_CHECK (
226
+ positions_ndim == 1 || positions_ndim == 2 ,
227
+ " positions must have shape [num_tokens] or [batch_size, seq_len]" );
228
+ if (positions_ndim == 1 ) {
229
+ TORCH_CHECK (
230
+ query.size (0 ) == positions.size (0 ) && key.size (0 ) == positions.size (0 ),
231
+ " query, key and positions must have the same number of tokens" );
232
+ }
233
+ if (positions_ndim == 2 ) {
234
+ TORCH_CHECK (
235
+ query.size (0 ) == positions.size (0 ) &&
236
+ key.size (0 ) == positions.size (0 ) &&
237
+ query.size (1 ) == positions.size (1 ) &&
238
+ key.size (1 ) == positions.size (1 ),
239
+ " query, key and positions must have the same batch_size and seq_len" );
240
+ }
241
+
242
+ // Make sure head_size is valid for query and key
243
+ int query_hidden_size = query.numel () / num_tokens;
244
+ int key_hidden_size = key.numel () / num_tokens;
245
+ TORCH_CHECK (query_hidden_size % head_size == 0 );
246
+ TORCH_CHECK (key_hidden_size % head_size == 0 );
247
+
248
+ // Make sure query and key have concistent number of heads
249
+ int num_heads = query_hidden_size / head_size;
250
+ int num_kv_heads = key_hidden_size / head_size;
251
+ TORCH_CHECK (num_heads % num_kv_heads == 0 );
252
+
253
+ int seq_dim_idx = positions_ndim - 1 ;
254
+ int64_t query_stride = query.stride (seq_dim_idx);
255
+ int64_t key_stride = key.stride (seq_dim_idx);
181
256
182
257
dim3 grid (num_tokens);
183
258
dim3 block (std::min<int64_t >(num_heads * rot_dim / 2 , 512 ));
0 commit comments