From 0ee752bc688c49d3b55995bb85db9262c8fdaad0 Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT Date: Sat, 13 Jan 2024 08:44:23 +0530 Subject: [PATCH] ADDED SUPPORT FLOAT VALUE IN ARANGE --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 59 ++++-- projects/pt1/e2e_testing/xfail_sets.py | 214 +++++++++++++++++++++ 2 files changed, 260 insertions(+), 13 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index b49c9af8adce..0b5819631f52 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -26,6 +26,7 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include using namespace mlir; using namespace mlir::torch; @@ -4067,28 +4068,60 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "unimplemented: pin_memory must be either None or false"); } - int64_t start, step, end; - if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) + double start, step, end; + int64_t start_int, step_int, end_int; + bool is_all_inp_int; //Flag to check whether all inputs are integer + is_all_inp_int = op.getStart().getType().isa() && op.getEnd().getType().isa() && op.getStep().getType().isa(); + + if (matchPattern(op.getStart(), m_TorchConstantInt(&start_int))) + { + start = (double)(start_int); + } + + else if(!matchPattern(op.getStart(), m_TorchConstantFloat(&start))) return rewriter.notifyMatchFailure( - op, "unimplemented: value `start` should be a torch constant int"); + op, "unimplemented: value `start` should be a torch constant int or float"); - if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) + if (matchPattern(op.getEnd(), m_TorchConstantInt(&end_int))) + { + end = (double)(end_int); + } + else if (!matchPattern(op.getEnd(), m_TorchConstantFloat(&end))) return rewriter.notifyMatchFailure( - op, "unimplemented: value `end` should be a torch constant int"); + op, "unimplemented: value `end` should be a torch constant int or float"); - if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) + if (matchPattern(op.getStep(), m_TorchConstantInt(&step_int))) + { + + step = (double)(step_int); + } + + else if (!matchPattern(op.getStep(), m_TorchConstantFloat(&step))) return rewriter.notifyMatchFailure( - op, "unimplemented: value `step` should be a torch constant int"); + op, "unimplemented: value `step` should be a torch constant int or float"); // The result will always be a 1-d tensor. // The size of the result is calculated as follows: // ceil((end - start)/step) - int64_t resultShape = ceil((float)(end - start) / (float)step); - SmallVector values(resultShape, start); - for (unsigned i = 1; i < resultShape; i++) - values[i] += i * step; - Value result = - tosa::getConstTensor(rewriter, op, values, resultShape).value(); + int64_t resultShape = ceil((end - start) / step); + Value result; + if (is_all_inp_int) + { + SmallVector values(resultShape, start); + for (unsigned i = 1; i < resultShape; i++) + values[i] += i * step; + + result = tosa::getConstTensor(rewriter, op, values, resultShape).value(); + } + + else + { + SmallVector values(resultShape, start); + for (unsigned i = 1; i < resultShape; i++) + values[i] += (i * step); + + result = tosa::getConstTensor(rewriter, op, values, resultShape).value(); + } rewriter.replaceOpWithNewOp(op, resultType, result); return success(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 70f26fe421e0..c9768152da25 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -871,6 +871,212 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "Convolution2DStridedModule_basic", + "IscloseStaticModule_basic", + "IscloseStaticModuleTrue_basic", + "TileBigDimsSizeModule_basic", + "TileSmallDimsSizeModule_basic", + "IndexPutImpl2DNoneIndexStaticModule_basic", + "AliasModule_basic", + "MaxPool2dEmptyStrideStaticModule_basic", + "ConstantBoolParameterModule_basic", + "ElementwiseCloneContiguousModule_basic", + "ElementwiseCloneChannelsLastMemoryFormatModule_basic", + "ElementwiseCloneModule_basic", + "ElementwiseUnaryModule_basic", + "ElementwiseBinaryModule_basic", + "ElementwiseSigmoidModule_basic", + "ElementwiseExpModule_basic", + "ElementwiseReluModule_basic", + "ElementwiseLeakyReluModule_basic", + "ElementwiseEluModule_basic", + "ElementwiseEluNonDefaultModule_basic", + "ElementwiseFloorModule_basic", + "ElementwiseFloorIntModule_basic", + "ElementwiseLogModule_basic", + "ElementwiseBinaryStaticShapeModule_basic", + "ElementwiseMinimumModule_basic", + "ElementwiseMinimumIntModule_basic", + "ElementwiseMinOtherIntModule_basic", + "ElementwiseMinOtherModule_basic", + "ElementwiseMaximumModule_basic", + "ElementwiseMaximumIntModule_basic", + "ElementwiseMaxOtherIntModule_basic", + "ElementwiseMaxOtherModule_basic", + "GluStaticModule_basic", + "ViewDoubleMergeStaticModule_basic", + "ViewCollapseOnesMiddleModule_basic", + "ViewFiveTestStaticModule_basic", + "ViewOffsetTestStaticModule_basic", + "ViewTwoFiveThreeStaticModule_basic", + "ViewTwoToThreeStaticModule_basic", + "ViewExpandOnesMiddleOppModule_basic", + "ViewOffsetBackwardTestStaticModule_basic", + "TanhBackward_basic", + "HardtanhBackward_basic", + "ElementwiseAddModule_basic", + "ReturnThreeTensorFloat32_basic", + "AddCMulModule_basic", + "AddCDivModule_basic", + "SqueezeModule_broadcast", + "BoolTensorReturnFalseModule_basic", + "BoolTensorReturnTrueModule_basic", + "BoolTensorReturnMixedModule_basic", + "BoolTensorHandleSignless_basic", + "ElementwiseRsqrtModule_basic", + "SelectIntNegativeDimAndIndexStaticModule_basic", + "SqueezeModule_static", + "SqueezeModule_noUnitDim", + "SqueezeModule_allUnitDim", + "TModuleRank1_basic", + "TModuleRank0_basic", + "ElementwiseToDtypeIdentityModule_basic", + "AtenToDeviceModule_basic", + "View1DFoldModule_basic", + "UnsafeView1DFoldModule_basic", + "UnflattenIntStaticModule_basic", + "UnflattenIntNegativeOneDimStaticModule_basic", + "UnflattenIntNegativeOneSizeStaticModule_basic", + "SqueezeDimModule_static", + "SqueezeDimModule_identity", + "SqueezeDimModule_unitDim", + "ReturnTwoTensorF32I64_basic", + "ElementwiseSignModule_basic", + "ElementwisePowModule_basic", + "BmmFloatModule_basic", + "MmDagModule_basic", + "Matmul4dStatic_basic", + "Matmul_dot", + "Matmul_3d", + "RsubFloatModule_basic", + "RsubFloatModule_noalpha_basic", + "RsubInt0d_NumToTensor_Module_basic", + "ElementwiseBitwiseAndModule_basic", + "ElementwiseBitwiseAndStaticShapeModule_basic", + "ElementwiseBitwiseNotInt32Module_basic", + "ElementwiseBitwiseNotInt64Module_basic", + "ElementwiseOrTensorStaticShapeModule_basic", + "ElementwiseOrTensorModule_basic", + "ElementwiseBitwiseOrModule_basic", + "ElementwiseBitwiseOrStaticShapeModule_basic", + "ElementwiseBitwiseXorModule_basic", + "ElementwiseBitwiseXorStaticShapeModule_basic", + "ElementwiseGeFloatIntScalarModule_basic", + "ElementwiseGeFloatScalarModule_basic", + "ElementwiseGeIntScalarModule_basic", + "ElementwiseGeMixedIntScalarModule_basic", + "ElementwiseGtFloatScalarModule_basic", + "ElementwiseGtIntScalarModule_basic", + "ElementwiseGtMixed2ScalarModule_basic", + "ElementwiseGtFloatTensorModule_basic", + "ElementwiseGtIntTensorModule_basic", + "ElementwiseLtFloatScalarModule_basic", + "ElementwiseLtIntScalarModule_basic", + "ElementwiseLtDiffWidthScalarModule_basic", + "ElementwiseLtFloatTensorModule_basic", + "ElementwiseLtIntTensorModule_basic", + "ElementwiseEqFloatScalarModule_basic", + "ElementwiseEqIntScalarModule_basic", + "ElementwiseEqBoolScalarModule_basic", + "ElementwiseEqDiffWidthScalarModule_basic", + "ElementwiseEqFloatTensorModule_basic", + "ElementwiseEqIntTensorModule_basic", + "ElementwiseNeFloatScalarModule_basic", + "ElementwiseNeFloatTensorModule_basic", + "ElementwiseNeFloatTensorStaticModule_basic", + "ElementwiseNeIntTensorModule_basic", + "ElementwiseNeIntTensorStaticModule_basic", + "ElementwiseMulScalarModule_int", + "ElementwiseMulScalarModule_float", + "ElementwiseMulTensorIntModule_basic", + "ElementwiseDivScalarModule_basic", + "ElementwiseAtenDivIntScalarModule_basic", + "ElementwiseSubScalarFloatModule_basic", + "ElementwiseAddScalarFloatModule_basic", + "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", + "ElementwiseMulScalarModule_float", + "ElementwiseCeilModule_basic", + "ElementwiseReciprocalModule_basic", + "ElementwiseIsnanModule_basic", + "ElementwiseIsinfModule_basic", + "TypePromotionAlphaWiderModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", + "BatchNorm1DModule_basic", + "BatchNorm1DWith2DInputModule_basic", + "BatchNorm2DModule_basic", + "BatchNorm3DModule_basic", + "BatchNorm1DStaticShapeModule_basic", + "FlattenStaticModule_basic", + "UnflattenStaticModule_basic", + "FlattenRank0Module_basic", + "ElementwiseFlattenBroadcastModule_basic", + "SquareModule_basic", + "MaxPool2dStaticModule_basic", + "MaxPool2dStaticCeilModeTrueModule_basic", + "ResNet18StaticModule_basic", + "ReduceAmaxKeepDim_basic", + "NativeLayerNormModule4D_basic", + "LayerNormNormalizeOverAllDimsModule_basic", + "PermuteModule_basic", + "PermuteNegativeIndexModule_basic", + "ElementwiseLog2Module_basic", + "Threshold1dIntI32Module_basic", + "Threshold1dFloatModule_basic", + "Threshold2dFloatModule_basic", + "Threshold3dFloatModule_basic", + "ElementwiseSubScalarIntModule_basic", + "ElementwiseAddScalarIntModule_basic", + "ElementwiseMulScalarModule_basic", + "ZerosModuleDefaultDtype_basic", + "ZerosModuleInt2D_basic", + "ZerosModuleInt3D_basic", + "ZerosModuleFloat2D_basic", + "ZerosModuleFloat3D_basic", + "ZerosModuleFalsePinMemory_basic", + "OnesModuleDefaultDtype_basic", + "OnesModuleInt_basic", + "OnesModuleFloat_basic", + "OnesModuleFalsePinMemory_basic", + "OnesModuleCPUDevice_basic", + "NewZerosModuleDefaultDtype_basic", + "NewZerosModuleInt2D_basic", + "NewZerosModuleInt3D_basic", + "NewZerosModuleFloat2D_basic", + "NewZerosModuleFloat3D_basic", + "NewZerosModuleFalsePinMemory_basic", + "NewOnesModuleDefaultDtype_basic", + "NewOnesModuleInt2D_basic", + "NewOnesModuleInt3D_basic", + "NewOnesModuleFloat2D_basic", + "NewOnesModuleFloat3D_basic", + "NewOnesModuleFalsePinMemory_basic", + "SiluModule_basic", + "DropoutEvalIntModule_basic", + "DropoutEvalFloatModule_basic", + "ContiguousModule_basic", + "DropoutModule_basic", + "ViewExpandModule_basic", + "ViewExpandOnesModule_basic", + "ViewExpandOnesBeforeAndAfterModule_basic", + "ViewExpandOnesMiddleModule_basic", + "ViewExpandCollapseModule_basic", + "ViewExpandCollapseWithOnesModule_basic", + "ViewCollapseInferredDimModule_basic", + "ViewExpandInferredDimModule_basic", + "ViewNegativeStaticModule_basic", + "ViewNoChangeStaticModule_basic", + "UnsafeViewExpandModule_basic", + "ReshapeCollapseModule_basic", + "ReshapeAsModule_basic", + "ElementwiseGeluModule_basic", + "GeluBackwardModule_basic", + "ElementwiseNeIntScalarModule_basic", + "Convolution2DStaticModule_basic", + "ElementwiseNegModule_basic", + "TestMultipleTensorReturn_basic", + "TypeAsSameModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "AddCDivModule_basic", @@ -888,6 +1094,14 @@ "ArangeStartOutViewModule_basic", "ArangeStartStepIntModule_basic", "ArangeZeroElementOutputModule_basic", + "ArangeDtypeIntModule_basic", + "ArangeFalsePinMemoryModule_basic", + "ArangeFloatModule_basic", + "ArangeNegativeStartFloatModule_basic", + "ArangeStartFloatModule_basic", + "ArangeStartNegativeStepFloatModule_basic", + "ArangeStartOutDtypeModule_basic", + "ArangeStartStepFloatModule_basic", "ArgmaxModule_keepDim", "ArgmaxModule_with_dim", "AtenComplex64Module_basic",