Skip to content

Commit d408934

Browse files
committed
[MLIR][TORCH] Add E2E support for aten.squeeze op
This commit adds lowering of `aten.Squeeze` op into `linalg.TensorCollapseShape` op. The size 1 dynamic dimensions are not handled as a part of this commit. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
1 parent 36afa4a commit d408934

File tree

9 files changed

+275
-1
lines changed

9 files changed

+275
-1
lines changed

e2e_testing/torchscript/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from . import matmul
4242
from . import view
4343
from . import scalar
44+
from . import squeeze
4445

4546
def _get_argparse():
4647
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']

e2e_testing/torchscript/squeeze.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
# Also available under a BSD-style license. See LICENSE.
5+
6+
import torch
7+
8+
from torch_mlir_e2e_test.torchscript.framework import TestUtils
9+
from torch_mlir_e2e_test.torchscript.registry import register_test_case
10+
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
11+
12+
# ==============================================================================
13+
14+
15+
class SqueezeStaticModule(torch.nn.Module):
16+
def __init__(self):
17+
super().__init__()
18+
19+
@export
20+
@annotate_args([
21+
None,
22+
([1, 7, 1, 3, 1], torch.float32, True),
23+
])
24+
def forward(self, a):
25+
return torch.squeeze(a)
26+
27+
28+
@register_test_case(
29+
module_factory=lambda: SqueezeStaticModule())
30+
def SqueezeModule_static(module, tu: TestUtils):
31+
module.forward(tu.rand(1, 7, 1, 3, 1))
32+
33+
34+
# ==============================================================================
35+
36+
37+
class SqueezeDynamicModule(torch.nn.Module):
38+
def __init__(self):
39+
super().__init__()
40+
41+
@export
42+
@annotate_args([
43+
None,
44+
([1, -1, 1, 384, -1, 1, 1], torch.float32, True),
45+
])
46+
def forward(self, a):
47+
return torch.squeeze(a)
48+
49+
50+
@register_test_case(
51+
module_factory=lambda: SqueezeDynamicModule())
52+
def SqueezeModule_dynamic(module, tu: TestUtils):
53+
module.forward(tu.rand(1, 8, 1, 384, 12, 1, 1))
54+
55+
56+
# ==============================================================================
57+
58+
59+
class SqueezeNoUnitDimModule(torch.nn.Module):
60+
def __init__(self):
61+
super().__init__()
62+
63+
@export
64+
@annotate_args([
65+
None,
66+
([4, -1, -1], torch.float32, True),
67+
])
68+
def forward(self, a):
69+
return torch.squeeze(a)
70+
71+
72+
@register_test_case(
73+
module_factory=lambda: SqueezeNoUnitDimModule())
74+
def SqueezeModule_noUnitDim(module, tu: TestUtils):
75+
module.forward(tu.rand(4, 2, 3))
76+
77+
78+
# ==============================================================================
79+
80+
81+
class SqueezeAllUnitDimModule(torch.nn.Module):
82+
def __init__(self):
83+
super().__init__()
84+
85+
@export
86+
@annotate_args([
87+
None,
88+
([1, 1], torch.float32, True),
89+
])
90+
def forward(self, a):
91+
return torch.squeeze(a)
92+
93+
94+
@register_test_case(
95+
module_factory=lambda: SqueezeAllUnitDimModule())
96+
def SqueezeModule_allUnitDim(module, tu: TestUtils):
97+
module.forward(tu.rand(1, 1))
98+
99+
100+
# ==============================================================================
101+
102+
103+
class SqueezeBroadcastModule(torch.nn.Module):
104+
def __init__(self):
105+
super().__init__()
106+
107+
@export
108+
@annotate_args([
109+
None,
110+
([-1, -1], torch.float32, True),
111+
([], torch.float32, True),
112+
])
113+
def forward(self, a, b):
114+
return a * b.squeeze()
115+
116+
117+
@register_test_case(
118+
module_factory=lambda: SqueezeBroadcastModule())
119+
def SqueezeModule_broadcast(module, tu: TestUtils):
120+
module.forward(tu.rand(4, 3), tu.rand())
121+

include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,6 +1452,20 @@ def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [
14521452
let assemblyFormat = "$self `,` $dim attr-dict `:` type($self) `,` type($dim) `->` type($result)";
14531453
}
14541454

1455+
def Torch_AtenSqueezeOp : Torch_Op<"aten.squeeze", [
1456+
AllowsTypeRefinement
1457+
]> {
1458+
let summary = "Generated op for `aten::squeeze : (Tensor) -> (Tensor)`";
1459+
let arguments = (ins
1460+
AnyTorchTensorType:$self
1461+
);
1462+
let results = (outs
1463+
AnyTorchTensorType:$result
1464+
);
1465+
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
1466+
let hasFolder = 1;
1467+
}
1468+
14551469
def Torch_AtenFlattenUsingIntsOp : Torch_Op<"aten.flatten.using_ints", [
14561470
AllowsTypeRefinement
14571471
]> {

lib/Conversion/TorchToLinalg/TorchToLinalg.cpp

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2388,6 +2388,97 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
23882388
};
23892389
} // namespace
23902390

2391+
namespace {
2392+
class ConvertAtenSqueezeOp : public OpConversionPattern<AtenSqueezeOp> {
2393+
public:
2394+
using OpConversionPattern::OpConversionPattern;
2395+
LogicalResult
2396+
matchAndRewrite(AtenSqueezeOp op, OpAdaptor adaptor,
2397+
ConversionPatternRewriter &rewriter) const override {
2398+
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
2399+
return failure();
2400+
Location loc = op.getLoc();
2401+
Value input = adaptor.self();
2402+
auto inputType = input.getType().cast<RankedTensorType>();
2403+
int64_t inputRank = inputType.getRank();
2404+
TypeConverter *typeConverter = getTypeConverter();
2405+
auto resultType =
2406+
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
2407+
int64_t resultRank = resultType.getRank();
2408+
2409+
if (inputRank == 0) {
2410+
return rewriter.notifyMatchFailure(
2411+
op, "zero input rank should have been handled by the folder");
2412+
}
2413+
2414+
// In case the operand tensor type is statically shaped with all dimensions
2415+
// being unit extent, it will be collapsed to a 0-D tensor.
2416+
if (resultRank == 0) {
2417+
SmallVector<ReassociationIndices> reassociation;
2418+
rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
2419+
op, resultType, input, reassociation);
2420+
return success();
2421+
}
2422+
2423+
// All the static size-1 dimensions at the beginning(going from higher to
2424+
// lower dimensions) will be collapsed into the first dynamic or first non
2425+
// size-1 static dimension. All the other static size-1 dimensions will be
2426+
// collapsed into its previous dynamic or non size-1 static dimension.
2427+
SmallVector<ReassociationIndices> reassociation(resultRank);
2428+
bool isSqueezed = false;
2429+
int64_t headOnesCount = 0;
2430+
while (headOnesCount < inputRank &&
2431+
inputType.getDimSize(headOnesCount) == 1) {
2432+
isSqueezed = true;
2433+
reassociation[0].push_back(headOnesCount++);
2434+
}
2435+
2436+
// TODO: Add support for size-1 dynamic dimensions.
2437+
Value one = rewriter.create<arith::ConstantOp>(
2438+
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
2439+
int64_t j = -1;
2440+
for (auto i : llvm::seq<int64_t>(headOnesCount, inputRank)) {
2441+
if (inputType.isDynamicDim(i)) {
2442+
// Make sure that size-1 dynamic dimension does not exist.
2443+
Value dimSize = getDimOp(rewriter, loc, input, i);
2444+
Value dimSizeNotOne = rewriter.create<arith::CmpIOp>(
2445+
loc, arith::CmpIPredicate::ne, dimSize, one);
2446+
rewriter.create<AssertOp>(
2447+
loc, dimSizeNotOne,
2448+
rewriter.getStringAttr(
2449+
"unimplemented: size 1 dynamic dimension is not supported"));
2450+
++j;
2451+
} else if (inputType.getDimSize(i) != 1) {
2452+
++j;
2453+
} else {
2454+
// `isSqueezed` checks if the operand tensor type contains at least one
2455+
// unit dimension.
2456+
isSqueezed = true;
2457+
}
2458+
if (j == resultRank)
2459+
break;
2460+
reassociation[j].push_back(i);
2461+
}
2462+
2463+
// Make sure that result type rank is compatible with the squeezed size.
2464+
if (j != resultRank - 1)
2465+
return rewriter.notifyMatchFailure(
2466+
op, "expected output size mismatches with the result type rank");
2467+
2468+
if (isSqueezed) {
2469+
rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
2470+
op, resultType, input, reassociation);
2471+
2472+
} else {
2473+
// If the operand tensor type does not have any unit dimension,
2474+
// `aten.squeeze` will behave as an identity operation.
2475+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);
2476+
}
2477+
return success();
2478+
}
2479+
};
2480+
} // namespace
2481+
23912482
namespace {
23922483
class ConvertAtenUnsqueezeOp : public OpConversionPattern<AtenUnsqueezeOp> {
23932484
public:
@@ -3057,6 +3148,8 @@ class ConvertTorchToLinalg
30573148
AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
30583149
AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp>();
30593150
patterns.add<ConvertElementwiseOp>(typeConverter, context);
3151+
target.addIllegalOp<AtenSqueezeOp>();
3152+
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
30603153
target.addIllegalOp<AtenUnsqueezeOp>();
30613154
patterns.add<ConvertAtenUnsqueezeOp>(typeConverter, context);
30623155
target.addIllegalOp<AtenConv2dOp>();

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,18 @@ OpFoldResult AtenNeBoolOp::fold(ArrayRef<Attribute> operands) {
450450
return IntegerAttr::get(IntegerType::get(getContext(), 1), a != b);
451451
}
452452

453+
//===----------------------------------------------------------------------===//
454+
// AtenSqueezeOp
455+
//===----------------------------------------------------------------------===//
456+
457+
OpFoldResult AtenSqueezeOp::fold(ArrayRef<Attribute> operands) {
458+
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) {
459+
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
460+
return getOperand();
461+
}
462+
return nullptr;
463+
}
464+
453465
//===----------------------------------------------------------------------===//
454466
// AtenDimOp
455467
//===----------------------------------------------------------------------===//

lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class RewriteViewLikeSubgraph
8989
Operation *op = workList.pop_back_val();
9090
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) {
9191
copyToValueTensorOps.push_back(copyToValueTensor);
92-
} else if (isa<AtenUnsqueezeOp, AtenFlattenUsingIntsOp,
92+
} else if (isa<AtenSqueezeOp, AtenUnsqueezeOp, AtenFlattenUsingIntsOp,
9393
AtenTransposeIntOp, TensorStaticInfoCastOp,
9494
AtenBroadcastToOp, AtenToDtypeOp, AtenContiguousOp,
9595
AtenPermuteOp, AtenViewOp, AtenExpandOp,

lib/Dialect/Torch/Transforms/RefineTypes.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
300300
return visitAtenLerpTensorOp(lerpTensor, operands);
301301
} else if (auto flatten = dyn_cast<AtenFlattenUsingIntsOp>(op)) {
302302
return visitAtenFlattenUsingIntsOp(flatten, operands);
303+
} else if (auto squeeze = dyn_cast<AtenSqueezeOp>(op)) {
304+
return visitAtenSqueezeOp(squeeze, operands);
303305
} else if (auto unsqueeze = dyn_cast<AtenUnsqueezeOp>(op)) {
304306
return visitAtenUnsqueezeOp(unsqueeze, operands);
305307
} else if (auto arange = dyn_cast<AtenArangeOp>(op)) {
@@ -466,6 +468,9 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
466468
AtenFlattenUsingIntsOp op,
467469
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
468470
ChangeResult
471+
visitAtenSqueezeOp(AtenSqueezeOp op,
472+
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
473+
ChangeResult
469474
visitAtenUnsqueezeOp(AtenUnsqueezeOp op,
470475
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
471476

@@ -880,6 +885,25 @@ ChangeResult TypeAnalyzer::visitAtenFlattenUsingIntsOp(
880885
return getLatticeElement(op.getResult()).join(knowledge);
881886
}
882887

888+
ChangeResult TypeAnalyzer::visitAtenSqueezeOp(
889+
AtenSqueezeOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
890+
auto operand = operands[0]->getValue();
891+
auto knowledge =
892+
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
893+
knowledge.dtype = operand.dtype;
894+
if (operand.hasSizes) {
895+
int64_t inputRank = operand.sizes.size();
896+
knowledge.hasSizes = true;
897+
// `knowledge.sizes` will be empty when either `inputRank` is 0 or operand
898+
// tensor type is statically shaped with all dimensions being unit.
899+
// Note: size-1 dynamic dimensions are not supported yet.
900+
for (auto i = 0; i < inputRank; i++)
901+
if (operand.sizes[i] != 1)
902+
knowledge.sizes.push_back(operand.sizes[i]);
903+
}
904+
return getLatticeElement(op.getResult()).join(knowledge);
905+
}
906+
883907
ChangeResult TypeAnalyzer::visitAtenUnsqueezeOp(
884908
AtenUnsqueezeOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
885909
auto operand = operands[0]->getValue();

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,7 @@ def emit_with_mutating_variants(key, **kwargs):
521521

522522
# Misc tensor ops.
523523
emit("aten::unsqueeze : (Tensor, int) -> (Tensor)")
524+
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)
524525
emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)")
525526
emit("aten::dim : (Tensor) -> (int)", has_folder=True)
526527
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)

test/Dialect/Torch/canonicalize.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,3 +594,11 @@ func @torch.prim.TupleIndex$out_of_bound(%t0: !torch.tensor, %t1: !torch.tensor,
594594
%1 = torch.prim.TupleIndex %0, %int3 : !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
595595
return %1 : !torch.tensor
596596
}
597+
598+
// CHECK-LABEL: func @torch.aten.squeeze$zero_rank(
599+
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
600+
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[],f32>
601+
func @torch.aten.squeeze$zero_rank(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
602+
%0 = torch.aten.squeeze %arg0 : !torch.tensor<[],f32> -> !torch.tensor<[],f32>
603+
return %0 : !torch.tensor<[],f32>
604+
}

0 commit comments

Comments
 (0)