Skip to content

Commit

Permalink
[TOSA] Add aten.Index.Tensor support
Browse files Browse the repository at this point in the history
  • Loading branch information
AmosLewis authored and AmosLewis committed Jan 17, 2023
1 parent 3f49ba9 commit 6ec4143
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 0 deletions.
1 change: 1 addition & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,7 @@
"TypePromotionSameCategoryDifferentWidthModule_basic",
"TypePromotionZeroRankHigherCategoryModule_basic",
"GatherStaticModule_basic",
"IndexTensorStaticModule_basic",
"LiftFreshCopyModule_basic",
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
"ReduceSumDimIntListFloatModule_basic",
Expand Down
69 changes: 69 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down Expand Up @@ -3232,6 +3233,73 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
return success();
}

template <>
LogicalResult ConvertAtenOp<AtenIndexTensorOp>::matchAndRewrite(
AtenIndexTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Python example of this algorithm:
// https://gist.github.com/AmosLewis/c90c1148a96291db93408b3fa39f9ae2
auto input = adaptor.getSelf();
auto inputTensorType =
adaptor.getSelf().getType().dyn_cast<RankedTensorType>();
// Check input is a tensor type.
if (!inputTensorType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");

// Deal with torch.prim.ListConstruct of non const value to get the index
auto tensorList = op.getIndices();
SmallVector<Value> tensorsTorchType;
if (!getListConstructElements(tensorList, tensorsTorchType))
return op.emitError(
"unimplemented: the tensor list is not from list construct");
auto tensors = getTypeConvertedValues(rewriter, op->getLoc(),
getTypeConverter(), tensorsTorchType);

auto index = tensors[0];
// TODO figure out why the index is empty for IndexTensorMultiInputContiguousCenter e2e test
if (!index.getImpl())
return rewriter.notifyMatchFailure(
op, "Only list ranked tensor types index are supported");
auto indexType = index.getType().dyn_cast<RankedTensorType>();
auto indexShape = indexType.getShape();
// index i64 to i32 for tosa compatible
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
index = rewriter.create<tosa::CastOp>(
op->getLoc(),
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index);
}

auto outType = getTypeConverter()->convertType(op.getType());

// Expand last dim of index to tf indices [2,3] -> [2,3,1]
SmallVector<int64_t> indicesShape;
for (auto shape : indexShape) {
indicesShape.push_back(shape);
}
indicesShape.push_back(1);
auto indicesTf = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index,
rewriter.getDenseI64ArrayAttr(indicesShape));

if (!indicesTf) {
return rewriter.notifyMatchFailure(op,
"Convert TorchIndex To TfIndices fail.");
}
// do the tf gathernp algorithm with tf style indices as input.
auto result = tosa::convertGatherNdOp(rewriter, op, outType, input,
indicesTf.getResult());

if (!result) {
return rewriter.notifyMatchFailure(
op, "Convert GatherNdOp fail for index tensor.");
}
rewriter.replaceOp(op, {result.value()});

return success();
}

template <>
LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
AtenWhereSelfOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -4196,6 +4264,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenGatherOp);
INSERT_ATENOP_PATTERN(AtenIndexTensorOp);
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
INSERT_ATENOP_PATTERN(AtenClampOp);
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
Expand Down
21 changes: 21 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1787,6 +1787,27 @@ def IndexTensorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5), tu.randint(2, 3, high=4))


# ==============================================================================
class IndexTensorStaticModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([4, 5], torch.float32, True),
([2, 3], torch.int64, True),
])
def forward(self, x, index):
return torch.ops.aten.index(x, (index, ))


@register_test_case(module_factory=lambda: IndexTensorStaticModule())
def IndexTensorStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 5), tu.randint(2, 3, high=4))


# ==============================================================================


Expand Down
26 changes: 26 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,32 @@ func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.v
return %0 : !torch.vtensor<[1,4,2],f32>
}

// -----
// CHECK-LABEL: func.func @torch.aten.index.Tensor(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1],si64>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,2],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,2],f32> -> tensor<1x2xf32>
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_0]] : (!torch.vtensor<[1],si64>) -> !torch.list<vtensor>
// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1],si64> -> tensor<1xi64>
// CHECK: %[[VAL_5:.*]] = "tosa.cast"(%[[VAL_4]]) : (tensor<1xi64>) -> tensor<1xi32>
// CHECK: %[[VAL_6:.*]] = "tosa.reshape"(%[[VAL_5]]) {new_shape = array<i64: 1, 1>} : (tensor<1xi32>) -> tensor<1x1xi32>
// CHECK: %[[VAL_7:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = array<i64: 1, 1, 2>} : (tensor<1x2xf32>) -> tensor<1x1x2xf32>
// CHECK: %[[VAL_8:.*]] = "tosa.reshape"(%[[VAL_6]]) {new_shape = array<i64: 1, 1>} : (tensor<1x1xi32>) -> tensor<1x1xi32>
// CHECK: %[[VAL_9:.*]] = "tosa.const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK: %[[VAL_10:.*]] = "tosa.mul"(%[[VAL_8]], %[[VAL_9]]) {shift = 0 : i32} : (tensor<1x1xi32>, tensor<1xi32>) -> tensor<1x1xi32>
// CHECK: %[[VAL_11:.*]] = "tosa.reduce_sum"(%[[VAL_10]]) {axis = 1 : i64} : (tensor<1x1xi32>) -> tensor<1x1xi32>
// CHECK: %[[VAL_12:.*]] = "tosa.reshape"(%[[VAL_11]]) {new_shape = array<i64: 1, 1>} : (tensor<1x1xi32>) -> tensor<1x1xi32>
// CHECK: %[[VAL_13:.*]] = "tosa.gather"(%[[VAL_7]], %[[VAL_12]]) : (tensor<1x1x2xf32>, tensor<1x1xi32>) -> tensor<1x1x2xf32>
// CHECK: %[[VAL_14:.*]] = "tosa.reshape"(%[[VAL_13]]) {new_shape = array<i64: 1, 2>} : (tensor<1x1x2xf32>) -> tensor<1x2xf32>
// CHECK: %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor<1x2xf32> -> !torch.vtensor<[1,2],f32>
// CHECK: return %[[VAL_15]] : !torch.vtensor<[1,2],f32>
// CHECK: }
func.func @torch.aten.index.Tensor(%arg0: !torch.vtensor<[1],si64>, %arg1: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,2],f32> {
%0 = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%1 = torch.aten.index.Tensor %arg1, %0 : !torch.vtensor<[1,2],f32>, !torch.list<vtensor> -> !torch.vtensor<[1,2],f32>
return %1 : !torch.vtensor<[1,2],f32>
}

// -----
// CHECK-LABEL: func.func @torch.aten.add$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,2],si32>,
Expand Down

0 comments on commit 6ec4143

Please sign in to comment.