Skip to content

Commit

Permalink
[TOSA] Add aten.Index.Tensor support (llvm#1771)
Browse files Browse the repository at this point in the history
  • Loading branch information
Chi_Liu authored and gpetters94 committed May 8, 2023
1 parent 25ee0c9 commit 145b1f0
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 0 deletions.
2 changes: 2 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,7 @@
"TypePromotionSameCategoryDifferentWidthModule_basic",
"TypePromotionZeroRankHigherCategoryModule_basic",
"GatherStaticModule_basic",
"IndexTensorStaticModule_basic",
"LiftFreshCopyModule_basic",
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
"ReduceSumDimIntListFloatModule_basic",
Expand Down Expand Up @@ -719,6 +720,7 @@
"IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexTensorModule3dInput_basic",
"IndexTensorModule_basic",
"IndexTensorStaticModule_basic",
"IndexTensorMultiInputContiguousCenter_basic",
"IndexTensorMultiInputNonContiguous_basic",
"IndexTensorMultiInputOneDim_basic",
Expand Down
80 changes: 80 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down Expand Up @@ -3232,6 +3233,84 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
return success();
}

template <>
LogicalResult ConvertAtenOp<AtenIndexTensorOp>::matchAndRewrite(
AtenIndexTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// t = tf.constant([[1, 2, 3, 4, 5],[6,7,8,9,10],
// [11,12,13,14,15],[16,17,18,19,20]]) # 4*5
// i = tf.constant([[1,2,3], [3,2,1]]) # 2*3
// i_expand = tf.expand_dims(i,axis=2) # 2*3*1
// IndexTensorOutput = tf.gather_nd(t,tf.i_expand)
// = torch.ops.aten.index(t, (i, )) = t[i] # 2*3*5
// [[[ 6, 7, 8, 9, 10], [11, 12, 13, 14, 15], [16, 17, 18, 19, 20]],
// [[16, 17, 18, 19, 20], [11, 12, 13, 14, 15], [ 6, 7, 8, 9, 10]]]
auto input = adaptor.getSelf();
auto inputTensorType =
adaptor.getSelf().getType().dyn_cast<RankedTensorType>();
// Check input is a tensor type.
if (!inputTensorType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");

// Deal with torch.prim.ListConstruct of non const value to get the index
auto tensorList = op.getIndices();
SmallVector<Value> tensorsTorchType;
if (!getListConstructElements(tensorList, tensorsTorchType))
return op.emitError(
"unimplemented: the tensor list is not from list construct");
auto tensors = getTypeConvertedValues(rewriter, op->getLoc(),
getTypeConverter(), tensorsTorchType);

// TODO add support for multiple index
if ( tensors.size() > 1){
return op.emitError(
"unimplemented: the index tensor list from list construct > 1");
}
auto index = tensors[0];
// TODO add support for none index input like torch.ops.aten.index(x, (None, index1, index2, None))
if (!index.getImpl())
return rewriter.notifyMatchFailure(
op, "Only list ranked tensor types index are supported");
auto indexType = index.getType().dyn_cast<RankedTensorType>();
auto indexShape = indexType.getShape();
// index i64 to i32 for tosa compatible
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
index = rewriter.create<tosa::CastOp>(
op->getLoc(),
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index);
}

auto outType = getTypeConverter()->convertType(op.getType());

// Expand last dim of index to tf indices [2,3] -> [2,3,1]
SmallVector<int64_t> indicesShape;
for (auto shape : indexShape) {
indicesShape.push_back(shape);
}
indicesShape.push_back(1);
auto indicesTf = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index,
rewriter.getDenseI64ArrayAttr(indicesShape));

if (!indicesTf) {
return rewriter.notifyMatchFailure(op,
"Convert TorchIndex To TfIndices fail.");
}
// do the tf gathernp algorithm with tf style indices as input.
auto result = tosa::convertGatherNdOp(rewriter, op, outType, input,
indicesTf.getResult());

if (!result) {
return rewriter.notifyMatchFailure(
op, "Convert GatherNdOp fail for index tensor.");
}
rewriter.replaceOp(op, {result.value()});

return success();
}

template <>
LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
AtenWhereSelfOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -4196,6 +4275,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenGatherOp);
INSERT_ATENOP_PATTERN(AtenIndexTensorOp);
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
INSERT_ATENOP_PATTERN(AtenClampOp);
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
Expand Down
21 changes: 21 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1787,6 +1787,27 @@ def IndexTensorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5), tu.randint(2, 3, high=4))


# ==============================================================================
class IndexTensorStaticModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([4, 5], torch.float32, True),
([2, 3], torch.int64, True),
])
def forward(self, x, index):
return torch.ops.aten.index(x, (index, ))


@register_test_case(module_factory=lambda: IndexTensorStaticModule())
def IndexTensorStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 5), tu.randint(2, 3, high=4))


# ==============================================================================


Expand Down

0 comments on commit 145b1f0

Please sign in to comment.