Skip to content

Commit

Permalink
fix gqa rotary dim 1 (#19874)
Browse files Browse the repository at this point in the history
### Description
GQA Rotary Dimension 1 incorrectly assumed to be based on head size.



### Motivation and Context
This change should enable us to run phi-2 with GQA and Rotary Embedding
fused.
  • Loading branch information
aciddelgado authored Mar 13, 2024
1 parent e771a76 commit 8eb49c5
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 4 deletions.
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ struct GroupQueryAttentionParameters {
int kv_hidden_size;
int kv_num_heads;
int num_splits; // number of splits for splitkv
int rotary_dim; // rotary embedding dimension
bool is_unidirectional; // causal
int local_window_size;
bool kv_share_buffer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
int seqlen_q,
int seqlen_k,
int seqlen_k_new,
int rotary_dim,
const float softmax_scale,
bool is_causal,
bool is_bf16,
Expand Down Expand Up @@ -448,7 +449,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
params.rotary_cos_ptr = rotary_cos;
params.rotary_sin_ptr = rotary_sin;
params.is_rotary_interleaved = is_rotary_interleaved;
params.rotary_dim = (head_size / 16) * 16;
params.rotary_dim = rotary_dim;
}

params.num_splits = num_splits;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
int seqlen_q,
int seqlen_k,
int seqlen_k_new,
int rotary_dim,
const float softmax_scale,
bool is_causal,
bool is_bf16,
Expand Down
11 changes: 9 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ Status CheckInputs(const Tensor* query,
int total_sequence_length = *((*total_seqlen).template Data<int32_t>());
int present_sequence_length = std::max(total_sequence_length, past_sequence_length);

int rotary_dim = 0;
if (cos_cache != nullptr && sin_cache != nullptr) {
const auto& cos_dims = cos_cache->Shape().GetDims();
const auto& sin_dims = sin_cache->Shape().GetDims();
Expand All @@ -222,14 +223,19 @@ Status CheckInputs(const Tensor* query,
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"sin_cache dimension 0 should be of max_sequence_length.");
}
if (cos_dims[1] != (head_size / 16) * 8) {
if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
}
if (sin_dims[1] != (head_size / 16) * 8) {
if (sin_dims[1] > (head_size / 16) * 8 || sin_dims[1] % 8 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
}
if (cos_dims[1] != sin_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cos_cache and sin_cache dimension 1 must be the same.");
}
rotary_dim = static_cast<int>(cos_dims[1] * 2);
} else if (cos_cache != nullptr || sin_cache != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'cos_cache' and 'sin_cache' shall be both present or both absent.");
Expand All @@ -248,6 +254,7 @@ Status CheckInputs(const Tensor* query,
output_parameters->head_size = head_size;
output_parameters->kv_hidden_size = kv_hidden_size;
output_parameters->kv_num_heads = kv_num_heads;
output_parameters->rotary_dim = rotary_dim;
output_parameters->is_packed_qkv = is_packed_qkv;
output_parameters->is_unidirectional = true;
output_parameters->is_prompt = is_prompt;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ Status FlashAttention(
device_prop, stream, query, present_key, present_value, key, value, data.output,
reinterpret_cast<void*>(data.softmax_lse), seqlens_k, cos_cache, sin_cache,
batch_size, num_heads, kv_num_heads, head_size, sequence_length,
parameters.seqlen_present_kv_cache, kv_sequence_length,
parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim,
scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum),
reinterpret_cast<void*>(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved,
parameters.is_packed_qkv));
Expand Down

0 comments on commit 8eb49c5

Please sign in to comment.