Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -10251,6 +10251,30 @@ def Torch_AtenDiagEmbedOp : Torch_Op<"aten.diag_embed", [
}];
}

def Torch_AtenDiagOp : Torch_Op<"aten.diag", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::diag : (Tensor, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$diagonal
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenDiagOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenDiagOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_Aten_WeightNormInterfaceOp : Torch_Op<"aten._weight_norm_interface", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
56 changes: 56 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7025,6 +7025,58 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %18 = torch.aten.append.t %7, %17 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" return %7 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.diag\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: input must be 1D or 2D\"\n"
" %true = torch.constant.bool true\n"
" %int1 = torch.constant.int 1\n"
" %int2 = torch.constant.int 2\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %7 = torch.aten.eq.int %6, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %7 : !torch.bool\n"
" }\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %4 = torch.aten.eq.int %3, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" %5 = torch.prim.If %4 -> (!torch.list<int>) {\n"
" %6 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %7 = torch.operator \"prim.abs.int\"(%arg1) : (!torch.int) -> !torch.int \n"
" %8 = torch.aten.add.int %6, %7 : !torch.int, !torch.int -> !torch.int\n"
" %9 = torch.prim.ListConstruct %8, %8 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %9 : !torch.list<int>\n"
" } else {\n"
" %6 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %7 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %8 = torch.aten.sub.int %7, %arg1 : !torch.int, !torch.int -> !torch.int\n"
" %9 = torch.prim.min.int %6, %8 : !torch.int, !torch.int -> !torch.int\n"
" %10 = torch.prim.max.int %9, %int0 : !torch.int, !torch.int -> !torch.int\n"
" %11 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %12 = torch.prim.If %11 -> (!torch.int) {\n"
" %14 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %15 = torch.aten.add.int %14, %arg1 : !torch.int, !torch.int -> !torch.int\n"
" %16 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %17 = torch.prim.min.int %15, %16 : !torch.int, !torch.int -> !torch.int\n"
" %18 = torch.prim.max.int %17, %int0 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %18 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %10 : !torch.int\n"
" }\n"
" %13 = torch.prim.ListConstruct %12 : (!torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %13 : !torch.list<int>\n"
" }\n"
" return %5 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.fake_quantize_per_tensor_affine\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -14301,6 +14353,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.diag\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.uniform\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
43 changes: 43 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10354,6 +10354,48 @@ class DecomposeAtenL1LossOp : public OpRewritePattern<AtenL1LossOp> {
};
} // namespace

namespace {
// `aten.diag` is a rank-dependent alias (aten/src/ATen/native/TensorShape.cpp):
// a 1-D input builds a 2-D matrix with the input on the `diagonal`-th diagonal
// (aten.diag_embed); a 2-D input extracts the `diagonal`-th diagonal as a 1-D
// tensor (aten.diagonal).
class DecomposeAtenDiagOp : public OpRewritePattern<AtenDiagOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenDiagOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
std::optional<unsigned> maybeRank = getTensorRank(self);
if (!maybeRank)
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
unsigned rank = *maybeRank;
if (rank != 1 && rank != 2)
return rewriter.notifyMatchFailure(op,
"expected input to be rank 1 or 2");

Value zero =
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0));
Value one =
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1));
Value diagonal = op.getDiagonal();

if (rank == 1) {
Value embed = AtenDiagEmbedOp::create(rewriter, loc, op.getType(), self,
/*offset=*/diagonal, /*dim1=*/zero,
/*dim2=*/one);
rewriter.replaceOp(op, embed);
return success();
}
Value diag = AtenDiagonalOp::create(rewriter, loc, op.getType(), self,
/*offset=*/diagonal, /*dim1=*/zero,
/*dim2=*/one);
rewriter.replaceOp(op, diag);
return success();
}
};
} // namespace

namespace {
// Decompose `aten.norm.ScalarOpt_dim` op to `aten.linalg_vector_norm` op
class DecomposeAtenNormScalarOptDimOp
Expand Down Expand Up @@ -13620,6 +13662,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenLiftFreshCopyOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMseLossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenL1LossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenDiagOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNormScalarOptDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandintOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandintLowOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenLerpTensorOp>();
target.addIllegalOp<AtenMseLossOp>();
target.addIllegalOp<AtenL1LossOp>();
target.addIllegalOp<AtenDiagOp>();
target.addIllegalOp<AtenPoissonNllLossOp>();
target.addIllegalOp<AtenRandintLowOp>();
target.addIllegalOp<AtenRandintOp>();
Expand Down
19 changes: 19 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@
}

LINALG_CRASHING_SET = {
# aten.diag with a nonzero offset lowers to aten.diag_embed, whose linalg
# lowering does an out-of-bounds extract (same crash as AtenDiagEmbedOffsetDiag).
"Diag1DOffsetModule_basic",
# Runtime op verification: Out of bounds access
"AtenDiagEmbedNegOffsetDiag_basic",
"AtenDiagEmbedNonDefault4DDiag_basic",
Expand Down Expand Up @@ -550,6 +553,12 @@
}

FX_IMPORTER_STABLEHLO_XFAIL_SET = {
# aten.diag decomposes to aten.diag_embed / aten.diagonal, which are xfail
# on the stablehlo backend (same as AtenDiagEmbed*/DiagonalModule).
"Diag1DModule_basic",
"Diag1DOffsetModule_basic",
"Diag2DModule_basic",
"Diag2DNegativeOffsetModule_basic",
"ArgsortTensor_basic",
"ArgsortTensorInteger_basic",
"AddFloatIntModule_basic",
Expand Down Expand Up @@ -2779,6 +2788,11 @@
}

ONNX_XFAIL_SET = {
# aten.diag has no ONNX import path, so these fail at import time.
"Diag1DModule_basic",
"Diag1DOffsetModule_basic",
"Diag2DModule_basic",
"Diag2DNegativeOffsetModule_basic",
"ToDtypeIntFromFloatModule_basic",
# This test is expected to time out
"TimeOutModule_basic",
Expand Down Expand Up @@ -4101,6 +4115,11 @@
}

ONNX_TOSA_XFAIL_SET = {
# aten.diag has no ONNX import path, so these fail at import time.
"Diag1DModule_basic",
"Diag1DOffsetModule_basic",
"Diag2DModule_basic",
"Diag2DNegativeOffsetModule_basic",
"AtenFftRfft2DLastDim_basic",
"AtenFftRfft2DMiddleDim_basic",
"AtenStftCenter1D_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,26 @@ def aten〇diagonal〡shape(self: List[int], offset: int = 0, dim1: int = 0, dim

return diagonal

@check_shape_function([
Invocation(TensorOfShape(3)), # 1-D: builds a square matrix.
Invocation(TensorOfShape(3), diagonal=2), # 1-D positive offset.
Invocation(TensorOfShape(3), diagonal=-2), # 1-D negative offset.
Invocation(TensorOfShape(3, 5)), # 2-D: extracts the main diagonal.
Invocation(TensorOfShape(3, 5), diagonal=1), # 2-D positive offset.
Invocation(TensorOfShape(3, 5), diagonal=-1), # 2-D negative offset.
Invocation(TensorOfShape(5, 3), diagonal=3), # 2-D empty result (offset past last column).
ErrorInvocation(TensorOfShape(2, 3, 4)), # Rank > 2 unsupported.
])
def aten〇diag〡shape(self: List[int], diagonal: int = 0) -> List[int]:
assert len(self) == 1 or len(self) == 2, "input must be 1D or 2D"
if len(self) == 1:
side = self[0] + abs(diagonal)
return [side, side]
diag_size = max(min(self[0], self[1] - diagonal), 0)
if diagonal < 0:
diag_size = max(min(self[0] + diagonal, self[1]), 0)
return [diag_size]

def aten〇fake_quantize_per_tensor_affine〡shape(self: List[int], scale: float, zero_point: int, quant_min: int, quant_max: int) -> List[int]:
return upstream_shape_functions.unary(self)

Expand Down Expand Up @@ -4065,6 +4085,11 @@ def aten〇diagonal〡dtype(self_rank_dtype: Tuple[int, int], offset: int = 0, d
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def aten〇diag〡dtype(self_rank_dtype: Tuple[int, int], diagonal: int = 0) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def aten〇uniform〡dtype(self_rank_dtype: Tuple[int, int], from_: float = 0., to: float = 1., generator: Any = None) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,7 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::cosine_embedding_loss : (Tensor, Tensor, Tensor, float, int) -> (Tensor)"
)
emit("aten::diag_embed : (Tensor, int, int, int) -> (Tensor)")
emit("aten::diag : (Tensor, int) -> (Tensor)")
emit("aten::_weight_norm_interface : (Tensor, Tensor, int) -> (Tensor, Tensor)")
emit("aten::rot90 : (Tensor, int, int[]) -> (Tensor)", has_verifier=True)
emit("aten::count_nonzero : (Tensor, int?) -> (Tensor)", has_verifier=True)
Expand Down
92 changes: 92 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,95 @@ def forward(self, a):
@register_test_case(module_factory=lambda: DiagonalWithDimsOffsetModule())
def DiagonalModule_with_dims_and_offset(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))


# ==============================================================================


class Diag1DModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1], torch.float32, True),
]
)
def forward(self, a):
return torch.ops.aten.diag(a)


@register_test_case(module_factory=lambda: Diag1DModule())
def Diag1DModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5))


# ==============================================================================


class Diag1DOffsetModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1], torch.float32, True),
]
)
def forward(self, a):
return torch.ops.aten.diag(a, diagonal=2)


@register_test_case(module_factory=lambda: Diag1DOffsetModule())
def Diag1DOffsetModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5))


# ==============================================================================


class Diag2DModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
]
)
def forward(self, a):
return torch.ops.aten.diag(a)


@register_test_case(module_factory=lambda: Diag2DModule())
def Diag2DModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 6))


# ==============================================================================


class Diag2DNegativeOffsetModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
]
)
def forward(self, a):
return torch.ops.aten.diag(a, diagonal=-1)


@register_test_case(module_factory=lambda: Diag2DNegativeOffsetModule())
def Diag2DNegativeOffsetModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 6))
Loading