Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aten.unflatten.int support and its torch-to-tosa lowering #2509

Merged
merged 4 commits into from
Oct 14, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7537,6 +7537,30 @@ def Torch_AtenFlattenUsingIntsOp : Torch_Op<"aten.flatten.using_ints", [
}];
}

def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchListOfTorchIntType:$sizes
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenUnflattenIntOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenUnflattenIntOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenDimOp : Torch_Op<"aten.dim", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
55 changes: 55 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2525,6 +2525,60 @@ LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
return success();
}

template <>
LogicalResult ConvertAtenOp<AtenUnflattenIntOp>::matchAndRewrite(
AtenUnflattenIntOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {

// Not a ranked tensor type
auto selfType = adaptor.getSelf().getType().dyn_cast<RankedTensorType>();
if (!selfType || !selfType.hasStaticShape())
return rewriter.notifyMatchFailure(
op,
"Only ranked tensor types with static shapes are currently supported");

int64_t selfRank = selfType.getRank();
int64_t dim;

if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant");

SmallVector<int64_t> sizes;
if (!matchPattern(op.getSizes(), m_TorchListOfConstantInts(sizes)))
return rewriter.notifyMatchFailure(
op, "Only constant sizes are currently supported");

if (selfRank > 0 && !isValidDim(dim, selfRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");

SmallVector<int64_t> newShape;
for (auto s :
llvm::enumerate(makeShapeTorchCompatible(selfType.getShape()))) {
int64_t idx = s.index();
if (idx < dim || idx > dim) {
newShape.push_back(s.value());
} else {
auto sum = 1;
for (auto newDims : sizes) {
newShape.push_back(newDims);
sum *= newDims;
}
if (sum != s.value())
return rewriter.notifyMatchFailure(op,
"sizes mismatch with original dim");
}
}

auto newType = RankedTensorType::get(makeShapeLLVMCompatible(newShape),
selfType.getElementType());

rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, getTypeConverter()->convertType(newType), adaptor.getSelf(),
rewriter.getDenseI64ArrayAttr(newShape));

return success();
}

template <>
LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
AtenPermuteOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -5050,6 +5104,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp);
INSERT_ATENOP_PATTERN(AtenUnflattenIntOp);
INSERT_ATENOP_PATTERN(AtenPermuteOp);
INSERT_ATENOP_PATTERN(AtenLog2Op);
INSERT_ATENOP_PATTERN(AtenThresholdOp);
Expand Down
14 changes: 14 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7205,6 +7205,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.flatten(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.unflatten.int\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.slice.t %arg0, %none, %arg1, %int1 : !torch.list<int>, !torch.none, !torch.int, !torch.int -> !torch.list<int>\n"
" %1 = torch.aten.add.t %0, %arg2 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
" %2 = torch.aten.add.int %arg1, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %3 = torch.aten.slice.t %arg0, %2, %none, %int1 : !torch.list<int>, !torch.int, !torch.none, !torch.int -> !torch.list<int>\n"
" %4 = torch.aten.add.t %1, %3 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
" return %4 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.linear\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.linear(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -8580,6 +8590,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.unflatten.int\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.flip\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
8 changes: 4 additions & 4 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,10 @@ bool Torch::isViewLikeOp(Operation *op) {
// that it does not return a view and treat those as having value
// semantics.
return isa<AtenBroadcastToOp, AtenContiguousOp, AtenDetachOp, AtenExpandAsOp,
AtenExpandOp, AtenFlattenUsingIntsOp, AtenPermuteOp, AtenReshapeOp,
Aten_ReshapeAliasOp, AtenSelectIntOp, AtenSliceTensorOp,
AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp,
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
AtenExpandOp, AtenFlattenUsingIntsOp, AtenUnflattenIntOp,
AtenPermuteOp, AtenReshapeOp, Aten_ReshapeAliasOp, AtenSelectIntOp,
AtenSliceTensorOp, AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp,
AtenToDtypeOp, AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,9 @@ def aten〇adaptive_avg_pool2d〡shape(self: List[int], output_size: List[int])
def aten〇flatten〇using_ints〡shape(self: List[int], start_dim: int = 0, end_dim: int = -1) -> List[int]:
return upstream_shape_functions.flatten(self, start_dim, end_dim)

def aten〇unflatten〇int〡shape(self: List[int], dim: int, sizes: List[int]) -> List[int]:
return self[:dim] + sizes + self[dim + 1:]

def aten〇linear〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None) -> List[int]:
return upstream_shape_functions.linear(input, weight, bias)

Expand Down Expand Up @@ -1656,6 +1659,11 @@ def aten〇flatten〇using_ints〡dtype(self_rank_dtype: Tuple[int, int], start_
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, sizes=[1]))
def aten〇unflatten〇int〡dtype(self_rank_dtype: Tuple[int, int], dim: int, sizes: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[0]))
def aten〇flip〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True)
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)
emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)")
emit("aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)")
emit("aten::dim : (Tensor) -> (int)", has_folder=True)
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
emit("aten::Bool.Tensor : (Tensor) -> (bool)")
Expand Down
22 changes: 22 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,28 @@ func.func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor

// -----

// CHECK-LABEL: func.func @forward(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,6,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> {
// CHECK: %[[VAL:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,6,4],f32> -> tensor<1x6x4xf32>
// CHECK: %[[VAL_1:.*]] = torch.constant.int 1
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
// CHECK: %[[VAL_3:.*]] = torch.constant.int 3
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL]] {new_shape = array<i64: 1, 2, 3, 4>} : (tensor<1x6x4xf32>) -> tensor<1x2x3x4xf32>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x2x3x4xf32> -> !torch.vtensor<[1,2,3,4],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,2,3,4],f32>
// CHECK: }
func.func @forward(%arg0: !torch.vtensor<[1,6,4],f32> ) -> !torch.vtensor<[1,2,3,4],f32> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[1,6,4],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[1,2,3,4],f32>
return %1 : !torch.vtensor<[1,2,3,4],f32>
}

// -----

// CHECK-LABEL: func.func @forward(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,2,3],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2,3],f32>,
Expand Down