[Torch] Add aten.diag decomposition#4613
Draft
alex1xu wants to merge 1 commit into
Draft
Conversation
aten.diag had no first-class torch dialect op, so the FX importer emitted it as a generic torch.operator and LowerToBackendContract failed to legalize it. Add an ODS op, abstract-interp shape/dtype functions, and a decomposition: a 1-D input builds a 2-D matrix via aten.diag_embed; a 2-D input extracts the diagonal via aten.diagonal (mirroring PyTorch's TensorShape.cpp dispatch). The 1-D nonzero-offset case (Diag1DOffsetModule) is added to LINALG_CRASHING_SET because aten.diag_embed's existing linalg lowering does an out-of-bounds extract for nonzero offset (tracked separately). Part of llvm#4575
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
torch.diag(A)fails to compile through the torchdynamo-export-to-torch-backend pipeline witherror: failed to legalize operation 'torch.operator' that was explicitly marked illegal(KernelBench level1/12,return torch.diag(A) @ B). The FX importer has no first-class op foraten.diag, so it emits a generictorch.operator "aten.diag";LowerToBackendContractmarks all such ops illegal and aborts because nothing resolves or decomposes it.This adds
aten.diagas a first-class op and decomposes it.torch.diagis rank-dependent (PyTorchaten/src/ATen/native/TensorShape.cpp): a 1-D input builds a 2-D matrix with the input on thediagonal-th diagonal, and a 2-D input extracts thediagonal-th diagonal as 1-D.DecomposeAtenDiagOpmirrors that dispatch: rank 1 lowers toaten.diag_embed(self, offset=diagonal, dim1=0, dim2=1), rank 2 toaten.diagonal(self, offset=diagonal, dim1=0, dim2=1), with an earlynotifyMatchFailurefor any other rank. Reusing the existingdiag_embed/diagonallowerings rather than adding a new TorchToLinalg pattern keeps the change to a decomposition and inherits those ops' backend coverage.aten::diagintorch_ods_gen.py; shape/dtype inabstract_interp_lib_gen.py(aten〇diag〡shape/〡dtype), regenerated into the checked-inGeneratedTorchOps.td/AbstractInterpLibrary.cpp.addPatternIfTargetOpIsIllegal<DecomposeAtenDiagOp>andtarget.addIllegalOp<AtenDiagOp>()inLowerToBackendContract.cpp.aten.diagwith a nonzero offset on a 1-D input routes toaten.diag_embed, whose existing linalg lowering does an out-of-bounds extract for nonzero offset (the same crash that already listsAtenDiagEmbedOffsetDiaginLINALG_CRASHING_SET);Diag1DOffsetModuleis added there for parity, and fixing that lowering is left as separate work.Testing
test/Dialect/Torch/decompose-complex-ops.mlir:diag_1d(offset 2, lowers todiag_embed),diag_2d(offset 3, lowers todiagonal), anddiag_rank3_no_decompose(rank-3 left intact). Offsets are chosen distinct from thedim1=0/dim2=1constants so FileCheck pins operand wiring rather than matching coincidentally.aten〇diag〡shapecarries a@check_shape_functionthat validates the 1-D/2-D, positive/negative-offset, empty-result, and rank-3-error cases against eagertorch.diagat library-generation time.projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py):Diag1DModule,Diag1DOffsetModule,Diag2DModule,Diag2DNegativeOffsetModule, withxfail_sets.pyupdated — ONNX has noaten.diagimport path; stablehlo mirrors the existingdiag_embed/diagonalxfails; the 1-D nonzero-offset case isLINALG_CRASHING.Part of #4575.