Skip to content

Commit

Permalink
use approach based on upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
newling committed Aug 16, 2024
1 parent 93b20b1 commit 1de7f26
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 67 deletions.
132 changes: 67 additions & 65 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
#include "iree-amd-aie/IR/AMDAIEOps.h"

#include "iree-amd-aie/IR/AMDAIEDialect.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/OpDefinition.h"

#define GET_OP_CLASSES
Expand Down Expand Up @@ -395,72 +393,76 @@ void LogicalObjectFifoFromMemrefOp::build(
build(b, result, type, memref, tiles);
}

LogicalResult NpuDmaCpyNdOp::canonicalize(NpuDmaCpyNdOp dmaOp,
PatternRewriter &rewriter) {
// First check if any of offsets, sizes or strides are constant operands which
// can be made static.
auto canFold = [&](ArrayRef<OpFoldResult> mixed,
ArrayRef<int64_t> statics) -> bool {
for (uint64_t i = 0; i < statics.size(); ++i) {
if (statics[i] == ShapedType::kDynamic) {
auto maybeConstant = getConstantIntValue(mixed[i]);
if (maybeConstant.has_value()) return true;
}
namespace {
// Simplified from upstream MLIR's foldDynamicIndexList:
LogicalResult foldMixed(SmallVectorImpl<OpFoldResult> &ofrs) {
bool valuesChanged = false;
for (OpFoldResult &ofr : ofrs) {
if (ofr.is<Attribute>()) continue;
Attribute attr;
if (matchPattern(ofr.get<Value>(), m_Constant(&attr))) {
ofr = attr;
valuesChanged = true;
}
return false;
};

if (!canFold(dmaOp.getSourceMixedOffsets(), dmaOp.getSourceStaticOffsets()) &&
!canFold(dmaOp.getSourceMixedSizes(), dmaOp.getSourceStaticSizes()) &&
!canFold(dmaOp.getSourceMixedStrides(), dmaOp.getSourceStaticStrides()) &&
!canFold(dmaOp.getTargetMixedOffsets(), dmaOp.getTargetStaticOffsets()) &&
!canFold(dmaOp.getTargetMixedSizes(), dmaOp.getTargetStaticSizes()) &&
!canFold(dmaOp.getTargetMixedStrides(), dmaOp.getTargetStaticStrides())) {
return failure();
}

// Make the change, creating new static and dynamic dimensions for all.
auto getNew = [&](ArrayRef<OpFoldResult> mixed, ArrayRef<int64_t> statics)
-> std::tuple<SmallVector<int64_t>, SmallVector<Value>> {
SmallVector<int64_t> newStatics;
SmallVector<Value> newDynamics;
for (uint64_t i = 0; i < statics.size(); ++i) {
if (statics[i] == ShapedType::kDynamic) {
auto maybeConstant = getConstantIntValue(mixed[i]);
if (maybeConstant.has_value()) {
newStatics.push_back(maybeConstant.value());
} else {
newStatics.push_back(ShapedType::kDynamic);
newDynamics.push_back(mixed[i].get<Value>());
}
} else {
newStatics.push_back(statics[i]);
}
return success(valuesChanged);
}
} // namespace

template <typename OpType, typename ReplacementBuilder>
// Based on upstream MLIR's
// OpWithOffsetSizesAndStridesConstantArgumentFolder
class DoublyStridedFolder final : public OpRewritePattern<OpType> {
public:
using OpRewritePattern<OpType>::OpRewritePattern;

LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rewriter) const override {
SmallVector<OpFoldResult> tgtMixedOffsets(op.getTargetMixedOffsets());
SmallVector<OpFoldResult> tgtMixedSizes(op.getTargetMixedSizes());
SmallVector<OpFoldResult> tgtMixedStrides(op.getTargetMixedStrides());
SmallVector<OpFoldResult> srcMixedOffsets(op.getSourceMixedOffsets());
SmallVector<OpFoldResult> srcMixedSizes(op.getSourceMixedSizes());
SmallVector<OpFoldResult> srcMixedStrides(op.getSourceMixedStrides());

// No constant operands were folded, just return;
if (failed(foldMixed(tgtMixedOffsets)) &&
failed(foldMixed(tgtMixedSizes)) &&
failed(foldMixed(tgtMixedStrides)) &&
failed(foldMixed(srcMixedOffsets)) &&
failed(foldMixed(srcMixedSizes)) &&
failed(foldMixed(srcMixedStrides))) {
return failure();
}
return {newStatics, newDynamics};
};

auto [srcOffsetStatic, srcOffsetValues] =
getNew(dmaOp.getSourceMixedOffsets(), dmaOp.getSourceStaticOffsets());
auto [srcSizeStatic, srcSizeValues] =
getNew(dmaOp.getSourceMixedSizes(), dmaOp.getSourceStaticSizes());
auto [srcStrideStatic, srcStrideValues] =
getNew(dmaOp.getSourceMixedStrides(), dmaOp.getSourceStaticStrides());

auto [tgtOffsetStatic, tgtOffsetValues] =
getNew(dmaOp.getTargetMixedOffsets(), dmaOp.getTargetStaticOffsets());
auto [tgtSizeStatic, tgtSizeValues] =
getNew(dmaOp.getTargetMixedSizes(), dmaOp.getTargetStaticSizes());
auto [tgtStrideStatic, tgtStrideValues] =
getNew(dmaOp.getTargetMixedStrides(), dmaOp.getTargetStaticStrides());

rewriter.replaceOpWithNewOp<AMDAIE::NpuDmaCpyNdOp>(
dmaOp, dmaOp.getDma(), tgtOffsetValues, tgtSizeValues, tgtStrideValues,
tgtOffsetStatic, tgtSizeStatic, tgtStrideStatic, srcOffsetValues,
srcSizeValues, srcStrideValues, srcOffsetStatic, srcSizeStatic,
srcStrideStatic, dmaOp.getTargetBdId(), dmaOp.getSourceBdId());

return success();
ReplacementBuilder::replace(op, rewriter, tgtMixedOffsets, tgtMixedSizes,
tgtMixedStrides, srcMixedOffsets, srcMixedSizes,
srcMixedStrides);

return success();
}
};

struct NpuDmaCpyNdOpReplacementBuilder {
static void replace(NpuDmaCpyNdOp dmaOp, PatternRewriter &rewriter,
ArrayRef<OpFoldResult> tgtMixedOffsets,
ArrayRef<OpFoldResult> tgtMixedSizes,
ArrayRef<OpFoldResult> tgtMixedStrides,
ArrayRef<OpFoldResult> srcMixedOffsets,
ArrayRef<OpFoldResult> srcMixedSizes,
ArrayRef<OpFoldResult> srcMixedStrides) {
rewriter.replaceOpWithNewOp<NpuDmaCpyNdOp>(
dmaOp, dmaOp.getDma(), tgtMixedOffsets, tgtMixedSizes, tgtMixedStrides,
srcMixedOffsets, srcMixedSizes, srcMixedStrides, dmaOp.getTargetBdId(),
dmaOp.getSourceBdId());
}
};

void NpuDmaCpyNdOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results
.add<DoublyStridedFolder<NpuDmaCpyNdOp, NpuDmaCpyNdOpReplacementBuilder>>(
context);
}

LogicalResult LogicalObjectFifoFromMemrefOp::canonicalize(
Expand Down
12 changes: 10 additions & 2 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,7 @@ def AMDAIE_NpuDmaCpyNdOp: AMDAIE_Op<"npu.dma_cpy_nd",
}];

// Ensure that dimensions of offsets/sizes/strides that can be static, are.
// TODO(newling) make this a canonicalization for all doubly strided ops.
let hasCanonicalizeMethod = 1;
let hasCanonicalizer = 1;

}

Expand Down Expand Up @@ -790,6 +789,7 @@ class AMDAIE_DmaCpyNdBaseOp<string mnemonic, list<Trait> traits = []> :
::llvm::SmallVector<::mlir::OpFoldResult>& newSourceSizes,
::llvm::SmallVector<::mlir::OpFoldResult>& newSourceStrides);
}];

}

def AMDAIE_DmaCpyNdOp: AMDAIE_DmaCpyNdBaseOp<"dma_cpy_nd", []> {
Expand All @@ -815,6 +815,10 @@ def AMDAIE_DmaCpyNdOp: AMDAIE_DmaCpyNdBaseOp<"dma_cpy_nd", []> {
}];

let hasVerifier = 0;

// Ensure that dimensions of offsets/sizes/strides that can be static, are.
// TODO(newling)
// let hasCanonicalizer = 1;
}

def AMDAIE_CircularDmaCpyNdOp: AMDAIE_DmaCpyNdBaseOp<"circular_dma_cpy_nd", [Pure]> {
Expand All @@ -839,6 +843,10 @@ def AMDAIE_CircularDmaCpyNdOp: AMDAIE_DmaCpyNdBaseOp<"circular_dma_cpy_nd", [Pur
}];

let hasVerifier = 0;

// Ensure that dimensions of offsets/sizes/strides that can be static, are.
// TODO(newling)
// let hasCanonicalizer = 1;
}

def AMDAIE_ReferenceToOp: AMDAIE_Op<"reference_to", [SameOperandsAndResultType]> {
Expand Down

0 comments on commit 1de7f26

Please sign in to comment.