Skip to content

Commit

Permalink
squashed commit
Browse files Browse the repository at this point in the history
  • Loading branch information
newling committed Aug 16, 2024
1 parent 9109a1f commit a0a4bad
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 401 deletions.
114 changes: 105 additions & 9 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
#include "iree-amd-aie/IR/AMDAIEOps.h"

#include "iree-amd-aie/IR/AMDAIEDialect.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/OpDefinition.h"

#define GET_OP_CLASSES
Expand Down Expand Up @@ -52,7 +50,6 @@ LogicalResult ControlCodeOp::verify() {
// AMDAIE_CoreOp
//===----------------------------------------------------------------------===//


void CoreOp::build(OpBuilder &b, OperationState &result, AMDAIE::TileOp tileOp,
ValueRange inputDmas, ValueRange outputDmas) {
build(b, result, b.getIndexType(), tileOp, inputDmas, outputDmas, nullptr);
Expand Down Expand Up @@ -93,6 +90,70 @@ TileOp CoreOp::getTileOp() {
// AMDAIE_DmaCpyNdBaseOp
//===----------------------------------------------------------------------===//

namespace {
// Simplified from upstream MLIR's foldDynamicIndexList:
LogicalResult foldMixed(SmallVectorImpl<OpFoldResult> &ofrs) {
bool valuesChanged = false;
for (OpFoldResult &ofr : ofrs) {
if (ofr.is<Attribute>()) continue;
Attribute attr;
if (matchPattern(ofr.get<Value>(), m_Constant(&attr))) {
ofr = attr;
valuesChanged = true;
}
}
return success(valuesChanged);
}

template <typename OpType, typename ReplacementBuilder>
// Based on upstream MLIR's
// OpWithOffsetSizesAndStridesConstantArgumentFolder
class DoublyStridedFolder final : public OpRewritePattern<OpType> {
public:
using OpRewritePattern<OpType>::OpRewritePattern;

LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rewriter) const override {
SmallVector<OpFoldResult> tgtMixedOffsets(op.getTargetMixedOffsets());
SmallVector<OpFoldResult> tgtMixedSizes(op.getTargetMixedSizes());
SmallVector<OpFoldResult> tgtMixedStrides(op.getTargetMixedStrides());
SmallVector<OpFoldResult> srcMixedOffsets(op.getSourceMixedOffsets());
SmallVector<OpFoldResult> srcMixedSizes(op.getSourceMixedSizes());
SmallVector<OpFoldResult> srcMixedStrides(op.getSourceMixedStrides());

// No constant operands were folded, just return;
if (failed(foldMixed(tgtMixedOffsets)) &&
failed(foldMixed(tgtMixedSizes)) &&
failed(foldMixed(tgtMixedStrides)) &&
failed(foldMixed(srcMixedOffsets)) &&
failed(foldMixed(srcMixedSizes)) && failed(foldMixed(srcMixedStrides)))
return failure();

ReplacementBuilder::replace(op, rewriter, tgtMixedOffsets, tgtMixedSizes,
tgtMixedStrides, srcMixedOffsets, srcMixedSizes,
srcMixedStrides);

return success();
}
};

template <typename T>
struct DmaCpyNdBaseOpReplacementBuilder {
static void replace(T dmaOp, PatternRewriter &rewriter,
ArrayRef<OpFoldResult> tgtMixedOffsets,
ArrayRef<OpFoldResult> tgtMixedSizes,
ArrayRef<OpFoldResult> tgtMixedStrides,
ArrayRef<OpFoldResult> srcMixedOffsets,
ArrayRef<OpFoldResult> srcMixedSizes,
ArrayRef<OpFoldResult> srcMixedStrides) {
rewriter.replaceOpWithNewOp<T>(dmaOp, dmaOp.getTarget(), tgtMixedOffsets,
tgtMixedSizes, tgtMixedStrides,
dmaOp.getSource(), srcMixedOffsets,
srcMixedSizes, srcMixedStrides);
}
};
} // namespace

// Build a DmaCpyNdOp with mixed static and dynamic entries.
void DmaCpyNdOp::build(OpBuilder &b, OperationState &result, Value target,
ArrayRef<OpFoldResult> targetOffsets,
Expand Down Expand Up @@ -220,6 +281,12 @@ LogicalObjectFifoFromMemrefOp DmaCpyNdOp::getTargetObjectFifo() {
return dyn_cast<LogicalObjectFifoFromMemrefOp>(getTarget().getDefiningOp());
};

void DmaCpyNdOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DoublyStridedFolder<
DmaCpyNdOp, DmaCpyNdBaseOpReplacementBuilder<DmaCpyNdOp>>>(context);
}

// Build a CircularDmaCpyNdOp with mixed static and dynamic entries.
void CircularDmaCpyNdOp::build(
OpBuilder &b, OperationState &result, Value target,
Expand Down Expand Up @@ -347,6 +414,13 @@ LogicalObjectFifoFromMemrefOp CircularDmaCpyNdOp::getTargetObjectFifo() {
return dyn_cast<LogicalObjectFifoFromMemrefOp>(getTarget().getDefiningOp());
};

void CircularDmaCpyNdOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DoublyStridedFolder<
CircularDmaCpyNdOp,
DmaCpyNdBaseOpReplacementBuilder<CircularDmaCpyNdOp>>>(context);
}

//===----------------------------------------------------------------------===//
// AMDAIE_LogicalObjectFifoAccessOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -386,8 +460,7 @@ void LogicalObjectFifoFromMemrefOp::build(
for (auto [column, row] : tileLocations) {
auto getCol = b.create<arith::ConstantIndexOp>(b.getUnknownLoc(), column);
auto getRow = b.create<arith::ConstantIndexOp>(b.getUnknownLoc(), row);
auto tileOp =
b.create<AMDAIE::TileOp>(b.getUnknownLoc(), getCol, getRow);
auto tileOp = b.create<AMDAIE::TileOp>(b.getUnknownLoc(), getCol, getRow);
tiles.push_back(tileOp.getResult());
}
// For deterministic order.
Expand Down Expand Up @@ -449,8 +522,8 @@ void LogicalObjectFifoRelease::build(OpBuilder &b, mlir::OperationState &result,
// AMDAIE_NpuDmaCpyNdOp
//===----------------------------------------------------------------------===//

// Build a NpuDmaCpyNdOp with mixed static and dynamic entries and target and
// source BD IDs.
// Build a NpuDmaCpyNdOp with mixed static and dynamic entries and target
// and source BD IDs.
void NpuDmaCpyNdOp::build(OpBuilder &b, OperationState &result, Value dma,
ArrayRef<OpFoldResult> targetOffsets,
ArrayRef<OpFoldResult> targetSizes,
Expand Down Expand Up @@ -577,6 +650,30 @@ bool NpuDmaCpyNdOp::hasDmaWaitOpUser() {
[](auto userOp) { return isa<NpuDmaWaitOp>(userOp); });
}

namespace {
struct NpuDmaCpyNdOpReplacementBuilder {
static void replace(NpuDmaCpyNdOp dmaOp, PatternRewriter &rewriter,
ArrayRef<OpFoldResult> tgtMixedOffsets,
ArrayRef<OpFoldResult> tgtMixedSizes,
ArrayRef<OpFoldResult> tgtMixedStrides,
ArrayRef<OpFoldResult> srcMixedOffsets,
ArrayRef<OpFoldResult> srcMixedSizes,
ArrayRef<OpFoldResult> srcMixedStrides) {
rewriter.replaceOpWithNewOp<NpuDmaCpyNdOp>(
dmaOp, dmaOp.getDma(), tgtMixedOffsets, tgtMixedSizes, tgtMixedStrides,
srcMixedOffsets, srcMixedSizes, srcMixedStrides, dmaOp.getTargetBdId(),
dmaOp.getSourceBdId());
}
};
} // namespace

void NpuDmaCpyNdOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results
.add<DoublyStridedFolder<NpuDmaCpyNdOp, NpuDmaCpyNdOpReplacementBuilder>>(
context);
}

//===----------------------------------------------------------------------===//
// AMDAIE_TileOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -645,5 +742,4 @@ LogicalResult WorkgroupOp::verify() {
}
return success();
}

} // namespace mlir::iree_compiler::AMDAIE
7 changes: 7 additions & 0 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,11 @@ def AMDAIE_NpuDmaCpyNdOp: AMDAIE_Op<"npu.dma_cpy_nd",
::llvm::SmallVector<::mlir::OpFoldResult>& newSourceOffsets,
::llvm::SmallVector<::mlir::OpFoldResult>& newSourceSizes,
::llvm::SmallVector<::mlir::OpFoldResult>& newSourceStrides);

}];

let hasCanonicalizer = 1;

}

def AMDAIE_NpuDmaWaitOp: AMDAIE_Op<"npu.dma_wait", []> {
Expand Down Expand Up @@ -784,6 +788,7 @@ class AMDAIE_DmaCpyNdBaseOp<string mnemonic, list<Trait> traits = []> :
::llvm::SmallVector<::mlir::OpFoldResult>& newSourceSizes,
::llvm::SmallVector<::mlir::OpFoldResult>& newSourceStrides);
}];

}

def AMDAIE_DmaCpyNdOp: AMDAIE_DmaCpyNdBaseOp<"dma_cpy_nd", []> {
Expand All @@ -809,6 +814,7 @@ def AMDAIE_DmaCpyNdOp: AMDAIE_DmaCpyNdBaseOp<"dma_cpy_nd", []> {
}];

let hasVerifier = 0;
let hasCanonicalizer = 1;
}

def AMDAIE_CircularDmaCpyNdOp: AMDAIE_DmaCpyNdBaseOp<"circular_dma_cpy_nd", [Pure]> {
Expand All @@ -833,6 +839,7 @@ def AMDAIE_CircularDmaCpyNdOp: AMDAIE_DmaCpyNdBaseOp<"circular_dma_cpy_nd", [Pur
}];

let hasVerifier = 0;
let hasCanonicalizer = 1;
}

def AMDAIE_ReferenceToOp: AMDAIE_Op<"reference_to", [SameOperandsAndResultType]> {
Expand Down
Loading

0 comments on commit a0a4bad

Please sign in to comment.