Skip to content

Commit

Permalink
ADDED SUPPORT FLOAT VALUE IN ARANGE
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek-TyRnT authored and newling committed Feb 26, 2024
1 parent 3cbe6c9 commit 0ee752b
Show file tree
Hide file tree
Showing 2 changed files with 260 additions and 13 deletions.
59 changes: 46 additions & 13 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include <iostream>

using namespace mlir;
using namespace mlir::torch;
Expand Down Expand Up @@ -4067,28 +4068,60 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
op, "unimplemented: pin_memory must be either None or false");
}

int64_t start, step, end;
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
double start, step, end;
int64_t start_int, step_int, end_int;
bool is_all_inp_int; //Flag to check whether all inputs are integer
is_all_inp_int = op.getStart().getType().isa<Torch::IntType>() && op.getEnd().getType().isa<Torch::IntType>() && op.getStep().getType().isa<Torch::IntType>();

if (matchPattern(op.getStart(), m_TorchConstantInt(&start_int)))
{
start = (double)(start_int);
}

else if(!matchPattern(op.getStart(), m_TorchConstantFloat(&start)))
return rewriter.notifyMatchFailure(
op, "unimplemented: value `start` should be a torch constant int");
op, "unimplemented: value `start` should be a torch constant int or float");

if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end)))
if (matchPattern(op.getEnd(), m_TorchConstantInt(&end_int)))
{
end = (double)(end_int);
}
else if (!matchPattern(op.getEnd(), m_TorchConstantFloat(&end)))
return rewriter.notifyMatchFailure(
op, "unimplemented: value `end` should be a torch constant int");
op, "unimplemented: value `end` should be a torch constant int or float");

if (!matchPattern(op.getStep(), m_TorchConstantInt(&step)))
if (matchPattern(op.getStep(), m_TorchConstantInt(&step_int)))
{

step = (double)(step_int);
}

else if (!matchPattern(op.getStep(), m_TorchConstantFloat(&step)))
return rewriter.notifyMatchFailure(
op, "unimplemented: value `step` should be a torch constant int");
op, "unimplemented: value `step` should be a torch constant int or float");

// The result will always be a 1-d tensor.
// The size of the result is calculated as follows:
// ceil((end - start)/step)
int64_t resultShape = ceil((float)(end - start) / (float)step);
SmallVector<int64_t> values(resultShape, start);
for (unsigned i = 1; i < resultShape; i++)
values[i] += i * step;
Value result =
tosa::getConstTensor<int64_t>(rewriter, op, values, resultShape).value();
int64_t resultShape = ceil((end - start) / step);
Value result;
if (is_all_inp_int)
{
SmallVector<int64_t> values(resultShape, start);
for (unsigned i = 1; i < resultShape; i++)
values[i] += i * step;

result = tosa::getConstTensor<int64_t>(rewriter, op, values, resultShape).value();
}

else
{
SmallVector<float> values(resultShape, start);
for (unsigned i = 1; i < resultShape; i++)
values[i] += (i * step);

result = tosa::getConstTensor<float>(rewriter, op, values, resultShape).value();
}

rewriter.replaceOpWithNewOp<tosa::CastOp>(op, resultType, result);
return success();
Expand Down
214 changes: 214 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,212 @@
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"Convolution2DStridedModule_basic",
"IscloseStaticModule_basic",
"IscloseStaticModuleTrue_basic",
"TileBigDimsSizeModule_basic",
"TileSmallDimsSizeModule_basic",
"IndexPutImpl2DNoneIndexStaticModule_basic",
"AliasModule_basic",
"MaxPool2dEmptyStrideStaticModule_basic",
"ConstantBoolParameterModule_basic",
"ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneChannelsLastMemoryFormatModule_basic",
"ElementwiseCloneModule_basic",
"ElementwiseUnaryModule_basic",
"ElementwiseBinaryModule_basic",
"ElementwiseSigmoidModule_basic",
"ElementwiseExpModule_basic",
"ElementwiseReluModule_basic",
"ElementwiseLeakyReluModule_basic",
"ElementwiseEluModule_basic",
"ElementwiseEluNonDefaultModule_basic",
"ElementwiseFloorModule_basic",
"ElementwiseFloorIntModule_basic",
"ElementwiseLogModule_basic",
"ElementwiseBinaryStaticShapeModule_basic",
"ElementwiseMinimumModule_basic",
"ElementwiseMinimumIntModule_basic",
"ElementwiseMinOtherIntModule_basic",
"ElementwiseMinOtherModule_basic",
"ElementwiseMaximumModule_basic",
"ElementwiseMaximumIntModule_basic",
"ElementwiseMaxOtherIntModule_basic",
"ElementwiseMaxOtherModule_basic",
"GluStaticModule_basic",
"ViewDoubleMergeStaticModule_basic",
"ViewCollapseOnesMiddleModule_basic",
"ViewFiveTestStaticModule_basic",
"ViewOffsetTestStaticModule_basic",
"ViewTwoFiveThreeStaticModule_basic",
"ViewTwoToThreeStaticModule_basic",
"ViewExpandOnesMiddleOppModule_basic",
"ViewOffsetBackwardTestStaticModule_basic",
"TanhBackward_basic",
"HardtanhBackward_basic",
"ElementwiseAddModule_basic",
"ReturnThreeTensorFloat32_basic",
"AddCMulModule_basic",
"AddCDivModule_basic",
"SqueezeModule_broadcast",
"BoolTensorReturnFalseModule_basic",
"BoolTensorReturnTrueModule_basic",
"BoolTensorReturnMixedModule_basic",
"BoolTensorHandleSignless_basic",
"ElementwiseRsqrtModule_basic",
"SelectIntNegativeDimAndIndexStaticModule_basic",
"SqueezeModule_static",
"SqueezeModule_noUnitDim",
"SqueezeModule_allUnitDim",
"TModuleRank1_basic",
"TModuleRank0_basic",
"ElementwiseToDtypeIdentityModule_basic",
"AtenToDeviceModule_basic",
"View1DFoldModule_basic",
"UnsafeView1DFoldModule_basic",
"UnflattenIntStaticModule_basic",
"UnflattenIntNegativeOneDimStaticModule_basic",
"UnflattenIntNegativeOneSizeStaticModule_basic",
"SqueezeDimModule_static",
"SqueezeDimModule_identity",
"SqueezeDimModule_unitDim",
"ReturnTwoTensorF32I64_basic",
"ElementwiseSignModule_basic",
"ElementwisePowModule_basic",
"BmmFloatModule_basic",
"MmDagModule_basic",
"Matmul4dStatic_basic",
"Matmul_dot",
"Matmul_3d",
"RsubFloatModule_basic",
"RsubFloatModule_noalpha_basic",
"RsubInt0d_NumToTensor_Module_basic",
"ElementwiseBitwiseAndModule_basic",
"ElementwiseBitwiseAndStaticShapeModule_basic",
"ElementwiseBitwiseNotInt32Module_basic",
"ElementwiseBitwiseNotInt64Module_basic",
"ElementwiseOrTensorStaticShapeModule_basic",
"ElementwiseOrTensorModule_basic",
"ElementwiseBitwiseOrModule_basic",
"ElementwiseBitwiseOrStaticShapeModule_basic",
"ElementwiseBitwiseXorModule_basic",
"ElementwiseBitwiseXorStaticShapeModule_basic",
"ElementwiseGeFloatIntScalarModule_basic",
"ElementwiseGeFloatScalarModule_basic",
"ElementwiseGeIntScalarModule_basic",
"ElementwiseGeMixedIntScalarModule_basic",
"ElementwiseGtFloatScalarModule_basic",
"ElementwiseGtIntScalarModule_basic",
"ElementwiseGtMixed2ScalarModule_basic",
"ElementwiseGtFloatTensorModule_basic",
"ElementwiseGtIntTensorModule_basic",
"ElementwiseLtFloatScalarModule_basic",
"ElementwiseLtIntScalarModule_basic",
"ElementwiseLtDiffWidthScalarModule_basic",
"ElementwiseLtFloatTensorModule_basic",
"ElementwiseLtIntTensorModule_basic",
"ElementwiseEqFloatScalarModule_basic",
"ElementwiseEqIntScalarModule_basic",
"ElementwiseEqBoolScalarModule_basic",
"ElementwiseEqDiffWidthScalarModule_basic",
"ElementwiseEqFloatTensorModule_basic",
"ElementwiseEqIntTensorModule_basic",
"ElementwiseNeFloatScalarModule_basic",
"ElementwiseNeFloatTensorModule_basic",
"ElementwiseNeFloatTensorStaticModule_basic",
"ElementwiseNeIntTensorModule_basic",
"ElementwiseNeIntTensorStaticModule_basic",
"ElementwiseMulScalarModule_int",
"ElementwiseMulScalarModule_float",
"ElementwiseMulTensorIntModule_basic",
"ElementwiseDivScalarModule_basic",
"ElementwiseAtenDivIntScalarModule_basic",
"ElementwiseSubScalarFloatModule_basic",
"ElementwiseAddScalarFloatModule_basic",
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
"ElementwiseMulScalarModule_float",
"ElementwiseCeilModule_basic",
"ElementwiseReciprocalModule_basic",
"ElementwiseIsnanModule_basic",
"ElementwiseIsinfModule_basic",
"TypePromotionAlphaWiderModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"BatchNorm1DModule_basic",
"BatchNorm1DWith2DInputModule_basic",
"BatchNorm2DModule_basic",
"BatchNorm3DModule_basic",
"BatchNorm1DStaticShapeModule_basic",
"FlattenStaticModule_basic",
"UnflattenStaticModule_basic",
"FlattenRank0Module_basic",
"ElementwiseFlattenBroadcastModule_basic",
"SquareModule_basic",
"MaxPool2dStaticModule_basic",
"MaxPool2dStaticCeilModeTrueModule_basic",
"ResNet18StaticModule_basic",
"ReduceAmaxKeepDim_basic",
"NativeLayerNormModule4D_basic",
"LayerNormNormalizeOverAllDimsModule_basic",
"PermuteModule_basic",
"PermuteNegativeIndexModule_basic",
"ElementwiseLog2Module_basic",
"Threshold1dIntI32Module_basic",
"Threshold1dFloatModule_basic",
"Threshold2dFloatModule_basic",
"Threshold3dFloatModule_basic",
"ElementwiseSubScalarIntModule_basic",
"ElementwiseAddScalarIntModule_basic",
"ElementwiseMulScalarModule_basic",
"ZerosModuleDefaultDtype_basic",
"ZerosModuleInt2D_basic",
"ZerosModuleInt3D_basic",
"ZerosModuleFloat2D_basic",
"ZerosModuleFloat3D_basic",
"ZerosModuleFalsePinMemory_basic",
"OnesModuleDefaultDtype_basic",
"OnesModuleInt_basic",
"OnesModuleFloat_basic",
"OnesModuleFalsePinMemory_basic",
"OnesModuleCPUDevice_basic",
"NewZerosModuleDefaultDtype_basic",
"NewZerosModuleInt2D_basic",
"NewZerosModuleInt3D_basic",
"NewZerosModuleFloat2D_basic",
"NewZerosModuleFloat3D_basic",
"NewZerosModuleFalsePinMemory_basic",
"NewOnesModuleDefaultDtype_basic",
"NewOnesModuleInt2D_basic",
"NewOnesModuleInt3D_basic",
"NewOnesModuleFloat2D_basic",
"NewOnesModuleFloat3D_basic",
"NewOnesModuleFalsePinMemory_basic",
"SiluModule_basic",
"DropoutEvalIntModule_basic",
"DropoutEvalFloatModule_basic",
"ContiguousModule_basic",
"DropoutModule_basic",
"ViewExpandModule_basic",
"ViewExpandOnesModule_basic",
"ViewExpandOnesBeforeAndAfterModule_basic",
"ViewExpandOnesMiddleModule_basic",
"ViewExpandCollapseModule_basic",
"ViewExpandCollapseWithOnesModule_basic",
"ViewCollapseInferredDimModule_basic",
"ViewExpandInferredDimModule_basic",
"ViewNegativeStaticModule_basic",
"ViewNoChangeStaticModule_basic",
"UnsafeViewExpandModule_basic",
"ReshapeCollapseModule_basic",
"ReshapeAsModule_basic",
"ElementwiseGeluModule_basic",
"GeluBackwardModule_basic",
"ElementwiseNeIntScalarModule_basic",
"Convolution2DStaticModule_basic",
"ElementwiseNegModule_basic",
"TestMultipleTensorReturn_basic",
"TypeAsSameModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"AddCDivModule_basic",
Expand All @@ -888,6 +1094,14 @@
"ArangeStartOutViewModule_basic",
"ArangeStartStepIntModule_basic",
"ArangeZeroElementOutputModule_basic",
"ArangeDtypeIntModule_basic",
"ArangeFalsePinMemoryModule_basic",
"ArangeFloatModule_basic",
"ArangeNegativeStartFloatModule_basic",
"ArangeStartFloatModule_basic",
"ArangeStartNegativeStepFloatModule_basic",
"ArangeStartOutDtypeModule_basic",
"ArangeStartStepFloatModule_basic",
"ArgmaxModule_keepDim",
"ArgmaxModule_with_dim",
"AtenComplex64Module_basic",
Expand Down

0 comments on commit 0ee752b

Please sign in to comment.