Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Canonicalizers for doubly strided ops such as npu.dma_cpy_nd #680

Merged
merged 1 commit into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 105 additions & 9 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
#include "iree-amd-aie/IR/AMDAIEOps.h"

#include "iree-amd-aie/IR/AMDAIEDialect.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/OpDefinition.h"

#define GET_OP_CLASSES
Expand Down Expand Up @@ -52,7 +50,6 @@ LogicalResult ControlCodeOp::verify() {
// AMDAIE_CoreOp
//===----------------------------------------------------------------------===//


void CoreOp::build(OpBuilder &b, OperationState &result, AMDAIE::TileOp tileOp,
ValueRange inputDmas, ValueRange outputDmas) {
build(b, result, b.getIndexType(), tileOp, inputDmas, outputDmas, nullptr);
Expand Down Expand Up @@ -93,6 +90,70 @@ TileOp CoreOp::getTileOp() {
// AMDAIE_DmaCpyNdBaseOp
//===----------------------------------------------------------------------===//

namespace {
// Simplified from upstream MLIR's foldDynamicIndexList:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Simplified from upstream MLIR's foldDynamicIndexList:
/// Simplified from upstream MLIR's foldDynamicIndexList:

Nit, but IREE uses /// for function/class comments.

LogicalResult foldMixed(SmallVectorImpl<OpFoldResult> &ofrs) {
bool valuesChanged = false;
for (OpFoldResult &ofr : ofrs) {
if (ofr.is<Attribute>()) continue;
Attribute attr;
if (matchPattern(ofr.get<Value>(), m_Constant(&attr))) {
ofr = attr;
valuesChanged = true;
}
}
return success(valuesChanged);
}

template <typename OpType, typename ReplacementBuilder>
// Based on upstream MLIR's
// OpWithOffsetSizesAndStridesConstantArgumentFolder
Comment on lines +109 to +110
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Based on upstream MLIR's
// OpWithOffsetSizesAndStridesConstantArgumentFolder
/// Based on upstream MLIR's
/// OpWithOffsetSizesAndStridesConstantArgumentFolder

Nit, but IREE uses /// for function/class comments.

class DoublyStridedFolder final : public OpRewritePattern<OpType> {
public:
using OpRewritePattern<OpType>::OpRewritePattern;

LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rewriter) const override {
SmallVector<OpFoldResult> tgtMixedOffsets(op.getTargetMixedOffsets());
SmallVector<OpFoldResult> tgtMixedSizes(op.getTargetMixedSizes());
SmallVector<OpFoldResult> tgtMixedStrides(op.getTargetMixedStrides());
SmallVector<OpFoldResult> srcMixedOffsets(op.getSourceMixedOffsets());
SmallVector<OpFoldResult> srcMixedSizes(op.getSourceMixedSizes());
SmallVector<OpFoldResult> srcMixedStrides(op.getSourceMixedStrides());

// No constant operands were folded, just return;
if (failed(foldMixed(tgtMixedOffsets)) &&
failed(foldMixed(tgtMixedSizes)) &&
failed(foldMixed(tgtMixedStrides)) &&
failed(foldMixed(srcMixedOffsets)) &&
failed(foldMixed(srcMixedSizes)) && failed(foldMixed(srcMixedStrides)))
return failure();

ReplacementBuilder::replace(op, rewriter, tgtMixedOffsets, tgtMixedSizes,
tgtMixedStrides, srcMixedOffsets, srcMixedSizes,
srcMixedStrides);

return success();
}
};

template <typename T>
struct DmaCpyNdBaseOpReplacementBuilder {
static void replace(T dmaOp, PatternRewriter &rewriter,
ArrayRef<OpFoldResult> tgtMixedOffsets,
ArrayRef<OpFoldResult> tgtMixedSizes,
ArrayRef<OpFoldResult> tgtMixedStrides,
ArrayRef<OpFoldResult> srcMixedOffsets,
ArrayRef<OpFoldResult> srcMixedSizes,
ArrayRef<OpFoldResult> srcMixedStrides) {
rewriter.replaceOpWithNewOp<T>(dmaOp, dmaOp.getTarget(), tgtMixedOffsets,
tgtMixedSizes, tgtMixedStrides,
dmaOp.getSource(), srcMixedOffsets,
srcMixedSizes, srcMixedStrides);
}
};
} // namespace

// Build a DmaCpyNdOp with mixed static and dynamic entries.
void DmaCpyNdOp::build(OpBuilder &b, OperationState &result, Value target,
ArrayRef<OpFoldResult> targetOffsets,
Expand Down Expand Up @@ -220,6 +281,12 @@ LogicalObjectFifoFromMemrefOp DmaCpyNdOp::getTargetObjectFifo() {
return dyn_cast<LogicalObjectFifoFromMemrefOp>(getTarget().getDefiningOp());
};

void DmaCpyNdOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DoublyStridedFolder<
DmaCpyNdOp, DmaCpyNdBaseOpReplacementBuilder<DmaCpyNdOp>>>(context);
}

// Build a CircularDmaCpyNdOp with mixed static and dynamic entries.
void CircularDmaCpyNdOp::build(
OpBuilder &b, OperationState &result, Value target,
Expand Down Expand Up @@ -347,6 +414,13 @@ LogicalObjectFifoFromMemrefOp CircularDmaCpyNdOp::getTargetObjectFifo() {
return dyn_cast<LogicalObjectFifoFromMemrefOp>(getTarget().getDefiningOp());
};

void CircularDmaCpyNdOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DoublyStridedFolder<
CircularDmaCpyNdOp,
DmaCpyNdBaseOpReplacementBuilder<CircularDmaCpyNdOp>>>(context);
}

//===----------------------------------------------------------------------===//
// AMDAIE_LogicalObjectFifoAccessOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -386,8 +460,7 @@ void LogicalObjectFifoFromMemrefOp::build(
for (auto [column, row] : tileLocations) {
auto getCol = b.create<arith::ConstantIndexOp>(b.getUnknownLoc(), column);
auto getRow = b.create<arith::ConstantIndexOp>(b.getUnknownLoc(), row);
auto tileOp =
b.create<AMDAIE::TileOp>(b.getUnknownLoc(), getCol, getRow);
auto tileOp = b.create<AMDAIE::TileOp>(b.getUnknownLoc(), getCol, getRow);
tiles.push_back(tileOp.getResult());
}
// For deterministic order.
Expand Down Expand Up @@ -449,8 +522,8 @@ void LogicalObjectFifoRelease::build(OpBuilder &b, mlir::OperationState &result,
// AMDAIE_NpuDmaCpyNdOp
//===----------------------------------------------------------------------===//

// Build a NpuDmaCpyNdOp with mixed static and dynamic entries and target and
// source BD IDs.
// Build a NpuDmaCpyNdOp with mixed static and dynamic entries and target
// and source BD IDs.
Comment on lines +525 to +526
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Build a NpuDmaCpyNdOp with mixed static and dynamic entries and target
// and source BD IDs.
/// Build a NpuDmaCpyNdOp with mixed static and dynamic entries and target
/// and source BD IDs.

void NpuDmaCpyNdOp::build(OpBuilder &b, OperationState &result, Value dma,
ArrayRef<OpFoldResult> targetOffsets,
ArrayRef<OpFoldResult> targetSizes,
Expand Down Expand Up @@ -577,6 +650,30 @@ bool NpuDmaCpyNdOp::hasDmaWaitOpUser() {
[](auto userOp) { return isa<NpuDmaWaitOp>(userOp); });
}

namespace {
struct NpuDmaCpyNdOpReplacementBuilder {
static void replace(NpuDmaCpyNdOp dmaOp, PatternRewriter &rewriter,
ArrayRef<OpFoldResult> tgtMixedOffsets,
ArrayRef<OpFoldResult> tgtMixedSizes,
ArrayRef<OpFoldResult> tgtMixedStrides,
ArrayRef<OpFoldResult> srcMixedOffsets,
ArrayRef<OpFoldResult> srcMixedSizes,
ArrayRef<OpFoldResult> srcMixedStrides) {
rewriter.replaceOpWithNewOp<NpuDmaCpyNdOp>(
dmaOp, dmaOp.getDma(), tgtMixedOffsets, tgtMixedSizes, tgtMixedStrides,
srcMixedOffsets, srcMixedSizes, srcMixedStrides, dmaOp.getTargetBdId(),
dmaOp.getSourceBdId());
}
};
} // namespace

void NpuDmaCpyNdOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results
.add<DoublyStridedFolder<NpuDmaCpyNdOp, NpuDmaCpyNdOpReplacementBuilder>>(
context);
}

//===----------------------------------------------------------------------===//
// AMDAIE_TileOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -645,5 +742,4 @@ LogicalResult WorkgroupOp::verify() {
}
return success();
}

} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,11 @@ def AMDAIE_NpuDmaCpyNdOp: AMDAIE_Op<"npu.dma_cpy_nd",
::llvm::SmallVector<::mlir::OpFoldResult>& newSourceOffsets,
::llvm::SmallVector<::mlir::OpFoldResult>& newSourceSizes,
::llvm::SmallVector<::mlir::OpFoldResult>& newSourceStrides);

}];

let hasCanonicalizer = 1;

}

def AMDAIE_NpuDmaWaitOp: AMDAIE_Op<"npu.dma_wait", []> {
Expand Down Expand Up @@ -784,6 +788,7 @@ class AMDAIE_DmaCpyNdBaseOp<string mnemonic, list<Trait> traits = []> :
::llvm::SmallVector<::mlir::OpFoldResult>& newSourceSizes,
::llvm::SmallVector<::mlir::OpFoldResult>& newSourceStrides);
}];

}

def AMDAIE_DmaCpyNdOp: AMDAIE_DmaCpyNdBaseOp<"dma_cpy_nd", []> {
Expand All @@ -809,6 +814,7 @@ def AMDAIE_DmaCpyNdOp: AMDAIE_DmaCpyNdBaseOp<"dma_cpy_nd", []> {
}];

let hasVerifier = 0;
let hasCanonicalizer = 1;
}

def AMDAIE_CircularDmaCpyNdOp: AMDAIE_DmaCpyNdBaseOp<"circular_dma_cpy_nd", [Pure]> {
Expand All @@ -833,6 +839,7 @@ def AMDAIE_CircularDmaCpyNdOp: AMDAIE_DmaCpyNdBaseOp<"circular_dma_cpy_nd", [Pur
}];

let hasVerifier = 0;
let hasCanonicalizer = 1;
}

def AMDAIE_ReferenceToOp: AMDAIE_Op<"reference_to", [SameOperandsAndResultType]> {
Expand Down
Loading
Loading