Skip to content

Commit

Permalink
Fix GQA Rotary Embedding sequence length (#19801)
Browse files Browse the repository at this point in the history
### Description
Previously, GQA incorrectly enforced rotary cos and sin cache to be of
sequence length equal to present sequence length. Now it enforces that
it be greater than or equal to present sequence length since to match
Rotary Embedding Op it should be of max_sequence_length



### Motivation and Context
Fixes issue with fusing Rotary Embedding and GQA for certain models
which prefer this optimization.
  • Loading branch information
aciddelgado authored Mar 6, 2024
1 parent db8d0c8 commit 8bd1335
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,13 @@ Status CheckInputs(const Tensor* query,
"head_size shall be a multiple of 16. Got head_size % 16 == ",
head_size % 16);
}
if (cos_dims[0] != present_sequence_length) {
if (cos_dims[0] < present_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cos_cache dimension 0 must be of present_sequence_length.");
"cos_cache dimension 0 should be of max_sequence_length.");
}
if (sin_dims[0] != present_sequence_length) {
if (sin_dims[0] < present_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"sin_cache dimension 0 must be of present_sequence_length.");
"sin_cache dimension 0 should be of max_sequence_length.");
}
if (cos_dims[1] != (head_size / 16) * 8) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand Down

0 comments on commit 8bd1335

Please sign in to comment.