Skip to content

Commit 9c38b2e

Browse files
authored
[mlir][tosa] Fold PadOp to tensor operations (#132700)
1 parent 1365b5b commit 9c38b2e

File tree

3 files changed

+431
-35
lines changed

3 files changed

+431
-35
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
107107
LogicalResult verifyOutputZeroPoint(int64_t zp);
108108
}];
109109

110+
let hasCanonicalizer = 1;
110111
let hasVerifier = 1;
111112
}
112113

@@ -153,6 +154,8 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
153154
}];
154155

155156
let builders = [Tosa_ConvOpQuantInfoBuilder];
157+
158+
let hasCanonicalizer = 1;
156159
let hasVerifier = 1;
157160
}
158161

@@ -244,6 +247,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
244247
}];
245248

246249
let builders = [Tosa_ConvOpQuantInfoBuilder];
250+
251+
let hasCanonicalizer = 1;
247252
let hasVerifier = 1;
248253
}
249254

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 274 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,280 @@ using namespace mlir::tosa;
3939
// Operator Canonicalizers.
4040
//===----------------------------------------------------------------------===//
4141

42+
//===----------------------------------------------------------------------===//
43+
// Tensor Data Engine Operators.
44+
//===----------------------------------------------------------------------===//
45+
46+
// Check that the zero point of the tensor and padding operations are aligned.
47+
bool checkMatchingPadConstAndZp(Value padConst, Value zp) {
48+
// Check that padConst is a constant value and a scalar tensor
49+
DenseElementsAttr padConstAttr;
50+
if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
51+
(padConstAttr.size() != 1)) {
52+
return false;
53+
}
54+
55+
// Check that floating point pad is zero
56+
if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
57+
float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
58+
return padConstVal == 0.0f;
59+
}
60+
61+
// Check that the zp and padConst align for the integer (quantized) case
62+
if (auto padConstIntAttr =
63+
mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
64+
DenseIntElementsAttr zpAttr;
65+
// Check that zp is a constant value and a scalar tensor
66+
if (!matchPattern(zp, m_Constant(&zpAttr)) || (padConstAttr.size() != 1)) {
67+
return false;
68+
}
69+
70+
// Check equality
71+
int64_t zpVal = (*zpAttr.begin()).getSExtValue();
72+
int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
73+
return zpVal == padConstVal;
74+
}
75+
76+
// Bail-out on unsupported type
77+
return false;
78+
}
79+
80+
namespace {
81+
template <typename OpTy>
82+
struct PoolPadFoldAdaptor;
83+
84+
template <>
85+
struct PoolPadFoldAdaptor<tosa::AvgPool2dOp> {
86+
using OpTy = tosa::AvgPool2dOp;
87+
static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
88+
const llvm::ArrayRef<int64_t> kernel = op.getKernel();
89+
if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
90+
newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
91+
return false;
92+
return true;
93+
}
94+
static bool checkPadConstCompliance(OpTy op, Value padConst) {
95+
return checkMatchingPadConstAndZp(padConst, op.getInputZp());
96+
}
97+
static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
98+
Value padInput, ArrayRef<int64_t> newPad) {
99+
rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>(
100+
op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(),
101+
op.getKernel(), op.getStride(), rewriter.getDenseI64ArrayAttr(newPad),
102+
op.getAccType());
103+
}
104+
};
105+
106+
template <>
107+
struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
108+
using OpTy = tosa::MaxPool2dOp;
109+
static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
110+
const llvm::ArrayRef<int64_t> kernel = op.getKernel();
111+
if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
112+
newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
113+
return false;
114+
return true;
115+
}
116+
static bool checkPadConstCompliance(OpTy, Value padConst) {
117+
// Check that padConst is a constant value and a scalar tensor
118+
DenseElementsAttr padConstAttr;
119+
if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
120+
padConstAttr.size() != 1) {
121+
return false;
122+
}
123+
124+
// Pad needs to be in the minimum value to be able to merge
125+
if (auto padConstFpAttr =
126+
mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
127+
const APFloat padConstVal = *padConstFpAttr.begin();
128+
const APFloat lowestVal =
129+
APFloat::getLargest(padConstVal.getSemantics(), true);
130+
return padConstVal == lowestVal;
131+
} else if (auto padConstIntAttr =
132+
mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
133+
const APInt padConstVal = *padConstIntAttr.begin();
134+
const unsigned int bitWidth = padConstVal.getBitWidth();
135+
const APInt lowestVal =
136+
padConstIntAttr.getElementType().isUnsignedInteger()
137+
? APInt::getZero(bitWidth)
138+
: APInt::getSignedMinValue(bitWidth);
139+
return padConstVal == lowestVal;
140+
}
141+
142+
// Bail-out on unsupported type
143+
return false;
144+
}
145+
static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
146+
Value padInput, ArrayRef<int64_t> newPad) {
147+
rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
148+
op, op.getType(), padInput, op.getKernel(), op.getStride(),
149+
rewriter.getDenseI64ArrayAttr(newPad), op.getNanMode());
150+
}
151+
};
152+
153+
template <typename OpTy>
154+
struct ConvPadFoldAdaptor {
155+
static bool checkKernelCompliance(OpTy, const ArrayRef<int64_t>) {
156+
return true;
157+
}
158+
static bool checkPadConstCompliance(OpTy op, Value padConst) {
159+
return checkMatchingPadConstAndZp(padConst, op.getInputZp());
160+
}
161+
static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
162+
Value padInput, ArrayRef<int64_t> newPad) {
163+
rewriter.replaceOpWithNewOp<OpTy>(
164+
op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
165+
op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
166+
op.getDilationAttr(), op.getAccType(), op.getLocalBound());
167+
}
168+
};
169+
170+
// Pattern attempts to fold a `tosa.pad` operator to a following tensor
171+
// operation like `tosa.conv2d` by merging the padding associated with the
172+
// pad operator directly to the implicit padding of the tensor operation.
173+
// This helps eliminate the explicit padding operator if unused.
174+
template <typename OpTy, typename AdaptorTy>
175+
struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
176+
using OpRewritePattern<OpTy>::OpRewritePattern;
177+
178+
LogicalResult matchAndRewrite(OpTy tensorOp,
179+
PatternRewriter &rewriter) const override {
180+
// Check producer is a tosa::PadOp
181+
auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
182+
if (!padOp)
183+
return rewriter.notifyMatchFailure(tensorOp,
184+
"Producer must be a tosa::PadOp.");
185+
186+
// Validate that tensor operation has sane padding
187+
const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
188+
if (tensorOpPad.size() != 4) // pad_top, pad_bottom, pad_left, pad_right
189+
return rewriter.notifyMatchFailure(
190+
tensorOp, "Tensor operation padding shall have 4 elements.");
191+
192+
// Validate tosa::PadOp padding
193+
DenseIntElementsAttr padOpPadding;
194+
if (!matchPattern(padOp.getPadding(), m_Constant(&padOpPadding))) {
195+
return rewriter.notifyMatchFailure(
196+
tensorOp,
197+
"The `padding` input specified on the tosa::PadOp must be constant.");
198+
}
199+
// N_before, N_after, H_before, H_after, W_before, W_after, C_before,
200+
// C_after
201+
if (padOpPadding.size() != 8)
202+
return rewriter.notifyMatchFailure(tensorOp,
203+
"Pad padding should have 8 elements.");
204+
int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue();
205+
int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue();
206+
int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue();
207+
int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue();
208+
int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue();
209+
int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue();
210+
int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue();
211+
int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue();
212+
213+
if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
214+
return rewriter.notifyMatchFailure(
215+
tensorOp, "Folding padding in N or C dimensions is not supported.");
216+
217+
// Fold padding from Pad into the tensor operation
218+
// 4 elements - pad_top, pad_bottom, pad_left, pad_right
219+
SmallVector<int64_t> foldedPad(tensorOpPad.size());
220+
foldedPad[0] = padHBefore + tensorOpPad[0];
221+
foldedPad[1] = padHAfter + tensorOpPad[1];
222+
foldedPad[2] = padWBefore + tensorOpPad[2];
223+
foldedPad[3] = padWAfter + tensorOpPad[3];
224+
225+
// Check kernel related restrictions
226+
if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
227+
return rewriter.notifyMatchFailure(
228+
tensorOp, "Padding size not aligned with kernel restrictions.");
229+
}
230+
231+
// Check padding constant restrictions
232+
if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
233+
return rewriter.notifyMatchFailure(
234+
tensorOp,
235+
"Padding constant is not aligned with operator zero-point.");
236+
}
237+
238+
// Check that padding doesn't grow more than 8K level (8192) for now
239+
if (llvm::any_of(foldedPad, [](int64_t padVal) { return padVal > 8192; })) {
240+
return rewriter.notifyMatchFailure(
241+
tensorOp, "Padding size more than the 8K level limit.");
242+
}
243+
244+
// Create operator
245+
AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
246+
foldedPad);
247+
248+
return success();
249+
}
250+
};
251+
} // namespace
252+
253+
void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
254+
MLIRContext *context) {
255+
results.add<FoldPadToTensorOp<tosa::AvgPool2dOp,
256+
PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
257+
context);
258+
}
259+
260+
void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
261+
MLIRContext *context) {
262+
results.add<
263+
FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
264+
context);
265+
}
266+
267+
void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
268+
MLIRContext *context) {
269+
results.add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
270+
ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
271+
context);
272+
}
273+
274+
struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
275+
using OpRewritePattern::OpRewritePattern;
276+
277+
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
278+
PatternRewriter &rewriter) const override {
279+
Value input = op.getInput();
280+
Value output = op.getOutput();
281+
ShapedType inputType = llvm::cast<ShapedType>(input.getType());
282+
ShapedType outputType = llvm::cast<ShapedType>(output.getType());
283+
284+
if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
285+
return failure();
286+
}
287+
288+
// If the output and input shapes are 1x1, then this is a no op.
289+
ArrayRef<int64_t> outputShape = outputType.getShape();
290+
if (outputShape[1] != 1 || outputShape[2] != 1) {
291+
return failure();
292+
}
293+
294+
ArrayRef<int64_t> inputShape = inputType.getShape();
295+
if (inputShape[1] != 1 || inputShape[2] != 1) {
296+
return failure();
297+
}
298+
299+
rewriter.replaceOp(op, input);
300+
return success();
301+
}
302+
};
303+
304+
void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
305+
MLIRContext *context) {
306+
results.add<MaxPool2dIsNoOp,
307+
FoldPadToTensorOp<tosa::MaxPool2dOp,
308+
PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
309+
context);
310+
}
311+
312+
//===----------------------------------------------------------------------===//
313+
// Data Layout / Memory Reinterpretation.
314+
//===----------------------------------------------------------------------===//
315+
42316
struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
43317
using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
44318

@@ -175,41 +449,6 @@ void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
175449
results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
176450
}
177451

178-
struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
179-
using OpRewritePattern::OpRewritePattern;
180-
181-
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
182-
PatternRewriter &rewriter) const override {
183-
Value input = op.getInput();
184-
Value output = op.getOutput();
185-
ShapedType inputType = llvm::cast<ShapedType>(input.getType());
186-
ShapedType outputType = llvm::cast<ShapedType>(output.getType());
187-
188-
if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
189-
return failure();
190-
}
191-
192-
// If the output and input shapes are 1x1, then this is a no op.
193-
ArrayRef<int64_t> outputShape = outputType.getShape();
194-
if (outputShape[1] != 1 || outputShape[2] != 1) {
195-
return failure();
196-
}
197-
198-
ArrayRef<int64_t> inputShape = inputType.getShape();
199-
if (inputShape[1] != 1 || inputShape[2] != 1) {
200-
return failure();
201-
}
202-
203-
rewriter.replaceOp(op, input);
204-
return success();
205-
}
206-
};
207-
208-
void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
209-
MLIRContext *context) {
210-
results.add<MaxPool2dIsNoOp>(context);
211-
}
212-
213452
struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
214453
using OpRewritePattern::OpRewritePattern;
215454

0 commit comments

Comments
 (0)