Skip to content

Commit

Permalink
namespaces
Browse files Browse the repository at this point in the history
  • Loading branch information
newling committed Aug 16, 2024
1 parent 5f1696c commit 8e99191
Showing 1 changed file with 101 additions and 101 deletions.
202 changes: 101 additions & 101 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,70 @@ TileOp CoreOp::getTileOp() {
// AMDAIE_DmaCpyNdBaseOp
//===----------------------------------------------------------------------===//

namespace {
// Simplified from upstream MLIR's foldDynamicIndexList:
LogicalResult foldMixed(SmallVectorImpl<OpFoldResult> &ofrs) {
bool valuesChanged = false;
for (OpFoldResult &ofr : ofrs) {
if (ofr.is<Attribute>()) continue;
Attribute attr;
if (matchPattern(ofr.get<Value>(), m_Constant(&attr))) {
ofr = attr;
valuesChanged = true;
}
}
return success(valuesChanged);
}

template <typename OpType, typename ReplacementBuilder>
// Based on upstream MLIR's
// OpWithOffsetSizesAndStridesConstantArgumentFolder
class DoublyStridedFolder final : public OpRewritePattern<OpType> {
public:
using OpRewritePattern<OpType>::OpRewritePattern;

LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rewriter) const override {
SmallVector<OpFoldResult> tgtMixedOffsets(op.getTargetMixedOffsets());
SmallVector<OpFoldResult> tgtMixedSizes(op.getTargetMixedSizes());
SmallVector<OpFoldResult> tgtMixedStrides(op.getTargetMixedStrides());
SmallVector<OpFoldResult> srcMixedOffsets(op.getSourceMixedOffsets());
SmallVector<OpFoldResult> srcMixedSizes(op.getSourceMixedSizes());
SmallVector<OpFoldResult> srcMixedStrides(op.getSourceMixedStrides());

// No constant operands were folded, just return;
if (failed(foldMixed(tgtMixedOffsets)) &&
failed(foldMixed(tgtMixedSizes)) &&
failed(foldMixed(tgtMixedStrides)) &&
failed(foldMixed(srcMixedOffsets)) &&
failed(foldMixed(srcMixedSizes)) && failed(foldMixed(srcMixedStrides)))
return failure();

ReplacementBuilder::replace(op, rewriter, tgtMixedOffsets, tgtMixedSizes,
tgtMixedStrides, srcMixedOffsets, srcMixedSizes,
srcMixedStrides);

return success();
}
};

template <typename T>
struct DmaCpyNdBaseOpReplacementBuilder {
static void replace(T dmaOp, PatternRewriter &rewriter,
ArrayRef<OpFoldResult> tgtMixedOffsets,
ArrayRef<OpFoldResult> tgtMixedSizes,
ArrayRef<OpFoldResult> tgtMixedStrides,
ArrayRef<OpFoldResult> srcMixedOffsets,
ArrayRef<OpFoldResult> srcMixedSizes,
ArrayRef<OpFoldResult> srcMixedStrides) {
rewriter.replaceOpWithNewOp<T>(dmaOp, dmaOp.getTarget(), tgtMixedOffsets,
tgtMixedSizes, tgtMixedStrides,
dmaOp.getSource(), srcMixedOffsets,
srcMixedSizes, srcMixedStrides);
}
};
} // namespace

// Build a DmaCpyNdOp with mixed static and dynamic entries.
void DmaCpyNdOp::build(OpBuilder &b, OperationState &result, Value target,
ArrayRef<OpFoldResult> targetOffsets,
Expand Down Expand Up @@ -217,6 +281,12 @@ LogicalObjectFifoFromMemrefOp DmaCpyNdOp::getTargetObjectFifo() {
return dyn_cast<LogicalObjectFifoFromMemrefOp>(getTarget().getDefiningOp());
};

void DmaCpyNdOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DoublyStridedFolder<
DmaCpyNdOp, DmaCpyNdBaseOpReplacementBuilder<DmaCpyNdOp>>>(context);
}

// Build a CircularDmaCpyNdOp with mixed static and dynamic entries.
void CircularDmaCpyNdOp::build(
OpBuilder &b, OperationState &result, Value target,
Expand Down Expand Up @@ -344,6 +414,13 @@ LogicalObjectFifoFromMemrefOp CircularDmaCpyNdOp::getTargetObjectFifo() {
return dyn_cast<LogicalObjectFifoFromMemrefOp>(getTarget().getDefiningOp());
};

void CircularDmaCpyNdOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DoublyStridedFolder<
CircularDmaCpyNdOp,
DmaCpyNdBaseOpReplacementBuilder<CircularDmaCpyNdOp>>>(context);
}

//===----------------------------------------------------------------------===//
// AMDAIE_LogicalObjectFifoAccessOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -393,107 +470,6 @@ void LogicalObjectFifoFromMemrefOp::build(
build(b, result, type, memref, tiles);
}

namespace {
// Simplified from upstream MLIR's foldDynamicIndexList:
LogicalResult foldMixed(SmallVectorImpl<OpFoldResult> &ofrs) {
bool valuesChanged = false;
for (OpFoldResult &ofr : ofrs) {
if (ofr.is<Attribute>()) continue;
Attribute attr;
if (matchPattern(ofr.get<Value>(), m_Constant(&attr))) {
ofr = attr;
valuesChanged = true;
}
}
return success(valuesChanged);
}
} // namespace

template <typename OpType, typename ReplacementBuilder>
// Based on upstream MLIR's
// OpWithOffsetSizesAndStridesConstantArgumentFolder
class DoublyStridedFolder final : public OpRewritePattern<OpType> {
public:
using OpRewritePattern<OpType>::OpRewritePattern;

LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rewriter) const override {
SmallVector<OpFoldResult> tgtMixedOffsets(op.getTargetMixedOffsets());
SmallVector<OpFoldResult> tgtMixedSizes(op.getTargetMixedSizes());
SmallVector<OpFoldResult> tgtMixedStrides(op.getTargetMixedStrides());
SmallVector<OpFoldResult> srcMixedOffsets(op.getSourceMixedOffsets());
SmallVector<OpFoldResult> srcMixedSizes(op.getSourceMixedSizes());
SmallVector<OpFoldResult> srcMixedStrides(op.getSourceMixedStrides());

// No constant operands were folded, just return;
if (failed(foldMixed(tgtMixedOffsets)) &&
failed(foldMixed(tgtMixedSizes)) &&
failed(foldMixed(tgtMixedStrides)) &&
failed(foldMixed(srcMixedOffsets)) &&
failed(foldMixed(srcMixedSizes)) &&
failed(foldMixed(srcMixedStrides))) {
return failure();
}

ReplacementBuilder::replace(op, rewriter, tgtMixedOffsets, tgtMixedSizes,
tgtMixedStrides, srcMixedOffsets, srcMixedSizes,
srcMixedStrides);

return success();
}
};

struct NpuDmaCpyNdOpReplacementBuilder {
static void replace(NpuDmaCpyNdOp dmaOp, PatternRewriter &rewriter,
ArrayRef<OpFoldResult> tgtMixedOffsets,
ArrayRef<OpFoldResult> tgtMixedSizes,
ArrayRef<OpFoldResult> tgtMixedStrides,
ArrayRef<OpFoldResult> srcMixedOffsets,
ArrayRef<OpFoldResult> srcMixedSizes,
ArrayRef<OpFoldResult> srcMixedStrides) {
rewriter.replaceOpWithNewOp<NpuDmaCpyNdOp>(
dmaOp, dmaOp.getDma(), tgtMixedOffsets, tgtMixedSizes, tgtMixedStrides,
srcMixedOffsets, srcMixedSizes, srcMixedStrides, dmaOp.getTargetBdId(),
dmaOp.getSourceBdId());
}
};

void NpuDmaCpyNdOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results
.add<DoublyStridedFolder<NpuDmaCpyNdOp, NpuDmaCpyNdOpReplacementBuilder>>(
context);
}

template <typename T>
struct DmaCpyNdBaseOpReplacementBuilder {
static void replace(T dmaOp, PatternRewriter &rewriter,
ArrayRef<OpFoldResult> tgtMixedOffsets,
ArrayRef<OpFoldResult> tgtMixedSizes,
ArrayRef<OpFoldResult> tgtMixedStrides,
ArrayRef<OpFoldResult> srcMixedOffsets,
ArrayRef<OpFoldResult> srcMixedSizes,
ArrayRef<OpFoldResult> srcMixedStrides) {
rewriter.replaceOpWithNewOp<T>(dmaOp, dmaOp.getTarget(), tgtMixedOffsets,
tgtMixedSizes, tgtMixedStrides,
dmaOp.getSource(), srcMixedOffsets,
srcMixedSizes, srcMixedStrides);
}
};

void DmaCpyNdOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DoublyStridedFolder<
DmaCpyNdOp, DmaCpyNdBaseOpReplacementBuilder<DmaCpyNdOp>>>(context);
}

void CircularDmaCpyNdOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DoublyStridedFolder<
CircularDmaCpyNdOp,
DmaCpyNdBaseOpReplacementBuilder<CircularDmaCpyNdOp>>>(context);
}

LogicalResult LogicalObjectFifoFromMemrefOp::canonicalize(
LogicalObjectFifoFromMemrefOp logicalObjectFifo,
PatternRewriter &rewriter) {
Expand Down Expand Up @@ -674,6 +650,30 @@ bool NpuDmaCpyNdOp::hasDmaWaitOpUser() {
[](auto userOp) { return isa<NpuDmaWaitOp>(userOp); });
}

namespace {
struct NpuDmaCpyNdOpReplacementBuilder {
static void replace(NpuDmaCpyNdOp dmaOp, PatternRewriter &rewriter,
ArrayRef<OpFoldResult> tgtMixedOffsets,
ArrayRef<OpFoldResult> tgtMixedSizes,
ArrayRef<OpFoldResult> tgtMixedStrides,
ArrayRef<OpFoldResult> srcMixedOffsets,
ArrayRef<OpFoldResult> srcMixedSizes,
ArrayRef<OpFoldResult> srcMixedStrides) {
rewriter.replaceOpWithNewOp<NpuDmaCpyNdOp>(
dmaOp, dmaOp.getDma(), tgtMixedOffsets, tgtMixedSizes, tgtMixedStrides,
srcMixedOffsets, srcMixedSizes, srcMixedStrides, dmaOp.getTargetBdId(),
dmaOp.getSourceBdId());
}
};
} // namespace

void NpuDmaCpyNdOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results
.add<DoublyStridedFolder<NpuDmaCpyNdOp, NpuDmaCpyNdOpReplacementBuilder>>(
context);
}

//===----------------------------------------------------------------------===//
// AMDAIE_TileOp
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit 8e99191

Please sign in to comment.