Skip to content

[mlir][tosa] Fold PadOp to tensor operations #132700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
LogicalResult verifyOutputZeroPoint(int64_t zp);
}];

let hasCanonicalizer = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -153,6 +154,8 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
}];

let builders = [Tosa_ConvOpQuantInfoBuilder];

let hasCanonicalizer = 1;
let hasVerifier = 1;
}

Expand Down Expand Up @@ -244,6 +247,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
}];

let builders = [Tosa_ConvOpQuantInfoBuilder];

let hasCanonicalizer = 1;
let hasVerifier = 1;
}

Expand Down
309 changes: 274 additions & 35 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,280 @@ using namespace mlir::tosa;
// Operator Canonicalizers.
//===----------------------------------------------------------------------===//

//===----------------------------------------------------------------------===//
// Tensor Data Engine Operators.
//===----------------------------------------------------------------------===//

// Check that the zero point of the tensor and padding operations are aligned.
bool checkMatchingPadConstAndZp(Value padConst, Value zp) {
// Check that padConst is a constant value and a scalar tensor
DenseElementsAttr padConstAttr;
if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
(padConstAttr.size() != 1)) {
return false;
}

// Check that floating point pad is zero
if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
return padConstVal == 0.0f;
}

// Check that the zp and padConst align for the integer (quantized) case
if (auto padConstIntAttr =
mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
DenseIntElementsAttr zpAttr;
// Check that zp is a constant value and a scalar tensor
if (!matchPattern(zp, m_Constant(&zpAttr)) || (padConstAttr.size() != 1)) {
return false;
}

// Check equality
int64_t zpVal = (*zpAttr.begin()).getSExtValue();
int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
return zpVal == padConstVal;
}

// Bail-out on unsupported type
return false;
}

namespace {
template <typename OpTy>
struct PoolPadFoldAdaptor;

template <>
struct PoolPadFoldAdaptor<tosa::AvgPool2dOp> {
using OpTy = tosa::AvgPool2dOp;
static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
const llvm::ArrayRef<int64_t> kernel = op.getKernel();
if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
return false;
return true;
}
static bool checkPadConstCompliance(OpTy op, Value padConst) {
return checkMatchingPadConstAndZp(padConst, op.getInputZp());
}
static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
Value padInput, ArrayRef<int64_t> newPad) {
rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>(
op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(),
op.getKernel(), op.getStride(), rewriter.getDenseI64ArrayAttr(newPad),
op.getAccType());
}
};

template <>
struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
using OpTy = tosa::MaxPool2dOp;
static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
const llvm::ArrayRef<int64_t> kernel = op.getKernel();
if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
return false;
return true;
}
static bool checkPadConstCompliance(OpTy, Value padConst) {
// Check that padConst is a constant value and a scalar tensor
DenseElementsAttr padConstAttr;
if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
padConstAttr.size() != 1) {
return false;
}

// Pad needs to be in the minimum value to be able to merge
if (auto padConstFpAttr =
mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
const APFloat padConstVal = *padConstFpAttr.begin();
const APFloat lowestVal =
APFloat::getLargest(padConstVal.getSemantics(), true);
return padConstVal == lowestVal;
} else if (auto padConstIntAttr =
mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
const APInt padConstVal = *padConstIntAttr.begin();
const unsigned int bitWidth = padConstVal.getBitWidth();
const APInt lowestVal =
padConstIntAttr.getElementType().isUnsignedInteger()
? APInt::getZero(bitWidth)
: APInt::getSignedMinValue(bitWidth);
return padConstVal == lowestVal;
}

// Bail-out on unsupported type
return false;
}
static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
Value padInput, ArrayRef<int64_t> newPad) {
rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
op, op.getType(), padInput, op.getKernel(), op.getStride(),
rewriter.getDenseI64ArrayAttr(newPad), op.getNanMode());
}
};

template <typename OpTy>
struct ConvPadFoldAdaptor {
static bool checkKernelCompliance(OpTy, const ArrayRef<int64_t>) {
return true;
}
static bool checkPadConstCompliance(OpTy op, Value padConst) {
return checkMatchingPadConstAndZp(padConst, op.getInputZp());
}
static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
Value padInput, ArrayRef<int64_t> newPad) {
rewriter.replaceOpWithNewOp<OpTy>(
op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
op.getDilationAttr(), op.getAccType(), op.getLocalBound());
}
};

// Pattern attempts to fold a `tosa.pad` operator to a following tensor
// operation like `tosa.conv2d` by merging the padding associated with the
// pad operator directly to the implicit padding of the tensor operation.
// This helps eliminate the explicit padding operator if unused.
template <typename OpTy, typename AdaptorTy>
struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;

LogicalResult matchAndRewrite(OpTy tensorOp,
PatternRewriter &rewriter) const override {
// Check producer is a tosa::PadOp
auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
if (!padOp)
return rewriter.notifyMatchFailure(tensorOp,
"Producer must be a tosa::PadOp.");

// Validate that tensor operation has sane padding
const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
if (tensorOpPad.size() != 4) // pad_top, pad_bottom, pad_left, pad_right
return rewriter.notifyMatchFailure(
tensorOp, "Tensor operation padding shall have 4 elements.");

// Validate tosa::PadOp padding
DenseIntElementsAttr padOpPadding;
if (!matchPattern(padOp.getPadding(), m_Constant(&padOpPadding))) {
return rewriter.notifyMatchFailure(
tensorOp,
"The `padding` input specified on the tosa::PadOp must be constant.");
}
// N_before, N_after, H_before, H_after, W_before, W_after, C_before,
// C_after
if (padOpPadding.size() != 8)
return rewriter.notifyMatchFailure(tensorOp,
"Pad padding should have 8 elements.");
int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue();
int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue();
int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue();
int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue();
int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue();
int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue();
int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue();
int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue();

if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
return rewriter.notifyMatchFailure(
tensorOp, "Folding padding in N or C dimensions is not supported.");

// Fold padding from Pad into the tensor operation
// 4 elements - pad_top, pad_bottom, pad_left, pad_right
SmallVector<int64_t> foldedPad(tensorOpPad.size());
foldedPad[0] = padHBefore + tensorOpPad[0];
foldedPad[1] = padHAfter + tensorOpPad[1];
foldedPad[2] = padWBefore + tensorOpPad[2];
foldedPad[3] = padWAfter + tensorOpPad[3];

// Check kernel related restrictions
if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
return rewriter.notifyMatchFailure(
tensorOp, "Padding size not aligned with kernel restrictions.");
}

// Check padding constant restrictions
if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
return rewriter.notifyMatchFailure(
tensorOp,
"Padding constant is not aligned with operator zero-point.");
}

// Check that padding doesn't grow more than 8K level (8192) for now
if (llvm::any_of(foldedPad, [](int64_t padVal) { return padVal > 8192; })) {
return rewriter.notifyMatchFailure(
tensorOp, "Padding size more than the 8K level limit.");
}

// Create operator
AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
foldedPad);

return success();
}
};
} // namespace

void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldPadToTensorOp<tosa::AvgPool2dOp,
PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
context);
}

void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
context);
}

void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
context);
}

struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
PatternRewriter &rewriter) const override {
Value input = op.getInput();
Value output = op.getOutput();
ShapedType inputType = llvm::cast<ShapedType>(input.getType());
ShapedType outputType = llvm::cast<ShapedType>(output.getType());

if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
return failure();
}

// If the output and input shapes are 1x1, then this is a no op.
ArrayRef<int64_t> outputShape = outputType.getShape();
if (outputShape[1] != 1 || outputShape[2] != 1) {
return failure();
}

ArrayRef<int64_t> inputShape = inputType.getShape();
if (inputShape[1] != 1 || inputShape[2] != 1) {
return failure();
}

rewriter.replaceOp(op, input);
return success();
}
};

void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<MaxPool2dIsNoOp,
FoldPadToTensorOp<tosa::MaxPool2dOp,
PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
context);
}

//===----------------------------------------------------------------------===//
// Data Layout / Memory Reinterpretation.
//===----------------------------------------------------------------------===//

struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;

Expand Down Expand Up @@ -175,41 +449,6 @@ void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
}

struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
PatternRewriter &rewriter) const override {
Value input = op.getInput();
Value output = op.getOutput();
ShapedType inputType = llvm::cast<ShapedType>(input.getType());
ShapedType outputType = llvm::cast<ShapedType>(output.getType());

if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
return failure();
}

// If the output and input shapes are 1x1, then this is a no op.
ArrayRef<int64_t> outputShape = outputType.getShape();
if (outputShape[1] != 1 || outputShape[2] != 1) {
return failure();
}

ArrayRef<int64_t> inputShape = inputType.getShape();
if (inputShape[1] != 1 || inputShape[2] != 1) {
return failure();
}

rewriter.replaceOp(op, input);
return success();
}
};

void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<MaxPool2dIsNoOp>(context);
}

struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down
Loading