Fix GPT2 attention scaling ignored in SDPA/FlashAttention#44397
Conversation
…ends GPT2Attention.forward() did not pass the `scaling` parameter to `attention_interface`, causing `scale_attn_weights` and `scale_attn_by_inverse_layer_idx` config options to be silently ignored when using SDPA or FlashAttention backends. Compute the combined scaling factor in __init__ (following the pattern used by LLaMA and other models) and forward it to the attention interface so all backends produce consistent results. Fixes huggingface#44380
…2Attention) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
vasqu
left a comment
There was a problem hiding this comment.
Some smaller comments but we got lucky on the default case tbh. Very good finding
| if self.scale_attn_weights: | ||
| self.scaling = self.head_dim**-0.5 |
There was a problem hiding this comment.
I think we got lucky here because SDPA and FA will default to exactly this
| if self.scale_attn_by_inverse_layer_idx: | ||
| self.scaling /= float(self.layer_idx + 1) |
There was a problem hiding this comment.
This is what was silently ignored then
There was a problem hiding this comment.
Imo, we can just use self.scaling in the eager forward as well then and copy from bert or similar (meaning the eager forward). No need for extra treatment then
There was a problem hiding this comment.
Thanks for the review! Addressed in the latest commit (04f9ba9):
self.scalingis computed once in__init__and used in both the_upcast_and_reordered_attnpath (viabaddbmm alpha) and the standardattention_interfacepath (viascaling=self.scaling).- The old per-forward scale factor computation in
_upcast_and_reordered_attnis removed.
So no extra treatment — just self.scaling everywhere, same pattern as bert.
There was a problem hiding this comment.
Let's also address https://github.com/OiPunk/transformers/blob/04f9ba9ff0e3de8a8b21c801ba79509328ff14da/src/transformers/models/gpt2/modeling_gpt2.py#L54-L79 then
We can be closer to Bert since we unify into self.scaling instead https://github.com/OiPunk/transformers/blob/04f9ba9ff0e3de8a8b21c801ba79509328ff14da/src/transformers/models/bert/modeling_bert.py#L115-L140
| ) | ||
| result.loss.backward() | ||
|
|
||
| def test_gpt2_sdpa_matches_eager_with_scaling_configs(self): |
There was a problem hiding this comment.
Added! The FA test is at test_gpt2_fa2_matches_eager_with_scaling_configs — both tests use model.set_attn_implementation() to switch backends without reloading.
| config.scale_attn_by_inverse_layer_idx = True | ||
|
|
||
| # Eager attention (known-correct reference) | ||
| config._attn_implementation = "eager" |
There was a problem hiding this comment.
I'd rather use this after init, i.e. model.set_attn_implementation("eager")
| output_eager = model_eager(input_ids, token_type_ids=token_type_ids).logits | ||
|
|
||
| # SDPA attention (was buggy: ignored scaling configs) | ||
| config._attn_implementation = "sdpa" |
| model_sdpa = GPT2LMHeadModel(config).to(torch_device).eval() | ||
| model_sdpa.load_state_dict(model_eager.state_dict()) |
There was a problem hiding this comment.
We don't need this reloading stuff; we just switch up the flags with the set_attn...
…ve tests Per reviewer feedback: - Replace inline scale_factor computation with self.scaling in _upcast_and_reordered_attn for both GPT2 and DecisionTransformer - Use model.set_attn_implementation() instead of model reloading in tests - Add FlashAttention2 vs eager comparison test
|
@vasqu Thanks for the review! Addressed all feedback:
|
vasqu
left a comment
There was a problem hiding this comment.
Thanks for iterating, left another round of smaller comments to simplify a bit / follow some conventions
| with torch.no_grad(): | ||
| output_eager = model(input_ids, token_type_ids=token_type_ids).logits | ||
|
|
||
| # FlashAttention2 |
There was a problem hiding this comment.
| # FlashAttention2 | |
| # Flash Attention 2 (was buggy: ignored scaling configs) |
| with torch.no_grad(): | ||
| output_fa2 = model(input_ids, token_type_ids=token_type_ids).logits | ||
|
|
||
| torch.testing.assert_close(output_eager, output_fa2, atol=5e-3, rtol=5e-3) |
There was a problem hiding this comment.
| torch.testing.assert_close(output_eager, output_fa2, atol=5e-3, rtol=5e-3) | |
| torch.testing.assert_close(output_eager, output_fa2, atol=1e-2, rtol=1e-2) |
Any reason for that atol/rtol? I remember FA being flaky at times so raising it would be personally preferred
| with torch.no_grad(): | ||
| output_sdpa = model(input_ids, token_type_ids=token_type_ids).logits | ||
|
|
||
| torch.testing.assert_close(output_eager, output_sdpa, atol=1e-5, rtol=1e-4) |
There was a problem hiding this comment.
| torch.testing.assert_close(output_eager, output_sdpa, atol=1e-5, rtol=1e-4) | |
| torch.testing.assert_close(output_eager, output_sdpa, atol=1e-4, rtol=1e-4) |
same here, just a small raise to avoid flakiness
There was a problem hiding this comment.
Let's also address https://github.com/OiPunk/transformers/blob/04f9ba9ff0e3de8a8b21c801ba79509328ff14da/src/transformers/models/gpt2/modeling_gpt2.py#L54-L79 then
We can be closer to Bert since we unify into self.scaling instead https://github.com/OiPunk/transformers/blob/04f9ba9ff0e3de8a8b21c801ba79509328ff14da/src/transformers/models/bert/modeling_bert.py#L115-L140
|
|
||
| torch.testing.assert_close(output_eager, output_sdpa, atol=1e-5, rtol=1e-4) | ||
|
|
||
| @require_torch_accelerator |
There was a problem hiding this comment.
| @require_torch_accelerator | |
| @require_torch_gpu | |
| @mark.flash_attn_test |
There was a problem hiding this comment.
Addressed in 802cd58:
- Refactored
eager_attention_forwardto acceptscalinganddropoutparams, matching Bert's pattern — no more manualmodule.scale_attn_weights/module.scale_attn_by_inverse_layer_idxchecks. - Applied the reviewer's tolerance suggestions: SDPA
atol=1e-4, FA2atol=1e-2. - Added
@require_torch_gpu+@pytest.mark.flash_attn_testto the FA test. - Fixed the FA comment text.
- Synced DecisionTransformer's
# Copied fromcopies.
…, fix test decorators and tolerances - Refactor eager_attention_forward to accept scaling/dropout params (like Bert's pattern) instead of reading module.scale_attn_weights/scale_attn_by_inverse_layer_idx directly - Reorganize GPT2Attention.__init__ to group config attrs together with clearer comment - Sync DecisionTransformerGPT2Attention (Copied from GPT2Attention) - Tests: use atol=1e-4/rtol=1e-4 for SDPA, atol=1e-2/rtol=1e-2 for FA2 - Tests: add @require_torch_gpu and @pytest.mark.flash_attn_test to FA2 test - Tests: fix FA2 comment to clarify the bug being tested
vasqu
left a comment
There was a problem hiding this comment.
See my last comment but looks good
| @require_flash_attn | ||
| @pytest.mark.flash_attn_test | ||
| def test_gpt2_fa2_matches_eager_with_scaling_configs(self): | ||
| """Test that FlashAttention2 and eager produce equivalent outputs when scaling configs differ.""" |
There was a problem hiding this comment.
Last nit: Let's add a small reference to the issue for both added tests
Checking with our slow CI in a second but looks good to me
|
run-slow: decision_transformer, gpt2 |
|
This comment contains models: ["models/decision_transformer", "models/gpt2"] |
Link both SDPA and FA2 scaling tests to the original issue huggingface#44380.
|
[For maintainers] Suggested jobs to run (before merge) run-slow: decision_transformer, gpt2 |
|
This is super weird, no idea why the CI is failing here |
|
Thanks for contributing btw 🤗 |
What does this PR do?
Fixes #44380
GPT2Attention.forward()did not pass thescalingparameter toattention_interface, causingscale_attn_weightsandscale_attn_by_inverse_layer_idxconfig options to be silently ignored when using SDPA or FlashAttention backends.The eager attention implementation (
eager_attention_forward) reads these flags directly from the module and applies scaling correctly. However, SDPA and FlashAttention rely on thescalingparameter passed to the attention function call — which GPT2 never provided.Changes
modeling_gpt2.py: Computeself.scalinginGPT2Attention.__init__by combiningscale_attn_weights(1/√d_k) andscale_attn_by_inverse_layer_idx(1/(layer_idx+1)), following the same pattern used by LLaMA and other models.modeling_gpt2.py: Passscaling=self.scalingtoattention_interface()inforward().test_modeling_gpt2.py: Addtest_gpt2_sdpa_matches_eager_with_scaling_configsthat verifies SDPA and eager produce equivalent outputs when using non-default scaling configs.Behavior table (before → after)
scale_attn_weights=True(default)scale_attn_weights=Falsescale_attn_by_inverse_layer_idx=True