Skip to content

[Torch] Add aten.smooth_l1_loss decomposition#4614

Draft
alex1xu wants to merge 1 commit into
llvm:mainfrom
alex1xu:smooth-l1-loss-decomp
Draft

[Torch] Add aten.smooth_l1_loss decomposition#4614
alex1xu wants to merge 1 commit into
llvm:mainfrom
alex1xu:smooth-l1-loss-decomp

Conversation

@alex1xu

@alex1xu alex1xu commented Jun 18, 2026

Copy link
Copy Markdown

torch.nn.functional.smooth_l1_loss fails to compile through the torchdynamo-export-to-torch-backend pipeline with error: failed to legalize operation 'torch.operator' that was explicitly marked illegal (KernelBench level1/96). The FX importer emits torch.operator "aten.smooth_l1_loss" because there is no first-class op, and LowerToBackendContract aborts on it.

This adds aten.smooth_l1_loss as a first-class op and decomposes it into elementwise ops plus a reduction, following the PyTorch kernel (aten/src/ATen/native/cpu/BinaryOpsKernel.cpp): with z = |self - target|, the per-element loss is 0.5*z*z/beta when z < beta and z - 0.5*beta otherwise, then the reduction. DecomposeAtenSmoothL1LossOp builds the piecewise form with aten.where.self over an aten.lt.Scalar mask (strict <, matching the kernel) and follows DecomposeAtenL1LossOp for the reduction (none keeps the elementwise shape, mean is sum/numel, sum is sum). beta == 0 is emitted directly as L1 (abs(self - target)): PyTorch has no special case, but z < 0 is never true so its kernel always takes the linear branch z - 0, which equals L1, and the shortcut avoids a 0/0 in the quadratic term. beta < 0 is rejected with notifyMatchFailure, matching PyTorch's TORCH_CHECK(beta >= 0).

ODS, shape, and dtype mirror aten.l1_loss (the closest precedent) with the added Torch_FloatType:$beta; the pattern registration and target.addIllegalOp<AtenSmoothL1LossOp>() sit next to the L1 entries.

Testing

  • test/Dialect/Torch/decompose-complex-ops.mlir: smooth_l1_loss_mean (beta 3.0, so the 1.5/0.5/3.0 constants are distinct and prove beta routing through mul.Scalar/div.Scalar/lt.Scalar); smooth_l1_loss_sum (CHECK-NOT: numel); smooth_l1_loss_beta_zero (degenerates to abs, CHECK-NOT: where.self); smooth_l1_loss_negative_beta (op left undecomposed, CHECK-NOT on where.self/abs).
  • e2e (test_suite/reduction.py): the three reduction modes; a non-default beta (SmoothL1LossBetaModule, whose inputs straddle the quadratic/linear branches); and mixed input dtypes (SmoothL1LossDifferentElemTypeModule, exercising the dtype-promotion path). xfail_sets.py updated — ONNX has no import path; the decomposition's ops lower on linalg/TOSA/stablehlo.

Part of #4575.

aten.smooth_l1_loss had no first-class torch dialect op, so the FX importer
emitted it as a generic torch.operator and LowerToBackendContract failed to
legalize it. Add an ODS op, abstract-interp shape/dtype functions, and a
decomposition into elementwise ops plus a reduction, mirroring the PyTorch
reference: with z = |self - target|, elem = 0.5*z*z/beta when z < beta else
z - 0.5*beta, then the reduction. beta == 0 degenerates to L1 and beta < 0 is
rejected via notifyMatchFailure, matching PyTorch's runtime check.

Part of llvm#4575
@alex1xu alex1xu force-pushed the smooth-l1-loss-decomp branch from b1eb5e1 to 057889b Compare June 18, 2026 03:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant