Skip to content

feat(linalg): add a way to pass controlFn to foldIntoPackUnpackPatterns #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1893,10 +1893,18 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
/// convert to a `linalg.dot`.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);

/// Function type which is used to control folding operations like `tensor.pad`
/// and `tensor.extract_slice` into linalg.pack/unpack ops.
using ControlFoldIntoPackUnpackFn = std::function<bool(OpOperand *opOperand)>;
inline bool defaultControlFoldIntoPackUnpackFn(OpOperand *opOperand) {
return true;
};
/// Populates `patterns` with patterns that fold operations like `tensor.pad`
/// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
/// respectively.
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns);
void populateFoldIntoPackAndUnpackPatterns(
RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn =
defaultControlFoldIntoPackUnpackFn);

/// Populates `patterns` with patterns that fold operations like `linalg.pack`
/// and `linalg.unpack` into `tensor.empty`.
Expand Down
83 changes: 76 additions & 7 deletions mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
Expand Down Expand Up @@ -197,7 +198,9 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
/// the pad op has zero low paddings, or if `pack` has no padding values.
struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
using OpRewritePattern<PackOp>::OpRewritePattern;
public:
FoldPadWithPackOp(MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
: OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}

LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
Expand All @@ -206,6 +209,10 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
return failure();

// User controlled folding function.
if (!controlFn(&packOp.getSourceMutable()))
return failure();

Value constantPaddingValue = padOp.getConstantPaddingValue();
if (!constantPaddingValue)
return failure();
Expand All @@ -220,20 +227,31 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
packOp.getOuterDimsPerm());
return success();
}

private:
ControlFoldIntoPackUnpackFn controlFn;
};

/// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
/// has extract_slice semantics.
struct FoldUnpackWithExtractSliceOp
: public OpRewritePattern<tensor::ExtractSliceOp> {
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
public:
FoldUnpackWithExtractSliceOp(MLIRContext *context,
ControlFoldIntoPackUnpackFn controlFn)
: OpRewritePattern<tensor::ExtractSliceOp>(context),
controlFn(std::move(controlFn)) {}

LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override {
auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
if (!unpackOp)
return failure();

// User controlled folding function.
if (!controlFn(&sliceOp.getSourceMutable()))
return failure();

if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
return rewriter.notifyMatchFailure(
sliceOp, "rank-reduced folding is not supported");
Expand All @@ -255,6 +273,9 @@ struct FoldUnpackWithExtractSliceOp
unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
return success();
}

private:
ControlFoldIntoPackUnpackFn controlFn;
};

// Applies 'permutation' on 'inVec' and stores the result in resVec.
Expand Down Expand Up @@ -284,7 +305,12 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
/// semantics.
struct FoldProducerPackWithConsumerLinalgTransposeOp
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;

public:
FoldProducerPackWithConsumerLinalgTransposeOp(
MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
: OpInterfaceRewritePattern<linalg::LinalgOp>(context),
controlFn(std::move(controlFn)) {}

LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
Expand All @@ -293,6 +319,10 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
if (!packOp)
return failure();

// User controlled folding function.
if (!controlFn(&linalgOp->getOpOperand(0)))
return failure();

FailureOr<SmallVector<int64_t>> maybePerm =
getTransposeOpPermutation(linalgOp);
if (failed(maybePerm))
Expand Down Expand Up @@ -331,20 +361,31 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp

return success();
}

private:
ControlFoldIntoPackUnpackFn controlFn;
};

/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
/// semantics.
struct FoldConsumerPackWithProducerLinalgTransposeOp
: public OpRewritePattern<PackOp> {
using OpRewritePattern<PackOp>::OpRewritePattern;

public:
FoldConsumerPackWithProducerLinalgTransposeOp(
MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
: OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}

LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
if (!linalgOp)
return failure();

// User controlled folding function.
if (!controlFn(&packOp.getSourceMutable()))
return failure();

FailureOr<SmallVector<int64_t>> maybePerm =
getTransposeOpPermutation(linalgOp);
if (failed(maybePerm))
Expand Down Expand Up @@ -375,13 +416,21 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp

return success();
}

private:
ControlFoldIntoPackUnpackFn controlFn;
};

/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
/// transpose semantics.
struct FoldProducerUnPackWithConsumerLinalgTransposeOp
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;

public:
FoldProducerUnPackWithConsumerLinalgTransposeOp(
MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
: OpInterfaceRewritePattern<linalg::LinalgOp>(context),
controlFn(std::move(controlFn)) {}

LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
Expand All @@ -390,6 +439,10 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
if (!unPackOp)
return failure();

// User controlled folding function.
if (!controlFn(&linalgOp->getOpOperand(0)))
return failure();

FailureOr<SmallVector<int64_t>> maybePerm =
getTransposeOpPermutation(linalgOp);
if (failed(maybePerm))
Expand All @@ -416,6 +469,9 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp

return success();
}

private:
ControlFoldIntoPackUnpackFn controlFn;
};

/// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
Expand All @@ -424,12 +480,21 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
: public OpRewritePattern<UnPackOp> {
using OpRewritePattern<UnPackOp>::OpRewritePattern;

public:
FoldConsumerUnPackWithProducerLinalgTransposeOp(
MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
: OpRewritePattern<UnPackOp>(context), controlFn(std::move(controlFn)) {}

LogicalResult matchAndRewrite(UnPackOp unPackOp,
PatternRewriter &rewriter) const override {
auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
if (!linalgOp)
return failure();

// User controlled folding function.
if (!controlFn(&unPackOp.getSourceMutable()))
return failure();

FailureOr<SmallVector<int64_t>> maybePerm =
getTransposeOpPermutation(linalgOp);
if (failed(maybePerm))
Expand Down Expand Up @@ -474,6 +539,9 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp

return success();
}

private:
ControlFoldIntoPackUnpackFn controlFn;
};

/// tensor.empty does not define any tensor contents, so an unpadded pack
Expand Down Expand Up @@ -521,13 +589,14 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {

} // namespace

void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
void populateFoldIntoPackAndUnpackPatterns(
RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn) {
patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
FoldProducerPackWithConsumerLinalgTransposeOp,
FoldConsumerPackWithProducerLinalgTransposeOp,
FoldConsumerUnPackWithProducerLinalgTransposeOp,
FoldProducerUnPackWithConsumerLinalgTransposeOp>(
patterns.getContext());
patterns.getContext(), controlFn);
}

void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {
Expand Down