Skip to content

Commit

Permalink
Add aten.unflatten.int support and its torch-to-tosa lowering (llvm#2509
Browse files Browse the repository at this point in the history
)

Add aten.unflatten.int op
Add its torch-to-tosa lowering
Update the TorchToTosa/basic.mlir tests

To test e2e tosa lowering:

`python -m e2e_testing.main -v -c=tosa`

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
  • Loading branch information
zezhang and Ze Zhang authored Oct 14, 2023
1 parent 9b5a4af commit e649e06
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 4 deletions.
2 changes: 2 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"UnflattenStaticModule_basic",
}

TORCHDYNAMO_XFAIL_SET = {
Expand Down Expand Up @@ -1056,6 +1057,7 @@
"BatchNorm3DModule_basic",
"BatchNorm1DStaticShapeModule_basic",
"FlattenStaticModule_basic",
"UnflattenStaticModule_basic",
"FlattenRank0Module_basic",
"ElementwiseFlattenBroadcastModule_basic",
"SquareModule_basic",
Expand Down
24 changes: 24 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7537,6 +7537,30 @@ def Torch_AtenFlattenUsingIntsOp : Torch_Op<"aten.flatten.using_ints", [
}];
}

def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchListOfTorchIntType:$sizes
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenUnflattenIntOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenUnflattenIntOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenDimOp : Torch_Op<"aten.dim", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
55 changes: 55 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2525,6 +2525,60 @@ LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
return success();
}

template <>
LogicalResult ConvertAtenOp<AtenUnflattenIntOp>::matchAndRewrite(
AtenUnflattenIntOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {

// Not a ranked tensor type
auto selfType = adaptor.getSelf().getType().dyn_cast<RankedTensorType>();
if (!selfType || !selfType.hasStaticShape())
return rewriter.notifyMatchFailure(
op,
"Only ranked tensor types with static shapes are currently supported");

int64_t selfRank = selfType.getRank();
int64_t dim;

if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant");

SmallVector<int64_t> sizes;
if (!matchPattern(op.getSizes(), m_TorchListOfConstantInts(sizes)))
return rewriter.notifyMatchFailure(
op, "Only constant sizes are currently supported");

if (selfRank > 0 && !isValidDim(dim, selfRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");

SmallVector<int64_t> newShape;
for (auto s :
llvm::enumerate(makeShapeTorchCompatible(selfType.getShape()))) {
int64_t idx = s.index();
if (idx < dim || idx > dim) {
newShape.push_back(s.value());
} else {
auto sum = 1;
for (auto newDims : sizes) {
newShape.push_back(newDims);
sum *= newDims;
}
if (sum != s.value())
return rewriter.notifyMatchFailure(op,
"sizes mismatch with original dim");
}
}

auto newType = RankedTensorType::get(makeShapeLLVMCompatible(newShape),
selfType.getElementType());

rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, getTypeConverter()->convertType(newType), adaptor.getSelf(),
rewriter.getDenseI64ArrayAttr(newShape));

return success();
}

template <>
LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
AtenPermuteOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -5050,6 +5104,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp);
INSERT_ATENOP_PATTERN(AtenUnflattenIntOp);
INSERT_ATENOP_PATTERN(AtenPermuteOp);
INSERT_ATENOP_PATTERN(AtenLog2Op);
INSERT_ATENOP_PATTERN(AtenThresholdOp);
Expand Down
14 changes: 14 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7205,6 +7205,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.flatten(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.unflatten.int\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.slice.t %arg0, %none, %arg1, %int1 : !torch.list<int>, !torch.none, !torch.int, !torch.int -> !torch.list<int>\n"
" %1 = torch.aten.add.t %0, %arg2 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
" %2 = torch.aten.add.int %arg1, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %3 = torch.aten.slice.t %arg0, %2, %none, %int1 : !torch.list<int>, !torch.int, !torch.none, !torch.int -> !torch.list<int>\n"
" %4 = torch.aten.add.t %1, %3 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
" return %4 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.linear\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.linear(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -8580,6 +8590,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.unflatten.int\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.flip\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
8 changes: 4 additions & 4 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,10 @@ bool Torch::isViewLikeOp(Operation *op) {
// that it does not return a view and treat those as having value
// semantics.
return isa<AtenBroadcastToOp, AtenContiguousOp, AtenDetachOp, AtenExpandAsOp,
AtenExpandOp, AtenFlattenUsingIntsOp, AtenPermuteOp, AtenReshapeOp,
Aten_ReshapeAliasOp, AtenSelectIntOp, AtenSliceTensorOp,
AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp,
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
AtenExpandOp, AtenFlattenUsingIntsOp, AtenUnflattenIntOp,
AtenPermuteOp, AtenReshapeOp, Aten_ReshapeAliasOp, AtenSelectIntOp,
AtenSliceTensorOp, AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp,
AtenToDtypeOp, AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,9 @@ def aten〇adaptive_avg_pool2d〡shape(self: List[int], output_size: List[int])
def aten〇flatten〇using_ints〡shape(self: List[int], start_dim: int = 0, end_dim: int = -1) -> List[int]:
return upstream_shape_functions.flatten(self, start_dim, end_dim)

def aten〇unflatten〇int〡shape(self: List[int], dim: int, sizes: List[int]) -> List[int]:
return self[:dim] + sizes + self[dim + 1:]

def aten〇linear〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None) -> List[int]:
return upstream_shape_functions.linear(input, weight, bias)

Expand Down Expand Up @@ -1656,6 +1659,11 @@ def aten〇flatten〇using_ints〡dtype(self_rank_dtype: Tuple[int, int], start_
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, sizes=[1]))
def aten〇unflatten〇int〡dtype(self_rank_dtype: Tuple[int, int], dim: int, sizes: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[0]))
def aten〇flip〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True)
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)
emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)")
emit("aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)")
emit("aten::dim : (Tensor) -> (int)", has_folder=True)
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
emit("aten::Bool.Tensor : (Tensor) -> (bool)")
Expand Down
22 changes: 22 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,28 @@ def AddmmModule_differentRankBroadcastable(module, tu: TestUtils):
# ==============================================================================


class UnflattenStaticModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([1, 6, 4], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.unflatten(x, 1, (2, 3))


@register_test_case(module_factory=lambda: UnflattenStaticModule())
def UnflattenStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 6, 4))


# ==============================================================================


class FlattenStaticModule(torch.nn.Module):

def __init__(self):
Expand Down
22 changes: 22 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,28 @@ func.func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor

// -----

// CHECK-LABEL: func.func @forward(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,6,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> {
// CHECK: %[[VAL:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,6,4],f32> -> tensor<1x6x4xf32>
// CHECK: %[[VAL_1:.*]] = torch.constant.int 1
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
// CHECK: %[[VAL_3:.*]] = torch.constant.int 3
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL]] {new_shape = array<i64: 1, 2, 3, 4>} : (tensor<1x6x4xf32>) -> tensor<1x2x3x4xf32>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x2x3x4xf32> -> !torch.vtensor<[1,2,3,4],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,2,3,4],f32>
// CHECK: }
func.func @forward(%arg0: !torch.vtensor<[1,6,4],f32> ) -> !torch.vtensor<[1,2,3,4],f32> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[1,6,4],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[1,2,3,4],f32>
return %1 : !torch.vtensor<[1,2,3,4],f32>
}

// -----

// CHECK-LABEL: func.func @forward(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,2,3],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2,3],f32>,
Expand Down

0 comments on commit e649e06

Please sign in to comment.