[Torch] Add aten.smooth_l1_loss decomposition#4614
Draft
alex1xu wants to merge 1 commit into
Draft
Conversation
aten.smooth_l1_loss had no first-class torch dialect op, so the FX importer emitted it as a generic torch.operator and LowerToBackendContract failed to legalize it. Add an ODS op, abstract-interp shape/dtype functions, and a decomposition into elementwise ops plus a reduction, mirroring the PyTorch reference: with z = |self - target|, elem = 0.5*z*z/beta when z < beta else z - 0.5*beta, then the reduction. beta == 0 degenerates to L1 and beta < 0 is rejected via notifyMatchFailure, matching PyTorch's runtime check. Part of llvm#4575
b1eb5e1 to
057889b
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
torch.nn.functional.smooth_l1_lossfails to compile through the torchdynamo-export-to-torch-backend pipeline witherror: failed to legalize operation 'torch.operator' that was explicitly marked illegal(KernelBench level1/96). The FX importer emitstorch.operator "aten.smooth_l1_loss"because there is no first-class op, andLowerToBackendContractaborts on it.This adds
aten.smooth_l1_lossas a first-class op and decomposes it into elementwise ops plus a reduction, following the PyTorch kernel (aten/src/ATen/native/cpu/BinaryOpsKernel.cpp): withz = |self - target|, the per-element loss is0.5*z*z/betawhenz < betaandz - 0.5*betaotherwise, then the reduction.DecomposeAtenSmoothL1LossOpbuilds the piecewise form withaten.where.selfover anaten.lt.Scalarmask (strict<, matching the kernel) and followsDecomposeAtenL1LossOpfor the reduction (none keeps the elementwise shape, mean issum/numel, sum issum).beta == 0is emitted directly as L1 (abs(self - target)): PyTorch has no special case, butz < 0is never true so its kernel always takes the linear branchz - 0, which equals L1, and the shortcut avoids a 0/0 in the quadratic term.beta < 0is rejected withnotifyMatchFailure, matching PyTorch'sTORCH_CHECK(beta >= 0).ODS, shape, and dtype mirror
aten.l1_loss(the closest precedent) with the addedTorch_FloatType:$beta; the pattern registration andtarget.addIllegalOp<AtenSmoothL1LossOp>()sit next to the L1 entries.Testing
test/Dialect/Torch/decompose-complex-ops.mlir:smooth_l1_loss_mean(beta 3.0, so the1.5/0.5/3.0constants are distinct and prove beta routing throughmul.Scalar/div.Scalar/lt.Scalar);smooth_l1_loss_sum(CHECK-NOT: numel);smooth_l1_loss_beta_zero(degenerates toabs,CHECK-NOT: where.self);smooth_l1_loss_negative_beta(op left undecomposed,CHECK-NOTonwhere.self/abs).test_suite/reduction.py): the three reduction modes; a non-default beta (SmoothL1LossBetaModule, whose inputs straddle the quadratic/linear branches); and mixed input dtypes (SmoothL1LossDifferentElemTypeModule, exercising the dtype-promotion path).xfail_sets.pyupdated — ONNX has no import path; the decomposition's ops lower on linalg/TOSA/stablehlo.Part of #4575.