@@ -39,6 +39,280 @@ using namespace mlir::tosa;
39
39
// Operator Canonicalizers.
40
40
// ===----------------------------------------------------------------------===//
41
41
42
+ // ===----------------------------------------------------------------------===//
43
+ // Tensor Data Engine Operators.
44
+ // ===----------------------------------------------------------------------===//
45
+
46
+ // Check that the zero point of the tensor and padding operations are aligned.
47
+ bool checkMatchingPadConstAndZp (Value padConst, Value zp) {
48
+ // Check that padConst is a constant value and a scalar tensor
49
+ DenseElementsAttr padConstAttr;
50
+ if (!matchPattern (padConst, m_Constant (&padConstAttr)) ||
51
+ (padConstAttr.size () != 1 )) {
52
+ return false ;
53
+ }
54
+
55
+ // Check that floating point pad is zero
56
+ if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
57
+ float padConstVal = (*padConstFpAttr.begin ()).convertToFloat ();
58
+ return padConstVal == 0 .0f ;
59
+ }
60
+
61
+ // Check that the zp and padConst align for the integer (quantized) case
62
+ if (auto padConstIntAttr =
63
+ mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
64
+ DenseIntElementsAttr zpAttr;
65
+ // Check that zp is a constant value and a scalar tensor
66
+ if (!matchPattern (zp, m_Constant (&zpAttr)) || (padConstAttr.size () != 1 )) {
67
+ return false ;
68
+ }
69
+
70
+ // Check equality
71
+ int64_t zpVal = (*zpAttr.begin ()).getSExtValue ();
72
+ int64_t padConstVal = (*padConstIntAttr.begin ()).getSExtValue ();
73
+ return zpVal == padConstVal;
74
+ }
75
+
76
+ // Bail-out on unsupported type
77
+ return false ;
78
+ }
79
+
80
+ namespace {
81
+ template <typename OpTy>
82
+ struct PoolPadFoldAdaptor ;
83
+
84
+ template <>
85
+ struct PoolPadFoldAdaptor <tosa::AvgPool2dOp> {
86
+ using OpTy = tosa::AvgPool2dOp;
87
+ static bool checkKernelCompliance (OpTy op, const ArrayRef<int64_t > newPad) {
88
+ const llvm::ArrayRef<int64_t > kernel = op.getKernel ();
89
+ if (newPad[2 ] >= kernel[1 ] || newPad[3 ] >= kernel[1 ] ||
90
+ newPad[0 ] >= kernel[0 ] || newPad[1 ] >= kernel[0 ])
91
+ return false ;
92
+ return true ;
93
+ }
94
+ static bool checkPadConstCompliance (OpTy op, Value padConst) {
95
+ return checkMatchingPadConstAndZp (padConst, op.getInputZp ());
96
+ }
97
+ static void replaceOpWithNewPad (PatternRewriter &rewriter, OpTy op,
98
+ Value padInput, ArrayRef<int64_t > newPad) {
99
+ rewriter.replaceOpWithNewOp <tosa::AvgPool2dOp>(
100
+ op, op.getType (), padInput, op.getInputZp (), op.getOutputZp (),
101
+ op.getKernel (), op.getStride (), rewriter.getDenseI64ArrayAttr (newPad),
102
+ op.getAccType ());
103
+ }
104
+ };
105
+
106
+ template <>
107
+ struct PoolPadFoldAdaptor <tosa::MaxPool2dOp> {
108
+ using OpTy = tosa::MaxPool2dOp;
109
+ static bool checkKernelCompliance (OpTy op, const ArrayRef<int64_t > newPad) {
110
+ const llvm::ArrayRef<int64_t > kernel = op.getKernel ();
111
+ if (newPad[2 ] >= kernel[1 ] || newPad[3 ] >= kernel[1 ] ||
112
+ newPad[0 ] >= kernel[0 ] || newPad[1 ] >= kernel[0 ])
113
+ return false ;
114
+ return true ;
115
+ }
116
+ static bool checkPadConstCompliance (OpTy, Value padConst) {
117
+ // Check that padConst is a constant value and a scalar tensor
118
+ DenseElementsAttr padConstAttr;
119
+ if (!matchPattern (padConst, m_Constant (&padConstAttr)) ||
120
+ padConstAttr.size () != 1 ) {
121
+ return false ;
122
+ }
123
+
124
+ // Pad needs to be in the minimum value to be able to merge
125
+ if (auto padConstFpAttr =
126
+ mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
127
+ const APFloat padConstVal = *padConstFpAttr.begin ();
128
+ const APFloat lowestVal =
129
+ APFloat::getLargest (padConstVal.getSemantics (), true );
130
+ return padConstVal == lowestVal;
131
+ } else if (auto padConstIntAttr =
132
+ mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
133
+ const APInt padConstVal = *padConstIntAttr.begin ();
134
+ const unsigned int bitWidth = padConstVal.getBitWidth ();
135
+ const APInt lowestVal =
136
+ padConstIntAttr.getElementType ().isUnsignedInteger ()
137
+ ? APInt::getZero (bitWidth)
138
+ : APInt::getSignedMinValue (bitWidth);
139
+ return padConstVal == lowestVal;
140
+ }
141
+
142
+ // Bail-out on unsupported type
143
+ return false ;
144
+ }
145
+ static void replaceOpWithNewPad (PatternRewriter &rewriter, OpTy op,
146
+ Value padInput, ArrayRef<int64_t > newPad) {
147
+ rewriter.replaceOpWithNewOp <tosa::MaxPool2dOp>(
148
+ op, op.getType (), padInput, op.getKernel (), op.getStride (),
149
+ rewriter.getDenseI64ArrayAttr (newPad), op.getNanMode ());
150
+ }
151
+ };
152
+
153
+ template <typename OpTy>
154
+ struct ConvPadFoldAdaptor {
155
+ static bool checkKernelCompliance (OpTy, const ArrayRef<int64_t >) {
156
+ return true ;
157
+ }
158
+ static bool checkPadConstCompliance (OpTy op, Value padConst) {
159
+ return checkMatchingPadConstAndZp (padConst, op.getInputZp ());
160
+ }
161
+ static void replaceOpWithNewPad (PatternRewriter &rewriter, OpTy op,
162
+ Value padInput, ArrayRef<int64_t > newPad) {
163
+ rewriter.replaceOpWithNewOp <OpTy>(
164
+ op, op.getResult ().getType (), padInput, op.getWeight (), op.getBias (),
165
+ op.getInputZp (), op.getWeightZp (), newPad, op.getStrideAttr (),
166
+ op.getDilationAttr (), op.getAccType (), op.getLocalBound ());
167
+ }
168
+ };
169
+
170
+ // Pattern attempts to fold a `tosa.pad` operator to a following tensor
171
+ // operation like `tosa.conv2d` by merging the padding associated with the
172
+ // pad operator directly to the implicit padding of the tensor operation.
173
+ // This helps eliminate the explicit padding operator if unused.
174
+ template <typename OpTy, typename AdaptorTy>
175
+ struct FoldPadToTensorOp : public OpRewritePattern <OpTy> {
176
+ using OpRewritePattern<OpTy>::OpRewritePattern;
177
+
178
+ LogicalResult matchAndRewrite (OpTy tensorOp,
179
+ PatternRewriter &rewriter) const override {
180
+ // Check producer is a tosa::PadOp
181
+ auto padOp = tensorOp.getInput ().template getDefiningOp <tosa::PadOp>();
182
+ if (!padOp)
183
+ return rewriter.notifyMatchFailure (tensorOp,
184
+ " Producer must be a tosa::PadOp." );
185
+
186
+ // Validate that tensor operation has sane padding
187
+ const std::vector<int64_t > &tensorOpPad = tensorOp.getPad ().vec ();
188
+ if (tensorOpPad.size () != 4 ) // pad_top, pad_bottom, pad_left, pad_right
189
+ return rewriter.notifyMatchFailure (
190
+ tensorOp, " Tensor operation padding shall have 4 elements." );
191
+
192
+ // Validate tosa::PadOp padding
193
+ DenseIntElementsAttr padOpPadding;
194
+ if (!matchPattern (padOp.getPadding (), m_Constant (&padOpPadding))) {
195
+ return rewriter.notifyMatchFailure (
196
+ tensorOp,
197
+ " The `padding` input specified on the tosa::PadOp must be constant." );
198
+ }
199
+ // N_before, N_after, H_before, H_after, W_before, W_after, C_before,
200
+ // C_after
201
+ if (padOpPadding.size () != 8 )
202
+ return rewriter.notifyMatchFailure (tensorOp,
203
+ " Pad padding should have 8 elements." );
204
+ int64_t padNBefore = (*(padOpPadding.begin () + 0 )).getLimitedValue ();
205
+ int64_t padNAfter = (*(padOpPadding.begin () + 1 )).getLimitedValue ();
206
+ int64_t padHBefore = (*(padOpPadding.begin () + 2 )).getLimitedValue ();
207
+ int64_t padHAfter = (*(padOpPadding.begin () + 3 )).getLimitedValue ();
208
+ int64_t padWBefore = (*(padOpPadding.begin () + 4 )).getLimitedValue ();
209
+ int64_t padWAfter = (*(padOpPadding.begin () + 5 )).getLimitedValue ();
210
+ int64_t padCBefore = (*(padOpPadding.begin () + 6 )).getLimitedValue ();
211
+ int64_t padCAfter = (*(padOpPadding.begin () + 7 )).getLimitedValue ();
212
+
213
+ if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0 )
214
+ return rewriter.notifyMatchFailure (
215
+ tensorOp, " Folding padding in N or C dimensions is not supported." );
216
+
217
+ // Fold padding from Pad into the tensor operation
218
+ // 4 elements - pad_top, pad_bottom, pad_left, pad_right
219
+ SmallVector<int64_t > foldedPad (tensorOpPad.size ());
220
+ foldedPad[0 ] = padHBefore + tensorOpPad[0 ];
221
+ foldedPad[1 ] = padHAfter + tensorOpPad[1 ];
222
+ foldedPad[2 ] = padWBefore + tensorOpPad[2 ];
223
+ foldedPad[3 ] = padWAfter + tensorOpPad[3 ];
224
+
225
+ // Check kernel related restrictions
226
+ if (!AdaptorTy::checkKernelCompliance (tensorOp, foldedPad)) {
227
+ return rewriter.notifyMatchFailure (
228
+ tensorOp, " Padding size not aligned with kernel restrictions." );
229
+ }
230
+
231
+ // Check padding constant restrictions
232
+ if (!AdaptorTy::checkPadConstCompliance (tensorOp, padOp.getPadConst ())) {
233
+ return rewriter.notifyMatchFailure (
234
+ tensorOp,
235
+ " Padding constant is not aligned with operator zero-point." );
236
+ }
237
+
238
+ // Check that padding doesn't grow more than 8K level (8192) for now
239
+ if (llvm::any_of (foldedPad, [](int64_t padVal) { return padVal > 8192 ; })) {
240
+ return rewriter.notifyMatchFailure (
241
+ tensorOp, " Padding size more than the 8K level limit." );
242
+ }
243
+
244
+ // Create operator
245
+ AdaptorTy::replaceOpWithNewPad (rewriter, tensorOp, padOp.getInput1 (),
246
+ foldedPad);
247
+
248
+ return success ();
249
+ }
250
+ };
251
+ } // namespace
252
+
253
+ void AvgPool2dOp::getCanonicalizationPatterns (RewritePatternSet &results,
254
+ MLIRContext *context) {
255
+ results.add <FoldPadToTensorOp<tosa::AvgPool2dOp,
256
+ PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
257
+ context);
258
+ }
259
+
260
+ void Conv2DOp::getCanonicalizationPatterns (RewritePatternSet &results,
261
+ MLIRContext *context) {
262
+ results.add <
263
+ FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
264
+ context);
265
+ }
266
+
267
+ void DepthwiseConv2DOp::getCanonicalizationPatterns (RewritePatternSet &results,
268
+ MLIRContext *context) {
269
+ results.add <FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
270
+ ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
271
+ context);
272
+ }
273
+
274
+ struct MaxPool2dIsNoOp : public OpRewritePattern <tosa::MaxPool2dOp> {
275
+ using OpRewritePattern::OpRewritePattern;
276
+
277
+ LogicalResult matchAndRewrite (tosa::MaxPool2dOp op,
278
+ PatternRewriter &rewriter) const override {
279
+ Value input = op.getInput ();
280
+ Value output = op.getOutput ();
281
+ ShapedType inputType = llvm::cast<ShapedType>(input.getType ());
282
+ ShapedType outputType = llvm::cast<ShapedType>(output.getType ());
283
+
284
+ if (!inputType.hasStaticShape () || !outputType.hasStaticShape ()) {
285
+ return failure ();
286
+ }
287
+
288
+ // If the output and input shapes are 1x1, then this is a no op.
289
+ ArrayRef<int64_t > outputShape = outputType.getShape ();
290
+ if (outputShape[1 ] != 1 || outputShape[2 ] != 1 ) {
291
+ return failure ();
292
+ }
293
+
294
+ ArrayRef<int64_t > inputShape = inputType.getShape ();
295
+ if (inputShape[1 ] != 1 || inputShape[2 ] != 1 ) {
296
+ return failure ();
297
+ }
298
+
299
+ rewriter.replaceOp (op, input);
300
+ return success ();
301
+ }
302
+ };
303
+
304
+ void MaxPool2dOp::getCanonicalizationPatterns (RewritePatternSet &results,
305
+ MLIRContext *context) {
306
+ results.add <MaxPool2dIsNoOp,
307
+ FoldPadToTensorOp<tosa::MaxPool2dOp,
308
+ PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
309
+ context);
310
+ }
311
+
312
+ // ===----------------------------------------------------------------------===//
313
+ // Data Layout / Memory Reinterpretation.
314
+ // ===----------------------------------------------------------------------===//
315
+
42
316
struct ConcatOptimization : public OpRewritePattern <tosa::ConcatOp> {
43
317
using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
44
318
@@ -175,41 +449,6 @@ void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
175
449
results.add <ConsolidateTransposeOptimization, TransposeIsReshape>(context);
176
450
}
177
451
178
- struct MaxPool2dIsNoOp : public OpRewritePattern <tosa::MaxPool2dOp> {
179
- using OpRewritePattern::OpRewritePattern;
180
-
181
- LogicalResult matchAndRewrite (tosa::MaxPool2dOp op,
182
- PatternRewriter &rewriter) const override {
183
- Value input = op.getInput ();
184
- Value output = op.getOutput ();
185
- ShapedType inputType = llvm::cast<ShapedType>(input.getType ());
186
- ShapedType outputType = llvm::cast<ShapedType>(output.getType ());
187
-
188
- if (!inputType.hasStaticShape () || !outputType.hasStaticShape ()) {
189
- return failure ();
190
- }
191
-
192
- // If the output and input shapes are 1x1, then this is a no op.
193
- ArrayRef<int64_t > outputShape = outputType.getShape ();
194
- if (outputShape[1 ] != 1 || outputShape[2 ] != 1 ) {
195
- return failure ();
196
- }
197
-
198
- ArrayRef<int64_t > inputShape = inputType.getShape ();
199
- if (inputShape[1 ] != 1 || inputShape[2 ] != 1 ) {
200
- return failure ();
201
- }
202
-
203
- rewriter.replaceOp (op, input);
204
- return success ();
205
- }
206
- };
207
-
208
- void MaxPool2dOp::getCanonicalizationPatterns (RewritePatternSet &results,
209
- MLIRContext *context) {
210
- results.add <MaxPool2dIsNoOp>(context);
211
- }
212
-
213
452
struct ClampIsNoOp : public OpRewritePattern <tosa::ClampOp> {
214
453
using OpRewritePattern::OpRewritePattern;
215
454
0 commit comments