diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.cpp index beb31dc03..a6a65ac19 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.cpp @@ -90,6 +90,70 @@ TileOp CoreOp::getTileOp() { // AMDAIE_DmaCpyNdBaseOp //===----------------------------------------------------------------------===// +namespace { +// Simplified from upstream MLIR's foldDynamicIndexList: +LogicalResult foldMixed(SmallVectorImpl &ofrs) { + bool valuesChanged = false; + for (OpFoldResult &ofr : ofrs) { + if (ofr.is()) continue; + Attribute attr; + if (matchPattern(ofr.get(), m_Constant(&attr))) { + ofr = attr; + valuesChanged = true; + } + } + return success(valuesChanged); +} + +template +// Based on upstream MLIR's +// OpWithOffsetSizesAndStridesConstantArgumentFolder +class DoublyStridedFolder final : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpType op, + PatternRewriter &rewriter) const override { + SmallVector tgtMixedOffsets(op.getTargetMixedOffsets()); + SmallVector tgtMixedSizes(op.getTargetMixedSizes()); + SmallVector tgtMixedStrides(op.getTargetMixedStrides()); + SmallVector srcMixedOffsets(op.getSourceMixedOffsets()); + SmallVector srcMixedSizes(op.getSourceMixedSizes()); + SmallVector srcMixedStrides(op.getSourceMixedStrides()); + + // No constant operands were folded, just return; + if (failed(foldMixed(tgtMixedOffsets)) && + failed(foldMixed(tgtMixedSizes)) && + failed(foldMixed(tgtMixedStrides)) && + failed(foldMixed(srcMixedOffsets)) && + failed(foldMixed(srcMixedSizes)) && failed(foldMixed(srcMixedStrides))) + return failure(); + + ReplacementBuilder::replace(op, rewriter, tgtMixedOffsets, tgtMixedSizes, + tgtMixedStrides, srcMixedOffsets, srcMixedSizes, + srcMixedStrides); + + return success(); + } +}; + +template +struct DmaCpyNdBaseOpReplacementBuilder { + static void replace(T dmaOp, PatternRewriter &rewriter, + ArrayRef tgtMixedOffsets, + ArrayRef tgtMixedSizes, + ArrayRef tgtMixedStrides, + ArrayRef srcMixedOffsets, + ArrayRef srcMixedSizes, + ArrayRef srcMixedStrides) { + rewriter.replaceOpWithNewOp(dmaOp, dmaOp.getTarget(), tgtMixedOffsets, + tgtMixedSizes, tgtMixedStrides, + dmaOp.getSource(), srcMixedOffsets, + srcMixedSizes, srcMixedStrides); + } +}; +} // namespace + // Build a DmaCpyNdOp with mixed static and dynamic entries. void DmaCpyNdOp::build(OpBuilder &b, OperationState &result, Value target, ArrayRef targetOffsets, @@ -217,6 +281,12 @@ LogicalObjectFifoFromMemrefOp DmaCpyNdOp::getTargetObjectFifo() { return dyn_cast(getTarget().getDefiningOp()); }; +void DmaCpyNdOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add>>(context); +} + // Build a CircularDmaCpyNdOp with mixed static and dynamic entries. void CircularDmaCpyNdOp::build( OpBuilder &b, OperationState &result, Value target, @@ -344,6 +414,13 @@ LogicalObjectFifoFromMemrefOp CircularDmaCpyNdOp::getTargetObjectFifo() { return dyn_cast(getTarget().getDefiningOp()); }; +void CircularDmaCpyNdOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add>>(context); +} + //===----------------------------------------------------------------------===// // AMDAIE_LogicalObjectFifoAccessOp //===----------------------------------------------------------------------===// @@ -393,107 +470,6 @@ void LogicalObjectFifoFromMemrefOp::build( build(b, result, type, memref, tiles); } -namespace { -// Simplified from upstream MLIR's foldDynamicIndexList: -LogicalResult foldMixed(SmallVectorImpl &ofrs) { - bool valuesChanged = false; - for (OpFoldResult &ofr : ofrs) { - if (ofr.is()) continue; - Attribute attr; - if (matchPattern(ofr.get(), m_Constant(&attr))) { - ofr = attr; - valuesChanged = true; - } - } - return success(valuesChanged); -} -} // namespace - -template -// Based on upstream MLIR's -// OpWithOffsetSizesAndStridesConstantArgumentFolder -class DoublyStridedFolder final : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpType op, - PatternRewriter &rewriter) const override { - SmallVector tgtMixedOffsets(op.getTargetMixedOffsets()); - SmallVector tgtMixedSizes(op.getTargetMixedSizes()); - SmallVector tgtMixedStrides(op.getTargetMixedStrides()); - SmallVector srcMixedOffsets(op.getSourceMixedOffsets()); - SmallVector srcMixedSizes(op.getSourceMixedSizes()); - SmallVector srcMixedStrides(op.getSourceMixedStrides()); - - // No constant operands were folded, just return; - if (failed(foldMixed(tgtMixedOffsets)) && - failed(foldMixed(tgtMixedSizes)) && - failed(foldMixed(tgtMixedStrides)) && - failed(foldMixed(srcMixedOffsets)) && - failed(foldMixed(srcMixedSizes)) && - failed(foldMixed(srcMixedStrides))) { - return failure(); - } - - ReplacementBuilder::replace(op, rewriter, tgtMixedOffsets, tgtMixedSizes, - tgtMixedStrides, srcMixedOffsets, srcMixedSizes, - srcMixedStrides); - - return success(); - } -}; - -struct NpuDmaCpyNdOpReplacementBuilder { - static void replace(NpuDmaCpyNdOp dmaOp, PatternRewriter &rewriter, - ArrayRef tgtMixedOffsets, - ArrayRef tgtMixedSizes, - ArrayRef tgtMixedStrides, - ArrayRef srcMixedOffsets, - ArrayRef srcMixedSizes, - ArrayRef srcMixedStrides) { - rewriter.replaceOpWithNewOp( - dmaOp, dmaOp.getDma(), tgtMixedOffsets, tgtMixedSizes, tgtMixedStrides, - srcMixedOffsets, srcMixedSizes, srcMixedStrides, dmaOp.getTargetBdId(), - dmaOp.getSourceBdId()); - } -}; - -void NpuDmaCpyNdOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results - .add>( - context); -} - -template -struct DmaCpyNdBaseOpReplacementBuilder { - static void replace(T dmaOp, PatternRewriter &rewriter, - ArrayRef tgtMixedOffsets, - ArrayRef tgtMixedSizes, - ArrayRef tgtMixedStrides, - ArrayRef srcMixedOffsets, - ArrayRef srcMixedSizes, - ArrayRef srcMixedStrides) { - rewriter.replaceOpWithNewOp(dmaOp, dmaOp.getTarget(), tgtMixedOffsets, - tgtMixedSizes, tgtMixedStrides, - dmaOp.getSource(), srcMixedOffsets, - srcMixedSizes, srcMixedStrides); - } -}; - -void DmaCpyNdOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add>>(context); -} - -void CircularDmaCpyNdOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add>>(context); -} - LogicalResult LogicalObjectFifoFromMemrefOp::canonicalize( LogicalObjectFifoFromMemrefOp logicalObjectFifo, PatternRewriter &rewriter) { @@ -674,6 +650,30 @@ bool NpuDmaCpyNdOp::hasDmaWaitOpUser() { [](auto userOp) { return isa(userOp); }); } +namespace { +struct NpuDmaCpyNdOpReplacementBuilder { + static void replace(NpuDmaCpyNdOp dmaOp, PatternRewriter &rewriter, + ArrayRef tgtMixedOffsets, + ArrayRef tgtMixedSizes, + ArrayRef tgtMixedStrides, + ArrayRef srcMixedOffsets, + ArrayRef srcMixedSizes, + ArrayRef srcMixedStrides) { + rewriter.replaceOpWithNewOp( + dmaOp, dmaOp.getDma(), tgtMixedOffsets, tgtMixedSizes, tgtMixedStrides, + srcMixedOffsets, srcMixedSizes, srcMixedStrides, dmaOp.getTargetBdId(), + dmaOp.getSourceBdId()); + } +}; +} // namespace + +void NpuDmaCpyNdOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results + .add>( + context); +} + //===----------------------------------------------------------------------===// // AMDAIE_TileOp //===----------------------------------------------------------------------===//