Skip to content

Fix GPT2 attention scaling ignored in SDPA/FlashAttention#44397

Merged
vasqu merged 8 commits intohuggingface:mainfrom
weiguangli-io:codex/transformers-44380-gpt2-sdpa-scaling
Mar 4, 2026
Merged

Fix GPT2 attention scaling ignored in SDPA/FlashAttention#44397
vasqu merged 8 commits intohuggingface:mainfrom
weiguangli-io:codex/transformers-44380-gpt2-sdpa-scaling

Conversation

@weiguangli-io
Copy link
Copy Markdown
Contributor

What does this PR do?

Fixes #44380

GPT2Attention.forward() did not pass the scaling parameter to attention_interface, causing scale_attn_weights and scale_attn_by_inverse_layer_idx config options to be silently ignored when using SDPA or FlashAttention backends.

The eager attention implementation (eager_attention_forward) reads these flags directly from the module and applies scaling correctly. However, SDPA and FlashAttention rely on the scaling parameter passed to the attention function call — which GPT2 never provided.

Changes

  1. modeling_gpt2.py: Compute self.scaling in GPT2Attention.__init__ by combining scale_attn_weights (1/√d_k) and scale_attn_by_inverse_layer_idx (1/(layer_idx+1)), following the same pattern used by LLaMA and other models.
  2. modeling_gpt2.py: Pass scaling=self.scaling to attention_interface() in forward().
  3. test_modeling_gpt2.py: Add test_gpt2_sdpa_matches_eager_with_scaling_configs that verifies SDPA and eager produce equivalent outputs when using non-default scaling configs.

Behavior table (before → after)

Config Eager SDPA before fix SDPA after fix
scale_attn_weights=True (default) ÷√d_k ÷√d_k (PyTorch default, coincidental match) ÷√d_k ✓
scale_attn_weights=False No scaling ÷√d_k (wrong) No scaling ✓
scale_attn_by_inverse_layer_idx=True ÷(layer+1) Ignored ÷(layer+1) ✓

liweiguang and others added 2 commits March 3, 2026 00:14
…ends

GPT2Attention.forward() did not pass the `scaling` parameter to
`attention_interface`, causing `scale_attn_weights` and
`scale_attn_by_inverse_layer_idx` config options to be silently
ignored when using SDPA or FlashAttention backends.

Compute the combined scaling factor in __init__ (following the pattern
used by LLaMA and other models) and forward it to the attention
interface so all backends produce consistent results.

Fixes huggingface#44380
…2Attention)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some smaller comments but we got lucky on the default case tbh. Very good finding

Comment on lines +105 to +106
if self.scale_attn_weights:
self.scaling = self.head_dim**-0.5
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we got lucky here because SDPA and FA will default to exactly this

Comment on lines +107 to +108
if self.scale_attn_by_inverse_layer_idx:
self.scaling /= float(self.layer_idx + 1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what was silently ignored then

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo, we can just use self.scaling in the eager forward as well then and copy from bert or similar (meaning the eager forward). No need for extra treatment then

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review! Addressed in the latest commit (04f9ba9):

  • self.scaling is computed once in __init__ and used in both the _upcast_and_reordered_attn path (via baddbmm alpha) and the standard attention_interface path (via scaling=self.scaling).
  • The old per-forward scale factor computation in _upcast_and_reordered_attn is removed.

So no extra treatment — just self.scaling everywhere, same pattern as bert.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)
result.loss.backward()

def test_gpt2_sdpa_matches_eager_with_scaling_configs(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also check FA

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added! The FA test is at test_gpt2_fa2_matches_eager_with_scaling_configs — both tests use model.set_attn_implementation() to switch backends without reloading.

Comment thread tests/models/gpt2/test_modeling_gpt2.py Outdated
config.scale_attn_by_inverse_layer_idx = True

# Eager attention (known-correct reference)
config._attn_implementation = "eager"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather use this after init, i.e. model.set_attn_implementation("eager")

Comment thread tests/models/gpt2/test_modeling_gpt2.py Outdated
output_eager = model_eager(input_ids, token_type_ids=token_type_ids).logits

# SDPA attention (was buggy: ignored scaling configs)
config._attn_implementation = "sdpa"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Comment thread tests/models/gpt2/test_modeling_gpt2.py Outdated
Comment on lines +264 to +265
model_sdpa = GPT2LMHeadModel(config).to(torch_device).eval()
model_sdpa.load_state_dict(model_eager.state_dict())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need this reloading stuff; we just switch up the flags with the set_attn...

…ve tests

Per reviewer feedback:
- Replace inline scale_factor computation with self.scaling in
  _upcast_and_reordered_attn for both GPT2 and DecisionTransformer
- Use model.set_attn_implementation() instead of model reloading in tests
- Add FlashAttention2 vs eager comparison test
@weiguangli-io
Copy link
Copy Markdown
Contributor Author

@vasqu Thanks for the review! Addressed all feedback:

  1. self.scaling in _upcast_and_reordered_attn: Replaced the inline scale_factor computation with self.scaling in both GPT2 and DecisionTransformer. Now all attention paths consistently use self.scaling computed once in __init__.

  2. set_attn_implementation in tests: Switched from model reloading with config._attn_implementation to model.set_attn_implementation("eager")/"sdpa" — single model, flag swap.

  3. FlashAttention2 test: Added test_gpt2_fa2_matches_eager_with_scaling_configs gated behind @require_torch_accelerator and @require_flash_attn.

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating, left another round of smaller comments to simplify a bit / follow some conventions

Comment thread tests/models/gpt2/test_modeling_gpt2.py Outdated
with torch.no_grad():
output_eager = model(input_ids, token_type_ids=token_type_ids).logits

# FlashAttention2
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# FlashAttention2
# Flash Attention 2 (was buggy: ignored scaling configs)

Comment thread tests/models/gpt2/test_modeling_gpt2.py Outdated
with torch.no_grad():
output_fa2 = model(input_ids, token_type_ids=token_type_ids).logits

torch.testing.assert_close(output_eager, output_fa2, atol=5e-3, rtol=5e-3)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
torch.testing.assert_close(output_eager, output_fa2, atol=5e-3, rtol=5e-3)
torch.testing.assert_close(output_eager, output_fa2, atol=1e-2, rtol=1e-2)

Any reason for that atol/rtol? I remember FA being flaky at times so raising it would be personally preferred

Comment thread tests/models/gpt2/test_modeling_gpt2.py Outdated
with torch.no_grad():
output_sdpa = model(input_ids, token_type_ids=token_type_ids).logits

torch.testing.assert_close(output_eager, output_sdpa, atol=1e-5, rtol=1e-4)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
torch.testing.assert_close(output_eager, output_sdpa, atol=1e-5, rtol=1e-4)
torch.testing.assert_close(output_eager, output_sdpa, atol=1e-4, rtol=1e-4)

same here, just a small raise to avoid flakiness

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread tests/models/gpt2/test_modeling_gpt2.py Outdated

torch.testing.assert_close(output_eager, output_sdpa, atol=1e-5, rtol=1e-4)

@require_torch_accelerator
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@require_torch_accelerator
@require_torch_gpu
@mark.flash_attn_test

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 802cd58:

  • Refactored eager_attention_forward to accept scaling and dropout params, matching Bert's pattern — no more manual module.scale_attn_weights / module.scale_attn_by_inverse_layer_idx checks.
  • Applied the reviewer's tolerance suggestions: SDPA atol=1e-4, FA2 atol=1e-2.
  • Added @require_torch_gpu + @pytest.mark.flash_attn_test to the FA test.
  • Fixed the FA comment text.
  • Synced DecisionTransformer's # Copied from copies.

…, fix test decorators and tolerances

- Refactor eager_attention_forward to accept scaling/dropout params (like Bert's pattern)
  instead of reading module.scale_attn_weights/scale_attn_by_inverse_layer_idx directly
- Reorganize GPT2Attention.__init__ to group config attrs together with clearer comment
- Sync DecisionTransformerGPT2Attention (Copied from GPT2Attention)
- Tests: use atol=1e-4/rtol=1e-4 for SDPA, atol=1e-2/rtol=1e-2 for FA2
- Tests: add @require_torch_gpu and @pytest.mark.flash_attn_test to FA2 test
- Tests: fix FA2 comment to clarify the bug being tested
Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my last comment but looks good

Comment thread tests/models/gpt2/test_modeling_gpt2.py Outdated
@require_flash_attn
@pytest.mark.flash_attn_test
def test_gpt2_fa2_matches_eager_with_scaling_configs(self):
"""Test that FlashAttention2 and eager produce equivalent outputs when scaling configs differ."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last nit: Let's add a small reference to the issue for both added tests

Checking with our slow CI in a second but looks good to me

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 309bb7f — added issue #44380 reference to both test docstrings.

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Mar 3, 2026

run-slow: decision_transformer, gpt2

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Mar 3, 2026

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/decision_transformer", "models/gpt2"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Mar 3, 2026

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 832b6319 workflow commit (merge commit)
PR 802cd589 branch commit (from PR)
main f6bd5abb base commit (on main)

✅ No failing test specific to this PR 🎉 👏 !

Link both SDPA and FA2 scaling tests to the original issue huggingface#44380.
@weiguangli-io
Copy link
Copy Markdown
Contributor Author

Added issue #44380 reference to both test docstrings in 309bb7f. Ready for slow CI.

@vasqu vasqu enabled auto-merge (squash) March 3, 2026 16:34
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Mar 4, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: decision_transformer, gpt2

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Mar 4, 2026

This is super weird, no idea why the CI is failing here

@vasqu vasqu merged commit 8757098 into huggingface:main Mar 4, 2026
26 checks passed
@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Mar 4, 2026

Thanks for contributing btw 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GPT2 attention scaling config is ignored when using SDPA / FlashAttention backends

2 participants