fix: resolve attention mask broadcasting error when cache sizes mismatch #463
+110
−0
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
Fixes a critical
ValueError: Incompatible shapes for broadcastingthat occurs during attention computation when the attention mask cache dimension doesn't match the logits cache dimension. This bug causes runtime failures during model inference, particularly in multi-turn conversations and when using padded inputs.Problem
From issue #407, users encounter this error during model execution:
Stack trace location:
Shape mismatch:
attn_maskafter expand_dims:(1, 1447, 1, 5234)← cache_size = 5234logits:(1, 1447, 8, 4096)← cache_size = 4096Root Cause
The attention mechanism flow:
attention_maskis created with acache_lengththat includes padding for static shapes and previous turnseinsumwith shape[B, L, num_heads, actual_cache_size]jnp.where()tries to broadcast the mask with logits, the cache size mismatch causes failureWhen this occurs:
Solution
Add defensive validation and slicing in
Attention.__call__()to ensure the attention mask cache dimension matches the logits cache dimension before broadcasting:Why this is safe:
Changes
1.
gemma/gm/nn/_modules.pyModified: Lines 271-277 (added 7 lines)
Added detailed comments explaining:
2.
gemma/gm/nn/_modules_mask_test.py(NEW)Created: Comprehensive regression test suite
Test 1:
test_attention_mask_broadcasting_shape_mismatchTest 2:
test_attention_mask_broadcasting_with_auto_slice_fixTest 3:
test_attention_mask_broadcasting_correct_shapesTesting
Regression tests:
Result: 3/3 tests passing
Existing test suite:
Result: 15/15 tests passing (no regressions)
Full test suite:
Result: 109/111 tests passing (2 expected failures for GCS checkpoints)
Performance Impact
Edge Cases Handled
Impact
Verification Steps
To verify this fix resolves the issue:
Reproduce original error:
Verify fix works:
Check no regressions:
Related Issues
Checklist
Future Considerations
While this fix handles the immediate issue defensively, potential upstream improvements could include:
However, the defensive approach is appropriate here as it: