Skip to content

[Shape Inference] Fix GQA shape inference for present outputs#27250

Open
Honry wants to merge 1 commit intomicrosoft:mainfrom
Honry:fix-gqa-shape-inference
Open

[Shape Inference] Fix GQA shape inference for present outputs#27250
Honry wants to merge 1 commit intomicrosoft:mainfrom
Honry:fix-gqa-shape-inference

Conversation

@Honry
Copy link
Contributor

@Honry Honry commented Feb 5, 2026

Description

When using pre-allocated KV cache with freeDimensionOverrides, the shape inference for present_key and present_value outputs failed silently. This caused downstream graph operations to receive tensors with unknown dynamic shapes, leading to unexpected fallback in execution providers like WebNN. (WebNN currently doesn't support dynamic shape)

Motivation and Context

Root cause:
In BaseGroupQueryAttentionTypeAndShapeInference(), the shape inference logic for use_max_past_present_buffer == -1 only propagated shapes when BOTH conditions were met:

  1. total_sequence_length_value was a concrete value (> 0)
  2. past_dims[2] had a concrete dimension value

When either condition failed (e.g., using freeDimensionOverrides which results in dynamic past_sequence_length), present output shapes were left uninitialized.

Additionally, when past_key/past_value is not provided (prefill/first-token mode), no shape inference was performed for present outputs at all.

Fix:

  1. For use_max_past_present_buffer == -1:

    • Always construct and propagate present_shape
    • Compute present_sequence_length = max(past_sequence_length, total_sequence_length) when both values are concrete
    • Fall back to copying past_key's sequence dimension when exact value cannot be computed
  2. Add new else-if branch to handle prefill mode (no past_key/past_value input):

    • Infer head_size from query shape and num_heads/kv_num_heads attrs
    • Handle both separate Q/K/V and packed QKV input formats
    • Construct present shape from query dims, kv_num_heads, and total_sequence_length or kv_sequence_length

… outputs

Issue:
When using pre-allocated KV cache with freeDimensionOverrides, the shape
inference for present_key and present_value outputs failed silently.
This caused downstream graph operations to receive tensors with unknown
dynamic shapes, leading to unexpected fallback in execution providers like WebNN.
(WebNN currently doesn't support dynamic shape)

Root cause:
In BaseGroupQueryAttentionTypeAndShapeInference(), the shape inference
logic for use_max_past_present_buffer == -1 only propagated shapes when
BOTH conditions were met:
  1. total_sequence_length_value was a concrete value (> 0)
  2. past_dims[2] had a concrete dimension value

When either condition failed (e.g., using freeDimensionOverrides which
results in dynamic past_sequence_length), present output shapes were
left uninitialized.

Additionally, when past_key is not provided (prefill/first-token mode),
no shape inference was performed for present outputs at all.

Fix:
1. For use_max_past_present_buffer == -1:
   - Always construct and propagate present_shape
   - Compute present_sequence_length = max(past_sequence_length,
     total_sequence_length) when both values are concrete
   - Fall back to copying past_key's sequence dimension when exact
     value cannot be computed

2. Add new else-if branch to handle prefill mode (no past_key input):
   - Infer head_size from query shape and num_heads/kv_num_heads attrs
   - Handle both separate Q/K/V and packed QKV input formats
   - Construct present shape from query dims, kv_num_heads, and
     total_sequence_length or kv_sequence_length
@Honry
Copy link
Contributor Author

Honry commented Feb 5, 2026

@guschmue, @tianleiwu, PTAL, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant