Skip to content

Commit cf69d3e

Browse files
committed
feat(linalg): add a way to pass controlFn to
`foldIntoPackUnpackPatterns`
1 parent 8dde3f4 commit cf69d3e

File tree

2 files changed

+80
-8
lines changed

2 files changed

+80
-8
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1893,10 +1893,18 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
18931893
/// convert to a `linalg.dot`.
18941894
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
18951895

1896+
/// Function type which is used to control folding operations like `tensor.pad`
1897+
/// and `tensor.extract_slice` into to of linalg.pack/unpack ops.
1898+
using ControlFoldIntoPackUnpackFn = std::function<bool(OpOperand *opOperand)>;
1899+
inline bool defaultControlFoldIntoPackUnpackFn(OpOperand *opOperand) {
1900+
return true;
1901+
};
18961902
/// Populates `patterns` with patterns that fold operations like `tensor.pad`
18971903
/// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
18981904
/// respectively.
1899-
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns);
1905+
void populateFoldIntoPackAndUnpackPatterns(
1906+
RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn =
1907+
defaultControlFoldIntoPackUnpackFn);
19001908

19011909
/// Populates `patterns` with patterns that fold operations like `linalg.pack`
19021910
/// and `linalg.unpack` into `tensor.empty`.

mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Linalg/IR/Linalg.h"
10+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1011
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1112
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
1213
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -197,7 +198,9 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
197198
/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
198199
/// the pad op has zero low paddings, or if `pack` has no padding values.
199200
struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
200-
using OpRewritePattern<PackOp>::OpRewritePattern;
201+
public:
202+
FoldPadWithPackOp(MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
203+
: OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
201204

202205
LogicalResult matchAndRewrite(PackOp packOp,
203206
PatternRewriter &rewriter) const override {
@@ -206,6 +209,9 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
206209
if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
207210
return failure();
208211

212+
if (!controlFn(&packOp.getSourceMutable()))
213+
return failure();
214+
209215
Value constantPaddingValue = padOp.getConstantPaddingValue();
210216
if (!constantPaddingValue)
211217
return failure();
@@ -220,20 +226,31 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
220226
packOp.getOuterDimsPerm());
221227
return success();
222228
}
229+
230+
private:
231+
ControlFoldIntoPackUnpackFn controlFn;
223232
};
224233

225234
/// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
226235
/// has extract_slice semantics.
227236
struct FoldUnpackWithExtractSliceOp
228237
: public OpRewritePattern<tensor::ExtractSliceOp> {
229-
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
238+
public:
239+
FoldUnpackWithExtractSliceOp(MLIRContext *context,
240+
ControlFoldIntoPackUnpackFn controlFn)
241+
: OpRewritePattern<tensor::ExtractSliceOp>(context),
242+
controlFn(std::move(controlFn)) {}
230243

231244
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
232245
PatternRewriter &rewriter) const override {
233246
auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
234247
if (!unpackOp)
235248
return failure();
236249

250+
// User controlled folding function.
251+
if (!controlFn(&sliceOp.getSourceMutable()))
252+
return failure();
253+
237254
if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
238255
return rewriter.notifyMatchFailure(
239256
sliceOp, "rank-reduced folding is not supported");
@@ -255,6 +272,9 @@ struct FoldUnpackWithExtractSliceOp
255272
unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
256273
return success();
257274
}
275+
276+
private:
277+
ControlFoldIntoPackUnpackFn controlFn;
258278
};
259279

260280
// Applies 'permutation' on 'inVec' and stores the result in resVec.
@@ -284,7 +304,12 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
284304
/// semantics.
285305
struct FoldProducerPackWithConsumerLinalgTransposeOp
286306
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
287-
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
307+
308+
public:
309+
FoldProducerPackWithConsumerLinalgTransposeOp(
310+
MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
311+
: OpInterfaceRewritePattern<linalg::LinalgOp>(context),
312+
controlFn(std::move(controlFn)) {}
288313

289314
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
290315
PatternRewriter &rewriter) const override {
@@ -293,6 +318,9 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
293318
if (!packOp)
294319
return failure();
295320

321+
if (!controlFn(&linalgOp->getOpOperand(0)))
322+
return failure();
323+
296324
FailureOr<SmallVector<int64_t>> maybePerm =
297325
getTransposeOpPermutation(linalgOp);
298326
if (failed(maybePerm))
@@ -331,20 +359,30 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
331359

332360
return success();
333361
}
362+
363+
private:
364+
ControlFoldIntoPackUnpackFn controlFn;
334365
};
335366

336367
/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
337368
/// semantics.
338369
struct FoldConsumerPackWithProducerLinalgTransposeOp
339370
: public OpRewritePattern<PackOp> {
340-
using OpRewritePattern<PackOp>::OpRewritePattern;
371+
372+
public:
373+
FoldConsumerPackWithProducerLinalgTransposeOp(
374+
MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
375+
: OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
341376

342377
LogicalResult matchAndRewrite(PackOp packOp,
343378
PatternRewriter &rewriter) const override {
344379
auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
345380
if (!linalgOp)
346381
return failure();
347382

383+
if (!controlFn(&packOp.getSourceMutable()))
384+
return failure();
385+
348386
FailureOr<SmallVector<int64_t>> maybePerm =
349387
getTransposeOpPermutation(linalgOp);
350388
if (failed(maybePerm))
@@ -375,13 +413,21 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
375413

376414
return success();
377415
}
416+
417+
private:
418+
ControlFoldIntoPackUnpackFn controlFn;
378419
};
379420

380421
/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
381422
/// transpose semantics.
382423
struct FoldProducerUnPackWithConsumerLinalgTransposeOp
383424
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
384-
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
425+
426+
public:
427+
FoldProducerUnPackWithConsumerLinalgTransposeOp(
428+
MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
429+
: OpInterfaceRewritePattern<linalg::LinalgOp>(context),
430+
controlFn(std::move(controlFn)) {}
385431

386432
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
387433
PatternRewriter &rewriter) const override {
@@ -390,6 +436,9 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
390436
if (!unPackOp)
391437
return failure();
392438

439+
if (!controlFn(&linalgOp->getOpOperand(0)))
440+
return failure();
441+
393442
FailureOr<SmallVector<int64_t>> maybePerm =
394443
getTransposeOpPermutation(linalgOp);
395444
if (failed(maybePerm))
@@ -416,6 +465,9 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
416465

417466
return success();
418467
}
468+
469+
private:
470+
ControlFoldIntoPackUnpackFn controlFn;
419471
};
420472

421473
/// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
@@ -424,12 +476,20 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
424476
: public OpRewritePattern<UnPackOp> {
425477
using OpRewritePattern<UnPackOp>::OpRewritePattern;
426478

479+
public:
480+
FoldConsumerUnPackWithProducerLinalgTransposeOp(
481+
MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
482+
: OpRewritePattern<UnPackOp>(context), controlFn(std::move(controlFn)) {}
483+
427484
LogicalResult matchAndRewrite(UnPackOp unPackOp,
428485
PatternRewriter &rewriter) const override {
429486
auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
430487
if (!linalgOp)
431488
return failure();
432489

490+
if (!controlFn(&unPackOp.getSourceMutable()))
491+
return failure();
492+
433493
FailureOr<SmallVector<int64_t>> maybePerm =
434494
getTransposeOpPermutation(linalgOp);
435495
if (failed(maybePerm))
@@ -474,6 +534,9 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
474534

475535
return success();
476536
}
537+
538+
private:
539+
ControlFoldIntoPackUnpackFn controlFn;
477540
};
478541

479542
/// tensor.empty does not define any tensor contents, so an unpadded pack
@@ -521,13 +584,14 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
521584

522585
} // namespace
523586

524-
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
587+
void populateFoldIntoPackAndUnpackPatterns(
588+
RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn) {
525589
patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
526590
FoldProducerPackWithConsumerLinalgTransposeOp,
527591
FoldConsumerPackWithProducerLinalgTransposeOp,
528592
FoldConsumerUnPackWithProducerLinalgTransposeOp,
529593
FoldProducerUnPackWithConsumerLinalgTransposeOp>(
530-
patterns.getContext());
594+
patterns.getContext(), controlFn);
531595
}
532596

533597
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {

0 commit comments

Comments
 (0)