Skip to content

[Torch] Add aten.diag decomposition#4613

Draft
alex1xu wants to merge 1 commit into
llvm:mainfrom
alex1xu:diag-decomp
Draft

[Torch] Add aten.diag decomposition#4613
alex1xu wants to merge 1 commit into
llvm:mainfrom
alex1xu:diag-decomp

Conversation

@alex1xu

@alex1xu alex1xu commented Jun 18, 2026

Copy link
Copy Markdown

torch.diag(A) fails to compile through the torchdynamo-export-to-torch-backend pipeline with error: failed to legalize operation 'torch.operator' that was explicitly marked illegal (KernelBench level1/12, return torch.diag(A) @ B). The FX importer has no first-class op for aten.diag, so it emits a generic torch.operator "aten.diag"; LowerToBackendContract marks all such ops illegal and aborts because nothing resolves or decomposes it.

This adds aten.diag as a first-class op and decomposes it. torch.diag is rank-dependent (PyTorch aten/src/ATen/native/TensorShape.cpp): a 1-D input builds a 2-D matrix with the input on the diagonal-th diagonal, and a 2-D input extracts the diagonal-th diagonal as 1-D. DecomposeAtenDiagOp mirrors that dispatch: rank 1 lowers to aten.diag_embed(self, offset=diagonal, dim1=0, dim2=1), rank 2 to aten.diagonal(self, offset=diagonal, dim1=0, dim2=1), with an early notifyMatchFailure for any other rank. Reusing the existing diag_embed/diagonal lowerings rather than adding a new TorchToLinalg pattern keeps the change to a decomposition and inherits those ops' backend coverage.

  • ODS: aten::diag in torch_ods_gen.py; shape/dtype in abstract_interp_lib_gen.py (aten〇diag〡shape/〡dtype), regenerated into the checked-in GeneratedTorchOps.td / AbstractInterpLibrary.cpp.
  • Registration: addPatternIfTargetOpIsIllegal<DecomposeAtenDiagOp> and target.addIllegalOp<AtenDiagOp>() in LowerToBackendContract.cpp.

aten.diag with a nonzero offset on a 1-D input routes to aten.diag_embed, whose existing linalg lowering does an out-of-bounds extract for nonzero offset (the same crash that already lists AtenDiagEmbedOffsetDiag in LINALG_CRASHING_SET); Diag1DOffsetModule is added there for parity, and fixing that lowering is left as separate work.

Testing

  • test/Dialect/Torch/decompose-complex-ops.mlir: diag_1d (offset 2, lowers to diag_embed), diag_2d (offset 3, lowers to diagonal), and diag_rank3_no_decompose (rank-3 left intact). Offsets are chosen distinct from the dim1=0/dim2=1 constants so FileCheck pins operand wiring rather than matching coincidentally.
  • aten〇diag〡shape carries a @check_shape_function that validates the 1-D/2-D, positive/negative-offset, empty-result, and rank-3-error cases against eager torch.diag at library-generation time.
  • e2e (projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py): Diag1DModule, Diag1DOffsetModule, Diag2DModule, Diag2DNegativeOffsetModule, with xfail_sets.py updated — ONNX has no aten.diag import path; stablehlo mirrors the existing diag_embed/diagonal xfails; the 1-D nonzero-offset case is LINALG_CRASHING.

Part of #4575.

aten.diag had no first-class torch dialect op, so the FX importer emitted it
as a generic torch.operator and LowerToBackendContract failed to legalize it.
Add an ODS op, abstract-interp shape/dtype functions, and a decomposition:
a 1-D input builds a 2-D matrix via aten.diag_embed; a 2-D input extracts the
diagonal via aten.diagonal (mirroring PyTorch's TensorShape.cpp dispatch).

The 1-D nonzero-offset case (Diag1DOffsetModule) is added to LINALG_CRASHING_SET
because aten.diag_embed's existing linalg lowering does an out-of-bounds extract
for nonzero offset (tracked separately).

Part of llvm#4575
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant