From 670a99ae196da892310776f110cfe29dfb68a174 Mon Sep 17 00:00:00 2001 From: Ze Zhang Date: Thu, 11 Jan 2024 10:36:48 -0800 Subject: [PATCH] Handle torch.none type in tosa.clamp op (#2739) This PR updates the torch-to-tosa conversion with following changes: - Support torch.none as min/max input argument for tosa.clamp op - Support negative value as start index for tosa.slice op - Add tosa.logical_or lowering support e2e test: python -m e2e_testing.main --config=tosa LIT tests: cmake --build build --target tools/torch-mlir/all --------- Co-authored-by: Ze Zhang --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 78 +++++++++++++--------- projects/pt1/e2e_testing/xfail_sets.py | 12 ++++ test/Conversion/TorchToTosa/basic.mlir | 71 ++++++++++++++++++++ 3 files changed, 129 insertions(+), 32 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index e123522a4542..6555f06e8702 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/Matchers.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -3336,9 +3337,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) return rewriter.notifyMatchFailure(op, "start must be a Scalar constant"); - if (start < 0) - return rewriter.notifyMatchFailure(op, "Currently unsupported: start < 0"); - + if (start < 0) { + start = toPositiveDim(start, selfType.getShape()[dim]); + if (!isValidDim(start, selfType.getShape()[dim])) + return rewriter.notifyMatchFailure(op, "start is not a valid index"); + } start = std::min(selfType.getShape()[dim], start); int64_t end; @@ -3984,36 +3987,46 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "only tensor types input are currently supported"); - IntegerAttr min_int, max_int; - FloatAttr min_fp, max_fp; - if (op.getMin().getType().isa()) { - double fp_min, fp_max; - if (!matchPattern(op.getMin(), m_TorchConstantFloat(&fp_min))) - return rewriter.notifyMatchFailure( - op, "unimplemented: value `fp_min` should be a torch constant float"); - - if (!matchPattern(op.getMax(), m_TorchConstantFloat(&fp_max))) - return rewriter.notifyMatchFailure( - op, "unimplemented: value `fp_max` should be a torch constant float"); - - min_int = rewriter.getI64IntegerAttr(static_cast(fp_min)); - max_int = rewriter.getI64IntegerAttr(static_cast(fp_max)); - min_fp = rewriter.getF32FloatAttr(static_cast(fp_min)); - max_fp = rewriter.getF32FloatAttr(static_cast(fp_max)); - } else { - int64_t int_min, int_max; - if (!matchPattern(op.getMin(), m_TorchConstantInt(&int_min))) - return rewriter.notifyMatchFailure( - op, "unimplemented: value `int_min` should be a torch constant int"); - - if (!matchPattern(op.getMax(), m_TorchConstantInt(&int_max))) - return rewriter.notifyMatchFailure( - op, "unimplemented: value `int_max` should be a torch constant int"); + IntegerAttr min_int = + rewriter.getI64IntegerAttr(std::numeric_limits::min()); + IntegerAttr max_int = + rewriter.getI64IntegerAttr(std::numeric_limits::max()); + FloatAttr min_fp = + rewriter.getF32FloatAttr(std::numeric_limits::lowest()); + FloatAttr max_fp = + rewriter.getF32FloatAttr(std::numeric_limits::max()); + + auto getValAttr = [&](Value operand, IntegerAttr &intAttr, + FloatAttr &fpAttr) -> LogicalResult { + double valFloat; + int64_t valInt; + if (matchPattern(operand, m_TorchConstantFloat(&valFloat))) { + intAttr = rewriter.getI64IntegerAttr(static_cast(valFloat)); + fpAttr = rewriter.getF32FloatAttr(static_cast(valFloat)); + } else if (matchPattern(operand, m_TorchConstantInt(&valInt))) { + intAttr = rewriter.getI64IntegerAttr(valInt); + fpAttr = rewriter.getF32FloatAttr(static_cast(valInt)); + } else { + return failure(); + } + return success(); + }; - min_int = rewriter.getI64IntegerAttr(int_min); - max_int = rewriter.getI64IntegerAttr(int_max); - min_fp = rewriter.getF32FloatAttr(static_cast(int_min)); - max_fp = rewriter.getF32FloatAttr(static_cast(int_max)); + LogicalResult minAttrResult = getValAttr(op.getMin(), min_int, min_fp); + LogicalResult maxAttrResult = getValAttr(op.getMax(), max_int, max_fp); + if (failed(minAttrResult) && failed(maxAttrResult)) { + return rewriter.notifyMatchFailure( + op, "either `min` or `max` should be a torch constant"); + } + if (failed(minAttrResult) && + succeeded(checkNotNone(rewriter, op, op.getMin()))) { + return rewriter.notifyMatchFailure(op, + "min attr should be a torch constant"); + } + if (failed(maxAttrResult) && + succeeded(checkNotNone(rewriter, op, op.getMax()))) { + return rewriter.notifyMatchFailure(op, + "max attr should be a torch constant"); } auto outType = getTypeConverter()->convertType(op.getType()); @@ -5025,6 +5038,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { patterns.add>(typeConverter, context); INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp) INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp) + INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp) #undef INSERT_BINARY_PATTERN #define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 98cde05a8f73..de68680a82f2 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1035,6 +1035,15 @@ "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", "ElementwiseAtenDivIntScalarModule_basic", "ElementwiseAtenIsinfOpModule_basic", + "ElementwiseAtenLogicalOrOpBrodcastModule_basic", + "ElementwiseAtenLogicalOrOpDiffArgs1Module_basic", + "ElementwiseAtenLogicalOrOpDiffArgs2Module_basic", + "ElementwiseAtenLogicalOrOpDiffArgs3Module_basic", + "ElementwiseAtenLogicalOrOpModule_basic", + "ElementwiseAtenLogicalOrOpNegativeModule_basic", + "ElementwiseAtenLogicalOrOpPromoteBroadcastStaticShapeModule_basic", + "ElementwiseAtenLogicalOrOpRandomFloatModule_basic", + "ElementwiseAtenLogicalOrOpRandomModule_basic", "ElementwiseAtenWhereSelfModule_basic", "ElementwiseBinaryModule_basic", "ElementwiseBinaryStaticShapeModule_basic", @@ -1047,6 +1056,9 @@ "ElementwiseBitwiseXorModule_basic", "ElementwiseBitwiseXorStaticShapeModule_basic", "ElementwiseCeilModule_basic", + "ElementwiseClampMaxModule_basic", + "ElementwiseClampMinModule_basic", + "ElementwiseClampModule_basic", "ElementwiseCloneChannelsLastMemoryFormatModule_basic", "ElementwiseCloneContiguousModule_basic", "ElementwiseCloneModule_basic", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 180f48bcef2b..b36acc779547 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -645,6 +645,22 @@ func.func @torch.aten.ne.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // ----- +// CHECK-LABEL: func.func @torch.aten.logical_or$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],i1>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],i1> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],i1> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.logical_or %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +// CHECK: } +func.func @torch.aten.logical_or$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.logical_or %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + // CHECK-LABEL: func.func @forward( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,2,4],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,2],f32> -> tensor<3x4x2xf32> @@ -1055,6 +1071,61 @@ func.func @torch.aten.Scalar$basic(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> return %0 : !torch.vtensor<[1,1,128,128],si64> } +// ----- +// CHECK-LABEL: func.func @torch.aten.slice.negative_start( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,16,256],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 100 +// CHECK: %[[VAL_5:.*]] = torch.constant.int -16 +// CHECK: %[[VAL_4:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<4x65x256xf32>) -> tensor<4x16x256xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x16x256xf32> -> !torch.vtensor<[4,16,256],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,16,256],f32> +// CHECK: } +func.func @torch.aten.slice.negative_start(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,16,256],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int100 = torch.constant.int 100 + %int-16 = torch.constant.int -16 + %0 = torch.aten.slice.Tensor %arg0, %int1, %int-16, %int100, %int1 : !torch.vtensor<[4,65,256],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,16,256],f32> + return %0 : !torch.vtensor<[4,16,256],f32> +} + +// ----- +// CHECK-LABEL: func.func @torch.aten.clamp.min_none( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,128,128],si64> -> tensor<1x1x128x128xi64> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.none +// CHECK: %[[VAL_4:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0.000000e+00 : f32, max_int = 0 : i64, min_fp = -3.40282347E+38 : f32, min_int = -9223372036854775808 : i64} : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi64> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[1,1,128,128],si64> +// CHECK: } +func.func @torch.aten.clamp.min_none(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> { + %int0 = torch.constant.int 0 + %none = torch.constant.none + %0 = torch.aten.clamp %arg0, %none, %int0 : !torch.vtensor<[1,1,128,128],si64>, !torch.none, !torch.int -> !torch.vtensor<[1,1,128,128],si64> + return %0 : !torch.vtensor<[1,1,128,128],si64> +} + +// ----- +// CHECK-LABEL: func.func @torch.aten.clamp.max_none( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,128,128],si64> -> tensor<1x1x128x128xi64> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.none +// CHECK: %[[VAL_4:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 3.40282347E+38 : f32, max_int = 9223372036854775807 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi64> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[1,1,128,128],si64> +// CHECK: } +func.func @torch.aten.clamp.max_none(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> { + %int0 = torch.constant.int 0 + %none = torch.constant.none + %0 = torch.aten.clamp %arg0, %int0, %none : !torch.vtensor<[1,1,128,128],si64>, !torch.int, !torch.none -> !torch.vtensor<[1,1,128,128],si64> + return %0 : !torch.vtensor<[1,1,128,128],si64> +} + // ----- // CHECK-LABEL: func.func @torch.aten.clamp( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> {