From 0ea235e4ed37bdad802915ef984cc79da6349a61 Mon Sep 17 00:00:00 2001 From: Raghavan Raman Date: Fri, 15 Sep 2023 20:29:45 -0700 Subject: [PATCH] [Tcp] Handle int inputs in sqrt (#2467) This PR adds support for integer inputs in `tcp.sqrt`. --- .../Dialect/Tcp/IR/TcpBase.td | 3 +- .../Dialect/Tcp/IR/TcpOps.td | 28 ++++----- lib/Conversion/TorchToTcp/DataMovement.cpp | 4 +- lib/Conversion/TorchToTcp/Elementwise.cpp | 60 +++++++++++-------- test/Conversion/TorchToTcp/elementwise.mlir | 13 ++++ 5 files changed, 64 insertions(+), 44 deletions(-) diff --git a/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpBase.td b/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpBase.td index 1edab840dc6c..81ea87c437c9 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpBase.td +++ b/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpBase.td @@ -88,8 +88,7 @@ class Tcp_UnaryElementwiseOp traits = []> : Tcp_Op { + SameOperandsAndResultShape])> { } class Tcp_BinaryElementwiseOp traits = []> : diff --git a/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpOps.td b/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpOps.td index ed26382d11dd..4a5f00751ee5 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpOps.td +++ b/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpOps.td @@ -15,7 +15,7 @@ include "torch-mlir-dialects/Dialect/Tcp/IR/TcpEnums.td" include "mlir/IR/OpBase.td" -def Tcp_TanhOp : Tcp_UnaryElementwiseOp<"tanh"> { +def Tcp_TanhOp : Tcp_UnaryElementwiseOp<"tanh", [SameOperandsAndResultElementType]> { let summary = "Computes tanh of input, elementwise"; let description = [{ @@ -33,7 +33,7 @@ def Tcp_TanhOp : Tcp_UnaryElementwiseOp<"tanh"> { let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)"; } -def Tcp_ClampOp : Tcp_UnaryElementwiseOp<"clamp"> { +def Tcp_ClampOp : Tcp_UnaryElementwiseOp<"clamp", [SameOperandsAndResultElementType]> { let summary = "Clamps input tensor to the given min and/or max"; let description = [{ @@ -65,7 +65,7 @@ def Tcp_ClampOp : Tcp_UnaryElementwiseOp<"clamp"> { let hasVerifier = 1; } -def Tcp_SigmoidOp : Tcp_UnaryElementwiseOp<"sigmoid"> { +def Tcp_SigmoidOp : Tcp_UnaryElementwiseOp<"sigmoid", [SameOperandsAndResultElementType]> { let summary = "Computes sigmoid of input, elementwise"; let description = [{ @@ -312,7 +312,7 @@ def Tcp_IsolatedGroupOp : Tcp_Op<"isolated_group", [ let hasVerifier = 1; } -def Tcp_SqrtOp : Tcp_UnaryElementwiseOp<"sqrt", [SameOperandsAndResultElementType]> { +def Tcp_SqrtOp : Tcp_UnaryElementwiseOp<"sqrt"> { let summary = "Computes square root of input, elementwise"; let description = [{ @@ -320,11 +320,11 @@ def Tcp_SqrtOp : Tcp_UnaryElementwiseOp<"sqrt", [SameOperandsAndResultElementTyp }]; let arguments = (ins - Tcp_FloatOrComplexTensor:$in + Tcp_FloatOrIntTensor:$in ); let results = (outs - Tcp_FloatOrComplexTensor:$out + Tcp_FloatTensor:$out ); let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)"; @@ -351,7 +351,7 @@ def Tcp_ConcatOp : Tcp_Op<"concat", [SameOperandsAndResultElementType]> { let hasVerifier = 1; } -def Tcp_CeilOp : Tcp_UnaryElementwiseOp<"ceil"> { +def Tcp_CeilOp : Tcp_UnaryElementwiseOp<"ceil", [SameOperandsAndResultElementType]> { let summary = "Computes ceil of input, elementwise"; let description = [{ @@ -369,7 +369,7 @@ def Tcp_CeilOp : Tcp_UnaryElementwiseOp<"ceil"> { let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)"; } -def Tcp_FloorOp : Tcp_UnaryElementwiseOp<"floor"> { +def Tcp_FloorOp : Tcp_UnaryElementwiseOp<"floor", [SameOperandsAndResultElementType]> { let summary = "Computes floor of input, elementwise"; let description = [{ @@ -387,7 +387,7 @@ def Tcp_FloorOp : Tcp_UnaryElementwiseOp<"floor"> { let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)"; } -def Tcp_CosOp : Tcp_UnaryElementwiseOp<"cos"> { +def Tcp_CosOp : Tcp_UnaryElementwiseOp<"cos", [SameOperandsAndResultElementType]> { let summary = "Computes cosine of input, elementwise"; let description = [{ @@ -405,7 +405,7 @@ def Tcp_CosOp : Tcp_UnaryElementwiseOp<"cos"> { let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)"; } -def Tcp_SinOp : Tcp_UnaryElementwiseOp<"sin"> { +def Tcp_SinOp : Tcp_UnaryElementwiseOp<"sin", [SameOperandsAndResultElementType]> { let summary = "Computes sine of input, elementwise"; let description = [{ @@ -423,7 +423,7 @@ def Tcp_SinOp : Tcp_UnaryElementwiseOp<"sin"> { let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)"; } -def Tcp_AbsOp : Tcp_UnaryElementwiseOp<"abs"> { +def Tcp_AbsOp : Tcp_UnaryElementwiseOp<"abs", [SameOperandsAndResultElementType]> { let summary = "Computes absolute of input, elementwise"; let description = [{ @@ -441,7 +441,7 @@ def Tcp_AbsOp : Tcp_UnaryElementwiseOp<"abs"> { let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)"; } -def Tcp_LogOp : Tcp_UnaryElementwiseOp<"log"> { +def Tcp_LogOp : Tcp_UnaryElementwiseOp<"log", [SameOperandsAndResultElementType]> { let summary = "Computes natural logarithm of input, elementwise"; let description = [{ @@ -459,7 +459,7 @@ def Tcp_LogOp : Tcp_UnaryElementwiseOp<"log"> { let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)"; } -def Tcp_NegOp : Tcp_UnaryElementwiseOp<"neg"> { +def Tcp_NegOp : Tcp_UnaryElementwiseOp<"neg", [SameOperandsAndResultElementType]> { let summary = "Computes the negation of input, elementwise"; let description = [{ @@ -477,7 +477,7 @@ def Tcp_NegOp : Tcp_UnaryElementwiseOp<"neg"> { let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)"; } -def Tcp_AtanOp : Tcp_UnaryElementwiseOp<"atan"> { +def Tcp_AtanOp : Tcp_UnaryElementwiseOp<"atan", [SameOperandsAndResultElementType]> { let summary = "Computes the arcus tangent value of input, elementwise"; let description = [{ diff --git a/lib/Conversion/TorchToTcp/DataMovement.cpp b/lib/Conversion/TorchToTcp/DataMovement.cpp index 92072e953601..9fe5b6480e9a 100644 --- a/lib/Conversion/TorchToTcp/DataMovement.cpp +++ b/lib/Conversion/TorchToTcp/DataMovement.cpp @@ -35,8 +35,8 @@ class ConvertAtenCatOp : public OpConversionPattern { return rewriter.notifyMatchFailure( catOp, "aten.cat operands must be a list of tensors"); - SmallVector tensorInputs = getTypeConvertedValues( - rewriter, catOp->getLoc(), getTypeConverter(), inputs); + auto tensorInputs = getTypeConvertedValues(rewriter, catOp->getLoc(), + getTypeConverter(), inputs); int64_t dim; if (!matchPattern(catOp.getDim(), m_TorchConstantInt(&dim))) diff --git a/lib/Conversion/TorchToTcp/Elementwise.cpp b/lib/Conversion/TorchToTcp/Elementwise.cpp index e8fb580ba21e..276f0958259c 100644 --- a/lib/Conversion/TorchToTcp/Elementwise.cpp +++ b/lib/Conversion/TorchToTcp/Elementwise.cpp @@ -447,12 +447,14 @@ class ConvertAtenReluOp : public OpConversionPattern { } }; -class ConvertAtenAbsOp : public OpConversionPattern { +template +class ConvertAtenUnaryIntOrFpOp : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult - matchAndRewrite(AtenAbsOp op, OpAdaptor adaptor, + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value input = adaptor.getSelf(); RankedTensorType inputType = input.getType().dyn_cast(); @@ -464,13 +466,18 @@ class ConvertAtenAbsOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "Abs input tensor must have integer or floating-point datatype"); - rewriter.replaceOpWithNewOp(op, inputType, input); + RankedTensorType resultType = + OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template cast(); + + rewriter.replaceOpWithNewOp(op, resultType, input); return success(); } }; template -class ConvertAtenUnaryOp : public OpConversionPattern { +class ConvertAtenUnaryFpOnlyOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; @@ -680,7 +687,6 @@ void torch_to_tcp::populateElementwisePatternsAndLegality( target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -688,29 +694,31 @@ void torch_to_tcp::populateElementwisePatternsAndLegality( target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, + patterns.add>( + typeConverter, context); + patterns.add>(typeConverter, + context); + patterns.add>( + typeConverter, context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); target.addIllegalOp(); - patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add>(typeConverter, + context); + patterns.add>( + typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/test/Conversion/TorchToTcp/elementwise.mlir b/test/Conversion/TorchToTcp/elementwise.mlir index ba0e33db0d69..9f3855a3fab5 100644 --- a/test/Conversion/TorchToTcp/elementwise.mlir +++ b/test/Conversion/TorchToTcp/elementwise.mlir @@ -327,6 +327,19 @@ func.func @torch.aten.sqrt(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[ // ----- +// CHECK-LABEL: func.func @torch.aten.sqrt_int( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[T1:.*]] = tcp.sqrt %[[T0]] : tensor -> tensor +// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.sqrt_int(%arg0: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.sqrt %arg0 : !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + // CHECK-LABEL: func.func @torch.aten.ceil( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor