[Shape Inference] Fix GQA shape inference for present outputs#27250
Open
Honry wants to merge 1 commit intomicrosoft:mainfrom
Open
[Shape Inference] Fix GQA shape inference for present outputs#27250Honry wants to merge 1 commit intomicrosoft:mainfrom
Honry wants to merge 1 commit intomicrosoft:mainfrom
Conversation
… outputs
Issue:
When using pre-allocated KV cache with freeDimensionOverrides, the shape
inference for present_key and present_value outputs failed silently.
This caused downstream graph operations to receive tensors with unknown
dynamic shapes, leading to unexpected fallback in execution providers like WebNN.
(WebNN currently doesn't support dynamic shape)
Root cause:
In BaseGroupQueryAttentionTypeAndShapeInference(), the shape inference
logic for use_max_past_present_buffer == -1 only propagated shapes when
BOTH conditions were met:
1. total_sequence_length_value was a concrete value (> 0)
2. past_dims[2] had a concrete dimension value
When either condition failed (e.g., using freeDimensionOverrides which
results in dynamic past_sequence_length), present output shapes were
left uninitialized.
Additionally, when past_key is not provided (prefill/first-token mode),
no shape inference was performed for present outputs at all.
Fix:
1. For use_max_past_present_buffer == -1:
- Always construct and propagate present_shape
- Compute present_sequence_length = max(past_sequence_length,
total_sequence_length) when both values are concrete
- Fall back to copying past_key's sequence dimension when exact
value cannot be computed
2. Add new else-if branch to handle prefill mode (no past_key input):
- Infer head_size from query shape and num_heads/kv_num_heads attrs
- Handle both separate Q/K/V and packed QKV input formats
- Construct present shape from query dims, kv_num_heads, and
total_sequence_length or kv_sequence_length
Contributor
Author
|
@guschmue, @tianleiwu, PTAL, thanks! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
When using pre-allocated KV cache with
freeDimensionOverrides, the shape inference forpresent_keyandpresent_valueoutputs failed silently. This caused downstream graph operations to receive tensors with unknown dynamic shapes, leading to unexpected fallback in execution providers like WebNN. (WebNN currently doesn't support dynamic shape)Motivation and Context
Root cause:
In
BaseGroupQueryAttentionTypeAndShapeInference(), the shape inference logic foruse_max_past_present_buffer == -1only propagated shapes when BOTH conditions were met:total_sequence_length_valuewas a concrete value (> 0)past_dims[2]had a concrete dimension valueWhen either condition failed (e.g., using
freeDimensionOverrideswhich results in dynamicpast_sequence_length), present output shapes were left uninitialized.Additionally, when
past_key/past_valueis not provided (prefill/first-token mode), no shape inference was performed for present outputs at all.Fix:
For
use_max_past_present_buffer == -1:present_shapepresent_sequence_length = max(past_sequence_length, total_sequence_length)when both values are concretepast_key's sequence dimension when exact value cannot be computedAdd new else-if branch to handle prefill mode (no past_key/past_value input):
head_sizefrom query shape andnum_heads/kv_num_headsattrskv_num_heads, andtotal_sequence_lengthorkv_sequence_length