From 567ed44fd058de0fe9b6553a28b69972d722efcb Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 3 Sep 2024 10:51:03 +0530 Subject: [PATCH] [MLIR][TORCH] Add E2E support for aten.polar op (#3671) Signed-Off By: Vivek Khandelwal --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 24 +++++++ .../TorchToLinalg/Uncategorized.cpp | 68 +++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 42 ++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 4 ++ .../build_tools/abstract_interp_lib_gen.py | 14 ++++ .../build_tools/torch_ods_gen.py | 1 + .../torch_mlir_e2e_test/test_suite/basic.py | 52 ++++++++++++++ 7 files changed, 205 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index b2cd8f307f24..91c5d2fa261d 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5332,6 +5332,30 @@ def Torch_AtenSoftshrinkOp : Torch_Op<"aten.softshrink", [ }]; } +def Torch_AtenPolarOp : Torch_Op<"aten.polar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::polar : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$abs, + AnyTorchTensorType:$angle + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenPolarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenPolarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenUnbindCopyIntOp : Torch_Op<"aten.unbind_copy.int", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 29e1e80d9732..cf4e2b4f07f0 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -3295,6 +3295,72 @@ class ConvertAtenLinalgDetOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertAtenPolarOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenPolarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + Location loc = op.getLoc(); + const TypeConverter *typeConverter = getTypeConverter(); + MLIRContext *context = rewriter.getContext(); + + Value absTensor = adaptor.getAbs(); + Value angleTensor = adaptor.getAngle(); + + RankedTensorType resultType = + cast(typeConverter->convertType(op.getType())); + auto elementType = resultType.getElementType(); + + SmallVector resultShape; + for (int64_t i = 0; i < resultType.getRank(); i++) { + auto currentDimSize = rewriter.create(loc, absTensor, i); + resultShape.push_back(currentDimSize); + } + + Value outTensor = rewriter.create( + loc, getAsOpFoldResult(resultShape), elementType); + + SmallVector outputExpr; + for (unsigned i = 0; i < resultType.getRank(); i++) { + outputExpr.push_back(getAffineDimExpr(i, context)); + } + + AffineMap identityMap = + AffineMap::get(resultType.getRank(), 0, outputExpr, op->getContext()); + + SmallVector indexingMaps{identityMap, identityMap, identityMap}; + SmallVector iteratorTypes( + resultType.getRank(), utils::IteratorType::parallel); + auto complexVar = + rewriter + .create( + loc, outTensor.getType(), ValueRange{absTensor, angleTensor}, + outTensor, indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + // out = abs⋅cos(angle) + abs⋅sin(angle)⋅j + Value abs = args[0]; + Value angle = args[1]; + Value realVal = b.create(loc, angle); + Value imagVal = b.create(loc, angle); + realVal = b.create(loc, abs, realVal); + imagVal = b.create(loc, abs, imagVal); + Value complexVal = b.create( + loc, elementType, realVal, imagVal); + b.create(loc, complexVal); + }) + .getResult(0); + rewriter.replaceOpWithNewOp(op, resultType, complexVar); + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -3355,4 +3421,6 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 67c36633232f..fb82bb914017 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6715,6 +6715,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.polar\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.mish\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -11276,6 +11280,44 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.polar\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int9 = torch.constant.int 9\n" +" %int6 = torch.constant.int 6\n" +" %int10 = torch.constant.int 10\n" +" %int7 = torch.constant.int 7\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = torch.aten.eq.int %1#1, %2#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %5:2 = torch.prim.If %4 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int10 : !torch.bool, !torch.int\n" +" } else {\n" +" %7 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" %8:2 = torch.prim.If %7 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int9 : !torch.bool, !torch.int\n" +" } else {\n" +" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %8#0, %8#1 : !torch.bool, !torch.int\n" +" }\n" +" %6 = torch.prim.If %5#0 -> (!torch.int) {\n" +" torch.prim.If.yield %5#1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %1#1 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.logit\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 7000499f0700..637593cbf836 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2428,6 +2428,8 @@ "AtenMmQMixedSigni8_basic", "AtenMmQint8_basic", "AtenMmQuint8_basic", + "AtenPolarFloatModule_basic", + "AtenPolarDoubleModule_basic", "AtenRealView128Module_basic", "AtenRealView64Module_basic", "AtenSubFloatModule_basic", @@ -3794,6 +3796,8 @@ "AtenMmQMixedSigni8_basic", "AtenMmQint8_basic", "AtenMmQuint8_basic", + "AtenPolarFloatModule_basic", + "AtenPolarDoubleModule_basic", "AtenRealView128Module_basic", "AtenRealView64Module_basic", "AtenRoundFloatHalfToEvenModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index b252b7e503d9..8bd60a7ef8ae 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -322,6 +322,9 @@ def aten〇hardshrink〡shape(self: List[int], lambd: float = 0.5) -> List[int]: def aten〇softshrink〡shape(self: List[int], lambd: float = 0.5) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇polar〡shape(abs: List[int], angle: List[int]) -> List[int]: + return upstream_shape_functions.unary(abs) + def aten〇mish〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -2595,6 +2598,17 @@ def aten〇softshrink〡dtype(self_rank_dtype: Tuple[int, int], lambd: Union[int self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) + +def aten〇polar〡dtype(abs_rank_dtype: Tuple[int, int], angle_rank_dtype: Tuple[int, int]) -> int: + _, abs_dtype = abs_rank_dtype + _, angle_dtype = angle_rank_dtype + assert (abs_dtype == angle_dtype) + if abs_dtype == torch.float64: + return torch.complex128 + elif abs_dtype == torch.float32: + return torch.complex64 + return abs_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇logit〡dtype(self_rank_dtype: Tuple[int, int], eps: Optional[float] = None) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 17f44d3422b6..6fe5248bfa97 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -501,6 +501,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::log_sigmoid : (Tensor) -> (Tensor)") emit("aten::hardshrink : (Tensor, Scalar) -> (Tensor)") emit("aten::softshrink : (Tensor, Scalar) -> (Tensor)") + emit("aten::polar : (Tensor, Tensor) -> (Tensor)") # Ops with dynamic number of outputs emit("aten::unbind_copy.int : (Tensor, int) -> (Tensor[])") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 2bda11410682..481a89b189a3 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -5761,3 +5761,55 @@ def forward(self, input): @register_test_case(module_factory=lambda: UnfoldModule()) def UnfoldModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 5, 3, 4)) + + +# ============================================================================== + + +class AtenPolarFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.unfold = torch.nn.Unfold(kernel_size=(2, 3)) + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, abs_, angle): + return torch.ops.aten.polar(torch.ops.aten.abs(abs_), angle) + + +@register_test_case(module_factory=lambda: AtenPolarFloatModule()) +def AtenPolarFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 5, 3, 4), tu.rand(2, 5, 3, 4)) + + +# ============================================================================== + + +class AtenPolarDoubleModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.unfold = torch.nn.Unfold(kernel_size=(2, 3)) + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float64, True), + ([-1, -1, -1, -1], torch.float64, True), + ] + ) + def forward(self, abs_, angle): + return torch.ops.aten.polar(torch.ops.aten.abs(abs_), angle) + + +@register_test_case(module_factory=lambda: AtenPolarDoubleModule()) +def AtenPolarDoubleModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(2, 5, 3, 4).to(torch.float64), tu.rand(2, 5, 3, 4).to(torch.float64) + )