Skip to content

Commit

Permalink
[TOSA] Add aten.Index.Tensor support
Browse files Browse the repository at this point in the history
  • Loading branch information
AmosLewis authored and AmosLewis committed Jan 3, 2023
1 parent d44bdd2 commit 4b57b24
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 0 deletions.
47 changes: 47 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down Expand Up @@ -3180,6 +3181,51 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
return success();
}

template <>
LogicalResult ConvertAtenOp<AtenIndexTensorOp>::matchAndRewrite(
AtenIndexTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// 1) Not a tensor type.
auto input = adaptor.getSelf();
auto inputTensorType =
adaptor.getSelf().getType().dyn_cast<RankedTensorType>();
if (!inputTensorType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");

// deal with torch.prim.ListConstruct of non const value
auto tensorList = op.getIndices();
SmallVector<Value> tensorsTorchType;
if (!getListConstructElements(tensorList, tensorsTorchType))
return op.emitError(
"unimplemented: the tensor list is not from list construct");
auto tensors = getTypeConvertedValues(rewriter, op->getLoc(),
getTypeConverter(), tensorsTorchType);

auto indices = tensors[0];
auto indicesType = indices.getType().dyn_cast<RankedTensorType>();
// index i64 to i32 for tosa compatitable
if (indicesType.getElementType() != rewriter.getIntegerType(32)) {
indices = rewriter.create<tosa::CastOp>(
op->getLoc(),
RankedTensorType::get(indicesType.getShape(),
rewriter.getIntegerType(32)),
indices);
}

auto outType = getTypeConverter()->convertType(op.getType());
// do the tf gathernp algorithm with tf style indices as input.
auto result = tosa::convertGatherNdOp(rewriter, op, outType, input, indices);

if (!result) {
return rewriter.notifyMatchFailure(
op, "Convert GatherNdOp fail for index tensor.");
}
rewriter.replaceOp(op, {result.value()});

return success();
}

template <>
LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
AtenWhereSelfOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -4131,6 +4177,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenGatherOp);
INSERT_ATENOP_PATTERN(AtenIndexTensorOp);
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
INSERT_ATENOP_PATTERN(AtenClampOp);
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
Expand Down
25 changes: 25 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,31 @@ func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.v
return %0 : !torch.vtensor<[1,4,2],f32>
}

// -----
// CHECK-LABEL: func.func @torch.aten.index.Tensor(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1],si64>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,2],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,2],f32> -> tensor<1x2xf32>
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_0]] : (!torch.vtensor<[1],si64>) -> !torch.list<vtensor>
// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1],si64> -> tensor<1xi64>
// CHECK: %[[VAL_5:.*]] = "tosa.cast"(%[[VAL_4]]) : (tensor<1xi64>) -> tensor<1xi32>
// CHECK: %[[VAL_6:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = [1, 1, 2]} : (tensor<1x2xf32>) -> tensor<1x1x2xf32>
// CHECK: %[[VAL_7:.*]] = "tosa.reshape"(%[[VAL_5]]) {new_shape = [1, 1]} : (tensor<1xi32>) -> tensor<1x1xi32>
// CHECK: %[[VAL_8:.*]] = "tosa.const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK: %[[VAL_9:.*]] = "tosa.mul"(%[[VAL_7]], %[[VAL_8]]) {shift = 0 : i32} : (tensor<1x1xi32>, tensor<1xi32>) -> tensor<1x1xi32>
// CHECK: %[[VAL_10:.*]] = "tosa.reduce_sum"(%[[VAL_9]]) {axis = 1 : i64} : (tensor<1x1xi32>) -> tensor<1x1xi32>
// CHECK: %[[VAL_11:.*]] = "tosa.reshape"(%[[VAL_10]]) {new_shape = [1, 1]} : (tensor<1x1xi32>) -> tensor<1x1xi32>
// CHECK: %[[VAL_12:.*]] = "tosa.gather"(%[[VAL_6]], %[[VAL_11]]) : (tensor<1x1x2xf32>, tensor<1x1xi32>) -> tensor<1x1x2xf32>
// CHECK: %[[VAL_13:.*]] = "tosa.reshape"(%[[VAL_12]]) {new_shape = [1, 2]} : (tensor<1x1x2xf32>) -> tensor<1x2xf32>
// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_13]] : tensor<1x2xf32> -> !torch.vtensor<[1,2],f32>
// CHECK: return %[[VAL_14]] : !torch.vtensor<[1,2],f32>
// CHECK: }
func.func @torch.aten.index.Tensor(%arg0: !torch.vtensor<[1],si64>, %arg1: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,2],f32> {
%0 = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%1 = torch.aten.index.Tensor %arg1, %0 : !torch.vtensor<[1,2],f32>, !torch.list<vtensor> -> !torch.vtensor<[1,2],f32>
return %1 : !torch.vtensor<[1,2],f32>
}

// -----
// CHECK-LABEL: func.func @torch.aten.add$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,2],si32>,
Expand Down

0 comments on commit 4b57b24

Please sign in to comment.