Skip to content

[pt2e] Skip linear+bn fusion when input is higher than 2-D#4242

Merged
andrewor14 merged 1 commit intopytorch:mainfrom
Lidang-Jiang:fix/linear-bn-fusion-4116
Apr 7, 2026
Merged

[pt2e] Skip linear+bn fusion when input is higher than 2-D#4242
andrewor14 merged 1 commit intopytorch:mainfrom
Lidang-Jiang:fix/linear-bn-fusion-4116

Conversation

@Lidang-Jiang
Copy link
Copy Markdown
Contributor

Summary

Fix _fuse_linear_bn_ silently producing incorrect results when the linear input has more than 2 dimensions.

  • Root cause: Linear always operates on the last dimension, while BatchNorm1d normalizes along dim 1 (channels). These coincide only for 2-D inputs (N, C). For 3-D inputs (N, C, L), the BN scale/shift is applied along the wrong axis during fusion, producing silently incorrect output.
  • Fix: Check linear_input.ndim via node.meta["val"] before fusion. Skip fusion with a warning when ndim > 2.
  • Tests: Added 2 new test cases covering the 3-D skip and 2-D correctness.

Fixes #4116

Reproduction

Model: Linear(3, 5)BatchNorm1d(5) with trained (non-trivial) BN statistics.

Before (3-D input fused incorrectly)
============================================================
Test: 3D input (BUG: dims differ)
Input shape: torch.Size([2, 5, 3])
============================================================
BN fused: True
Max absolute diff:  3.742362e-01
Mean absolute diff: 6.834684e-02
RESULT: FAIL - silent incorrect output (max diff = 3.742362e-01)

============================================================
Test: 2D input (correct fusion)
Input shape: torch.Size([2, 3])
============================================================
BN fused: True
Max absolute diff:  2.384186e-07
Mean absolute diff: 6.295740e-08
RESULT: PASS

============================================================
Summary:
  3D (bug case): FAIL (diff=3.742362e-01)
  2D (correct):  PASS (diff=2.384186e-07)
============================================================
After (3-D fusion skipped, output correct)
torchao/quantization/pt2e/utils.py:990: UserWarning: Not fusing linear+bn for node 'linear': the linear input is 3-D so Linear and BatchNorm operate on different dimensions

============================================================
Test: 3D input (BUG: dims differ)
Input shape: torch.Size([2, 5, 3])
============================================================
BN fused: False
Max absolute diff:  0.000000e+00
Mean absolute diff: 0.000000e+00
RESULT: PASS

============================================================
Test: 2D input (correct fusion)
Input shape: torch.Size([2, 3])
============================================================
BN fused: True
Max absolute diff:  2.384186e-07
Mean absolute diff: 6.295740e-08
RESULT: PASS

============================================================
Summary:
  3D (bug case): PASS (diff=0.000000e+00)
  2D (correct):  PASS (diff=2.384186e-07)
============================================================
Test results
$ python -m pytest test/quantization/pt2e/test_quantize_pt2e.py -k "linear_bn_fusion" -v

test_linear_bn_fusion PASSED
test_linear_bn_fusion_skipped_for_3d_input PASSED
test_linear_bn_fusion_correct_for_2d_input PASSED

3 passed in 14.10s

Test plan

  • test_linear_bn_fusion — existing test still passes (2-D fusion works)
  • test_linear_bn_fusion_skipped_for_3d_input — 3-D input: fusion skipped, warning emitted, output matches reference
  • test_linear_bn_fusion_correct_for_2d_input — 2-D input with multiple dimension combos and bias variants: fusion correct, BN removed
  • ruff check + ruff format pass
  • pre-commit hooks pass

Linear always operates on the last dimension while BatchNorm1d
normalizes along dim 1 (channels). These two coincide only for 2-D
inputs (N, C). For higher-rank inputs like 3-D (N, C, L), fusing
the BN parameters into Linear weights silently produces incorrect
results because the scale/shift is applied along the wrong axis.

Add an ndim check in _fuse_linear_bn_ that skips fusion and emits
a warning when the linear input has more than 2 dimensions.

Fixes pytorch#4116

Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com>
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Apr 6, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4242

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit 9d9a724 with merge base 6e7a6e9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 6, 2026
@Lidang-Jiang
Copy link
Copy Markdown
Contributor Author

@pytorchbot label "module: pt2e_quant"

@pytorch-bot pytorch-bot bot added the module: pt2e_quant pt2 export quantization (prepare_pt2e, convert_pt2e, quantizer) label Apr 6, 2026
@andrewor14
Copy link
Copy Markdown
Contributor

@claude review please?

@claude
Copy link
Copy Markdown

claude bot commented Apr 7, 2026

Claude finished @andrewor14's task in 2m 1s —— View job


Code Review: [pt2e] Skip linear+bn fusion when input is higher than 2-D

  • Read the PR diff and changed files
  • Review the production code changes (torchao/quantization/pt2e/utils.py)
  • Review the test changes (test/quantization/pt2e/test_quantize_pt2e.py)
  • Post review feedback

Overall: This is a correct and well-targeted fix. The root cause analysis is accurate — Linear operates on the last dimension while BatchNorm1d normalizes along dim 1, so fusion is only valid for 2-D inputs where those coincide. The implementation is clean and the tests cover the important cases.

A few observations:

Production code (torchao/quantization/pt2e/utils.py)

  1. Redundant isinstance check (minor, nit): At line 987, isinstance(linear_input_val, torch.Tensor) is technically unnecessary because meta["val"] in an exported graph is always a FakeTensor (which is a torch.Tensor subclass) or None. The existing pattern in this codebase (e.g., torchao/quantization/pt2e/inductor_passes/x86.py:694-698) just checks for None and then uses .dim() directly. That said, the extra guard is defensive and not harmful — just a style inconsistency.

  2. Behavior when meta["val"] is missing: If linear_input_node.meta.get("val") returns None (e.g., the graph wasn't traced with fake tensors), the code falls through and attempts fusion anyway, which preserves backward compatibility. This is the right default.

  3. stacklevel=1 in warnings.warn (line 996): stacklevel=1 points to the warnings.warn call itself inside _fuse_linear_bn_. This is fine for this context since the function is internal and the warning message already identifies the problematic node by name, making it easy to diagnose. No change needed.

  4. Comment quality: The block comment (lines 974–981) explaining the dimensional mismatch is clear and includes the issue link. Good.

Tests (test/quantization/pt2e/test_quantize_pt2e.py)

  1. test_linear_bn_fusion_skipped_for_3d_input: Good coverage — it checks both that a warning is emitted and that output correctness is preserved. Testing both bias=True and bias=False is appropriate.

  2. test_linear_bn_fusion_correct_for_2d_input: This test is very similar to the existing test_linear_bn_fusion (line 186). The main additions are testing (5, 5) dimensions and verifying BN removal. Consider whether this is worth maintaining as a separate test vs. extending the existing one, though it's fine either way.

  3. Missing check in 3-D test: The 3-D test verifies that BN fusion is skipped (via warning) and output matches, but doesn't explicitly verify that BN nodes are still present in the graph. This would make the test more robust — confirming that the graph was truly left unmodified. Fix this →

Summary

The fix is correct, minimal, and well-tested. The PR is ready to land with perhaps only the optional improvement of verifying BN node preservation in the 3-D skip test. Nice contribution.


@andrewor14
Copy link
Copy Markdown
Contributor

Looks good, thanks for fixing this! Let's wait on CI and claude review

@andrewor14 andrewor14 merged commit 4efd83a into pytorch:main Apr 7, 2026
21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: pt2e_quant pt2 export quantization (prepare_pt2e, convert_pt2e, quantizer)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Linear+BatchNorm fusion produces incorrect model in most cases

2 participants