@@ -124,18 +124,54 @@ __global__ void batched_rotary_embedding_kernel(
124124void rotary_embedding (
125125 torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
126126 torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
127- // [num_tokens, num_heads * head_size]
127+ // [num_tokens, num_heads * head_size] or
128+ // [batch_size, seq_len, num_heads, head_size] or
129+ // [num_tokens, num_heads, head_size]
128130 torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
129- // [num_tokens, num_kv_heads * head_size]
131+ // [num_tokens, num_kv_heads * head_size] or
132+ // [batch_size, seq_len, num_heads, head_size] or
133+ // [num_tokens, num_heads, head_size]
130134 int64_t head_size,
131135 torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
132136 bool is_neox) {
133- int64_t num_tokens = query.numel () / query.size (-1 );
137+ // num_tokens = batch_size * seq_len
138+ int64_t num_tokens = positions.numel ();
139+ int positions_ndim = positions.dim ();
140+
141+ // Make sure num_tokens dim is consistent across positions, query, and key.
142+ TORCH_CHECK (
143+ positions_ndim == 1 || positions_ndim == 2 ,
144+ " positions must have shape [num_tokens] or [batch_size, seq_len]" );
145+ if (positions_ndim == 1 ) {
146+ TORCH_CHECK (
147+ query.size (0 ) == positions.size (0 ) && key.size (0 ) == positions.size (0 ),
148+ " query, key and positions must have the same number of tokens" );
149+ }
150+ if (positions_ndim == 2 ) {
151+ TORCH_CHECK (
152+ query.size (0 ) == positions.size (0 ) &&
153+ key.size (0 ) == positions.size (0 ) &&
154+ query.size (1 ) == positions.size (1 ) &&
155+ key.size (1 ) == positions.size (1 ),
156+ " query, key and positions must have the same batch_size and seq_len" );
157+ }
158+
159+ // Make sure head_size is valid for query and key
160+ // hidden_size = num_heads * head_size
161+ int query_hidden_size = query.numel () / num_tokens;
162+ int key_hidden_size = key.numel () / num_tokens;
163+ TORCH_CHECK (query_hidden_size % head_size == 0 );
164+ TORCH_CHECK (key_hidden_size % head_size == 0 );
165+
166+ // Make sure query and key have consistent number of heads
167+ int num_heads = query_hidden_size / head_size;
168+ int num_kv_heads = key_hidden_size / head_size;
169+ TORCH_CHECK (num_heads % num_kv_heads == 0 );
170+
134171 int rot_dim = cos_sin_cache.size (1 );
135- int num_heads = query.size (-1 ) / head_size;
136- int num_kv_heads = key.size (-1 ) / head_size;
137- int64_t query_stride = query.stride (-2 );
138- int64_t key_stride = key.stride (-2 );
172+ int seq_dim_idx = positions_ndim - 1 ;
173+ int64_t query_stride = query.stride (seq_dim_idx);
174+ int64_t key_stride = key.stride (seq_dim_idx);
139175
140176 dim3 grid (num_tokens);
141177 dim3 block (std::min<int64_t >(num_heads * rot_dim / 2 , 512 ));
@@ -165,19 +201,58 @@ and process in batched manner.
165201void batched_rotary_embedding (
166202 torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
167203 torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
168- // [num_tokens, num_heads * head_size]
204+ // [num_tokens, num_heads * head_size] or
205+ // [batch_size, seq_len, num_heads, head_size] or
206+ // [num_tokens, num_heads, head_size]
169207 torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
170- // [num_tokens, num_kv_heads * head_size]
208+ // [num_tokens, num_kv_heads * head_size] or
209+ // [batch_size, seq_len, num_heads, head_size] or
210+ // [num_tokens, num_heads, head_size]
171211 int64_t head_size,
172212 torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
173213 bool is_neox, int64_t rot_dim,
174- torch::Tensor& cos_sin_cache_offsets // [num_tokens]
214+ torch::Tensor& cos_sin_cache_offsets // [num_tokens] or [batch_size]
175215) {
216+ // num_tokens = batch_size * seq_len
176217 int64_t num_tokens = cos_sin_cache_offsets.size (0 );
177- int num_heads = query.size (-1 ) / head_size;
178- int num_kv_heads = key.size (-1 ) / head_size;
179- int64_t query_stride = query.stride (-2 );
180- int64_t key_stride = key.stride (-2 );
218+ TORCH_CHECK (
219+ positions.size (0 ) == num_tokens || positions.numel () == num_tokens,
220+ " positions must have the same num_tokens or batch_size as "
221+ " cos_sin_cache_offsets" );
222+
223+ int positions_ndim = positions.dim ();
224+ // Make sure num_tokens dim is consistent across positions, query, and key.
225+ TORCH_CHECK (
226+ positions_ndim == 1 || positions_ndim == 2 ,
227+ " positions must have shape [num_tokens] or [batch_size, seq_len]" );
228+ if (positions_ndim == 1 ) {
229+ TORCH_CHECK (
230+ query.size (0 ) == positions.size (0 ) && key.size (0 ) == positions.size (0 ),
231+ " query, key and positions must have the same number of tokens" );
232+ }
233+ if (positions_ndim == 2 ) {
234+ TORCH_CHECK (
235+ query.size (0 ) == positions.size (0 ) &&
236+ key.size (0 ) == positions.size (0 ) &&
237+ query.size (1 ) == positions.size (1 ) &&
238+ key.size (1 ) == positions.size (1 ),
239+ " query, key and positions must have the same batch_size and seq_len" );
240+ }
241+
242+ // Make sure head_size is valid for query and key
243+ int query_hidden_size = query.numel () / num_tokens;
244+ int key_hidden_size = key.numel () / num_tokens;
245+ TORCH_CHECK (query_hidden_size % head_size == 0 );
246+ TORCH_CHECK (key_hidden_size % head_size == 0 );
247+
248+ // Make sure query and key have concistent number of heads
249+ int num_heads = query_hidden_size / head_size;
250+ int num_kv_heads = key_hidden_size / head_size;
251+ TORCH_CHECK (num_heads % num_kv_heads == 0 );
252+
253+ int seq_dim_idx = positions_ndim - 1 ;
254+ int64_t query_stride = query.stride (seq_dim_idx);
255+ int64_t key_stride = key.stride (seq_dim_idx);
181256
182257 dim3 grid (num_tokens);
183258 dim3 block (std::min<int64_t >(num_heads * rot_dim / 2 , 512 ));
0 commit comments