7
7
// ===----------------------------------------------------------------------===//
8
8
9
9
#include " mlir/Dialect/Linalg/IR/Linalg.h"
10
+ #include " mlir/Dialect/Linalg/Transforms/Transforms.h"
10
11
#include " mlir/Dialect/Tensor/IR/Tensor.h"
11
12
#include " mlir/Dialect/Tensor/Transforms/Transforms.h"
12
13
#include " mlir/Dialect/Utils/IndexingUtils.h"
@@ -197,7 +198,9 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
197
198
// / Fold a `pad` -> `pack` into `pack` if they have the same padding values and
198
199
// / the pad op has zero low paddings, or if `pack` has no padding values.
199
200
struct FoldPadWithPackOp : public OpRewritePattern <PackOp> {
200
- using OpRewritePattern<PackOp>::OpRewritePattern;
201
+ public:
202
+ FoldPadWithPackOp (MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
203
+ : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
201
204
202
205
LogicalResult matchAndRewrite (PackOp packOp,
203
206
PatternRewriter &rewriter) const override {
@@ -206,6 +209,9 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
206
209
if (!padOp || padOp.getNofold () || !padOp.hasZeroLowPad ())
207
210
return failure ();
208
211
212
+ if (!controlFn (&packOp.getSourceMutable ()))
213
+ return failure ();
214
+
209
215
Value constantPaddingValue = padOp.getConstantPaddingValue ();
210
216
if (!constantPaddingValue)
211
217
return failure ();
@@ -220,20 +226,31 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
220
226
packOp.getOuterDimsPerm ());
221
227
return success ();
222
228
}
229
+
230
+ private:
231
+ ControlFoldIntoPackUnpackFn controlFn;
223
232
};
224
233
225
234
// / Fold a `unpack` -> `extract_slice` into the `unpack` since it already
226
235
// / has extract_slice semantics.
227
236
struct FoldUnpackWithExtractSliceOp
228
237
: public OpRewritePattern<tensor::ExtractSliceOp> {
229
- using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
238
+ public:
239
+ FoldUnpackWithExtractSliceOp (MLIRContext *context,
240
+ ControlFoldIntoPackUnpackFn controlFn)
241
+ : OpRewritePattern<tensor::ExtractSliceOp>(context),
242
+ controlFn (std::move(controlFn)) {}
230
243
231
244
LogicalResult matchAndRewrite (tensor::ExtractSliceOp sliceOp,
232
245
PatternRewriter &rewriter) const override {
233
246
auto unpackOp = sliceOp.getSource ().getDefiningOp <UnPackOp>();
234
247
if (!unpackOp)
235
248
return failure ();
236
249
250
+ // User controlled folding function.
251
+ if (!controlFn (&sliceOp.getSourceMutable ()))
252
+ return failure ();
253
+
237
254
if (sliceOp.getResultType ().getRank () != unpackOp.getDestType ().getRank ()) {
238
255
return rewriter.notifyMatchFailure (
239
256
sliceOp, " rank-reduced folding is not supported" );
@@ -255,6 +272,9 @@ struct FoldUnpackWithExtractSliceOp
255
272
unpackOp.getMixedTiles (), unpackOp.getOuterDimsPerm ());
256
273
return success ();
257
274
}
275
+
276
+ private:
277
+ ControlFoldIntoPackUnpackFn controlFn;
258
278
};
259
279
260
280
// Applies 'permutation' on 'inVec' and stores the result in resVec.
@@ -284,7 +304,12 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
284
304
// / semantics.
285
305
struct FoldProducerPackWithConsumerLinalgTransposeOp
286
306
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
287
- using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
307
+
308
+ public:
309
+ FoldProducerPackWithConsumerLinalgTransposeOp (
310
+ MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
311
+ : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
312
+ controlFn (std::move(controlFn)) {}
288
313
289
314
LogicalResult matchAndRewrite (linalg::LinalgOp linalgOp,
290
315
PatternRewriter &rewriter) const override {
@@ -293,6 +318,9 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
293
318
if (!packOp)
294
319
return failure ();
295
320
321
+ if (!controlFn (&linalgOp->getOpOperand (0 )))
322
+ return failure ();
323
+
296
324
FailureOr<SmallVector<int64_t >> maybePerm =
297
325
getTransposeOpPermutation (linalgOp);
298
326
if (failed (maybePerm))
@@ -331,20 +359,30 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
331
359
332
360
return success ();
333
361
}
362
+
363
+ private:
364
+ ControlFoldIntoPackUnpackFn controlFn;
334
365
};
335
366
336
367
// / Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
337
368
// / semantics.
338
369
struct FoldConsumerPackWithProducerLinalgTransposeOp
339
370
: public OpRewritePattern<PackOp> {
340
- using OpRewritePattern<PackOp>::OpRewritePattern;
371
+
372
+ public:
373
+ FoldConsumerPackWithProducerLinalgTransposeOp (
374
+ MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
375
+ : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
341
376
342
377
LogicalResult matchAndRewrite (PackOp packOp,
343
378
PatternRewriter &rewriter) const override {
344
379
auto linalgOp = packOp.getSource ().getDefiningOp <linalg::LinalgOp>();
345
380
if (!linalgOp)
346
381
return failure ();
347
382
383
+ if (!controlFn (&packOp.getSourceMutable ()))
384
+ return failure ();
385
+
348
386
FailureOr<SmallVector<int64_t >> maybePerm =
349
387
getTransposeOpPermutation (linalgOp);
350
388
if (failed (maybePerm))
@@ -375,13 +413,21 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
375
413
376
414
return success ();
377
415
}
416
+
417
+ private:
418
+ ControlFoldIntoPackUnpackFn controlFn;
378
419
};
379
420
380
421
// / Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
381
422
// / transpose semantics.
382
423
struct FoldProducerUnPackWithConsumerLinalgTransposeOp
383
424
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
384
- using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
425
+
426
+ public:
427
+ FoldProducerUnPackWithConsumerLinalgTransposeOp (
428
+ MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
429
+ : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
430
+ controlFn (std::move(controlFn)) {}
385
431
386
432
LogicalResult matchAndRewrite (linalg::LinalgOp linalgOp,
387
433
PatternRewriter &rewriter) const override {
@@ -390,6 +436,9 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
390
436
if (!unPackOp)
391
437
return failure ();
392
438
439
+ if (!controlFn (&linalgOp->getOpOperand (0 )))
440
+ return failure ();
441
+
393
442
FailureOr<SmallVector<int64_t >> maybePerm =
394
443
getTransposeOpPermutation (linalgOp);
395
444
if (failed (maybePerm))
@@ -416,6 +465,9 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
416
465
417
466
return success ();
418
467
}
468
+
469
+ private:
470
+ ControlFoldIntoPackUnpackFn controlFn;
419
471
};
420
472
421
473
// / Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
@@ -424,12 +476,20 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
424
476
: public OpRewritePattern<UnPackOp> {
425
477
using OpRewritePattern<UnPackOp>::OpRewritePattern;
426
478
479
+ public:
480
+ FoldConsumerUnPackWithProducerLinalgTransposeOp (
481
+ MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
482
+ : OpRewritePattern<UnPackOp>(context), controlFn(std::move(controlFn)) {}
483
+
427
484
LogicalResult matchAndRewrite (UnPackOp unPackOp,
428
485
PatternRewriter &rewriter) const override {
429
486
auto linalgOp = unPackOp.getSource ().getDefiningOp <linalg::LinalgOp>();
430
487
if (!linalgOp)
431
488
return failure ();
432
489
490
+ if (!controlFn (&unPackOp.getSourceMutable ()))
491
+ return failure ();
492
+
433
493
FailureOr<SmallVector<int64_t >> maybePerm =
434
494
getTransposeOpPermutation (linalgOp);
435
495
if (failed (maybePerm))
@@ -474,6 +534,9 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
474
534
475
535
return success ();
476
536
}
537
+
538
+ private:
539
+ ControlFoldIntoPackUnpackFn controlFn;
477
540
};
478
541
479
542
// / tensor.empty does not define any tensor contents, so an unpadded pack
@@ -521,13 +584,14 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
521
584
522
585
} // namespace
523
586
524
- void populateFoldIntoPackAndUnpackPatterns (RewritePatternSet &patterns) {
587
+ void populateFoldIntoPackAndUnpackPatterns (
588
+ RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn) {
525
589
patterns.insert <FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
526
590
FoldProducerPackWithConsumerLinalgTransposeOp,
527
591
FoldConsumerPackWithProducerLinalgTransposeOp,
528
592
FoldConsumerUnPackWithProducerLinalgTransposeOp,
529
593
FoldProducerUnPackWithConsumerLinalgTransposeOp>(
530
- patterns.getContext ());
594
+ patterns.getContext (), controlFn );
531
595
}
532
596
533
597
void populateSimplifyPackAndUnpackPatterns (RewritePatternSet &patterns) {
0 commit comments