diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.cpp index 410413ce1..3c6bbd50b 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.cpp @@ -7,11 +7,9 @@ #include "iree-amd-aie/IR/AMDAIEOps.h" #include "iree-amd-aie/IR/AMDAIEDialect.h" -#include "llvm/ADT/TypeSwitch.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Transform/IR/TransformOps.h" -#include "mlir/IR/DialectImplementation.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/OpDefinition.h" #define GET_OP_CLASSES @@ -395,72 +393,76 @@ void LogicalObjectFifoFromMemrefOp::build( build(b, result, type, memref, tiles); } -LogicalResult NpuDmaCpyNdOp::canonicalize(NpuDmaCpyNdOp dmaOp, - PatternRewriter &rewriter) { - // First check if any of offsets, sizes or strides are constant operands which - // can be made static. - auto canFold = [&](ArrayRef mixed, - ArrayRef statics) -> bool { - for (uint64_t i = 0; i < statics.size(); ++i) { - if (statics[i] == ShapedType::kDynamic) { - auto maybeConstant = getConstantIntValue(mixed[i]); - if (maybeConstant.has_value()) return true; - } +namespace { +// Simplified from upstream MLIR's foldDynamicIndexList: +LogicalResult foldMixed(SmallVectorImpl &ofrs) { + bool valuesChanged = false; + for (OpFoldResult &ofr : ofrs) { + if (ofr.is()) continue; + Attribute attr; + if (matchPattern(ofr.get(), m_Constant(&attr))) { + ofr = attr; + valuesChanged = true; } - return false; - }; - - if (!canFold(dmaOp.getSourceMixedOffsets(), dmaOp.getSourceStaticOffsets()) && - !canFold(dmaOp.getSourceMixedSizes(), dmaOp.getSourceStaticSizes()) && - !canFold(dmaOp.getSourceMixedStrides(), dmaOp.getSourceStaticStrides()) && - !canFold(dmaOp.getTargetMixedOffsets(), dmaOp.getTargetStaticOffsets()) && - !canFold(dmaOp.getTargetMixedSizes(), dmaOp.getTargetStaticSizes()) && - !canFold(dmaOp.getTargetMixedStrides(), dmaOp.getTargetStaticStrides())) { - return failure(); } - - // Make the change, creating new static and dynamic dimensions for all. - auto getNew = [&](ArrayRef mixed, ArrayRef statics) - -> std::tuple, SmallVector> { - SmallVector newStatics; - SmallVector newDynamics; - for (uint64_t i = 0; i < statics.size(); ++i) { - if (statics[i] == ShapedType::kDynamic) { - auto maybeConstant = getConstantIntValue(mixed[i]); - if (maybeConstant.has_value()) { - newStatics.push_back(maybeConstant.value()); - } else { - newStatics.push_back(ShapedType::kDynamic); - newDynamics.push_back(mixed[i].get()); - } - } else { - newStatics.push_back(statics[i]); - } + return success(valuesChanged); +} +} // namespace + +template +// Based on upstream MLIR's +// OpWithOffsetSizesAndStridesConstantArgumentFolder +class DoublyStridedFolder final : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpType op, + PatternRewriter &rewriter) const override { + SmallVector tgtMixedOffsets(op.getTargetMixedOffsets()); + SmallVector tgtMixedSizes(op.getTargetMixedSizes()); + SmallVector tgtMixedStrides(op.getTargetMixedStrides()); + SmallVector srcMixedOffsets(op.getSourceMixedOffsets()); + SmallVector srcMixedSizes(op.getSourceMixedSizes()); + SmallVector srcMixedStrides(op.getSourceMixedStrides()); + + // No constant operands were folded, just return; + if (failed(foldMixed(tgtMixedOffsets)) && + failed(foldMixed(tgtMixedSizes)) && + failed(foldMixed(tgtMixedStrides)) && + failed(foldMixed(srcMixedOffsets)) && + failed(foldMixed(srcMixedSizes)) && + failed(foldMixed(srcMixedStrides))) { + return failure(); } - return {newStatics, newDynamics}; - }; - - auto [srcOffsetStatic, srcOffsetValues] = - getNew(dmaOp.getSourceMixedOffsets(), dmaOp.getSourceStaticOffsets()); - auto [srcSizeStatic, srcSizeValues] = - getNew(dmaOp.getSourceMixedSizes(), dmaOp.getSourceStaticSizes()); - auto [srcStrideStatic, srcStrideValues] = - getNew(dmaOp.getSourceMixedStrides(), dmaOp.getSourceStaticStrides()); - - auto [tgtOffsetStatic, tgtOffsetValues] = - getNew(dmaOp.getTargetMixedOffsets(), dmaOp.getTargetStaticOffsets()); - auto [tgtSizeStatic, tgtSizeValues] = - getNew(dmaOp.getTargetMixedSizes(), dmaOp.getTargetStaticSizes()); - auto [tgtStrideStatic, tgtStrideValues] = - getNew(dmaOp.getTargetMixedStrides(), dmaOp.getTargetStaticStrides()); - - rewriter.replaceOpWithNewOp( - dmaOp, dmaOp.getDma(), tgtOffsetValues, tgtSizeValues, tgtStrideValues, - tgtOffsetStatic, tgtSizeStatic, tgtStrideStatic, srcOffsetValues, - srcSizeValues, srcStrideValues, srcOffsetStatic, srcSizeStatic, - srcStrideStatic, dmaOp.getTargetBdId(), dmaOp.getSourceBdId()); - return success(); + ReplacementBuilder::replace(op, rewriter, tgtMixedOffsets, tgtMixedSizes, + tgtMixedStrides, srcMixedOffsets, srcMixedSizes, + srcMixedStrides); + + return success(); + } +}; + +struct NpuDmaCpyNdOpReplacementBuilder { + static void replace(NpuDmaCpyNdOp dmaOp, PatternRewriter &rewriter, + ArrayRef tgtMixedOffsets, + ArrayRef tgtMixedSizes, + ArrayRef tgtMixedStrides, + ArrayRef srcMixedOffsets, + ArrayRef srcMixedSizes, + ArrayRef srcMixedStrides) { + rewriter.replaceOpWithNewOp( + dmaOp, dmaOp.getDma(), tgtMixedOffsets, tgtMixedSizes, tgtMixedStrides, + srcMixedOffsets, srcMixedSizes, srcMixedStrides, dmaOp.getTargetBdId(), + dmaOp.getSourceBdId()); + } +}; + +void NpuDmaCpyNdOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results + .add>( + context); } LogicalResult LogicalObjectFifoFromMemrefOp::canonicalize( diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.td b/compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.td index bc5321dcf..9b152ac2b 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.td +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.td @@ -381,8 +381,7 @@ def AMDAIE_NpuDmaCpyNdOp: AMDAIE_Op<"npu.dma_cpy_nd", }]; // Ensure that dimensions of offsets/sizes/strides that can be static, are. - // TODO(newling) make this a canonicalization for all doubly strided ops. - let hasCanonicalizeMethod = 1; + let hasCanonicalizer = 1; } @@ -790,6 +789,7 @@ class AMDAIE_DmaCpyNdBaseOp traits = []> : ::llvm::SmallVector<::mlir::OpFoldResult>& newSourceSizes, ::llvm::SmallVector<::mlir::OpFoldResult>& newSourceStrides); }]; + } def AMDAIE_DmaCpyNdOp: AMDAIE_DmaCpyNdBaseOp<"dma_cpy_nd", []> { @@ -815,6 +815,10 @@ def AMDAIE_DmaCpyNdOp: AMDAIE_DmaCpyNdBaseOp<"dma_cpy_nd", []> { }]; let hasVerifier = 0; + + // Ensure that dimensions of offsets/sizes/strides that can be static, are. + // TODO(newling) + // let hasCanonicalizer = 1; } def AMDAIE_CircularDmaCpyNdOp: AMDAIE_DmaCpyNdBaseOp<"circular_dma_cpy_nd", [Pure]> { @@ -839,6 +843,10 @@ def AMDAIE_CircularDmaCpyNdOp: AMDAIE_DmaCpyNdBaseOp<"circular_dma_cpy_nd", [Pur }]; let hasVerifier = 0; + + // Ensure that dimensions of offsets/sizes/strides that can be static, are. + // TODO(newling) + // let hasCanonicalizer = 1; } def AMDAIE_ReferenceToOp: AMDAIE_Op<"reference_to", [SameOperandsAndResultType]> {