[pt2e] Skip linear+bn fusion when input is higher than 2-D#4242
[pt2e] Skip linear+bn fusion when input is higher than 2-D#4242andrewor14 merged 1 commit intopytorch:mainfrom
Conversation
Linear always operates on the last dimension while BatchNorm1d normalizes along dim 1 (channels). These two coincide only for 2-D inputs (N, C). For higher-rank inputs like 3-D (N, C, L), fusing the BN parameters into Linear weights silently produces incorrect results because the scale/shift is applied along the wrong axis. Add an ndim check in _fuse_linear_bn_ that skips fusion and emits a warning when the linear input has more than 2 dimensions. Fixes pytorch#4116 Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4242
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 1 PendingAs of commit 9d9a724 with merge base 6e7a6e9 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot label "module: pt2e_quant" |
|
@claude review please? |
|
Claude finished @andrewor14's task in 2m 1s —— View job Code Review: [pt2e] Skip linear+bn fusion when input is higher than 2-D
Overall: This is a correct and well-targeted fix. The root cause analysis is accurate — A few observations: Production code (
|
|
Looks good, thanks for fixing this! Let's wait on CI and claude review |
Summary
Fix
_fuse_linear_bn_silently producing incorrect results when the linear input has more than 2 dimensions.Linearalways operates on the last dimension, whileBatchNorm1dnormalizes along dim 1 (channels). These coincide only for 2-D inputs(N, C). For 3-D inputs(N, C, L), the BN scale/shift is applied along the wrong axis during fusion, producing silently incorrect output.linear_input.ndimvianode.meta["val"]before fusion. Skip fusion with a warning whenndim > 2.Fixes #4116
Reproduction
Model:
Linear(3, 5)→BatchNorm1d(5)with trained (non-trivial) BN statistics.Before (3-D input fused incorrectly)
After (3-D fusion skipped, output correct)
Test results
Test plan
test_linear_bn_fusion— existing test still passes (2-D fusion works)test_linear_bn_fusion_skipped_for_3d_input— 3-D input: fusion skipped, warning emitted, output matches referencetest_linear_bn_fusion_correct_for_2d_input— 2-D input with multiple dimension combos and bias variants: fusion correct, BN removedruff check+ruff formatpass