Skip to content

Commit

Permalink
[ConvertToDma] Add options to tranpose dma dimensions on target side (#…
Browse files Browse the repository at this point in the history
…812)

Pack/unpack ops change the data layout and thus after converting to dma
ops, the dma addressing dimensions are expanded/collapsed and
transposed. Previously, all the dimension transpositions are on the
source side of dma ops. This PR extends the usage to have an option for
transposition happen on the target side.

In applications, we could make choices of transposition on source or
target for pack or unpack ops based on performance and hardware dma
requirements, etc. The motivation comes from [this
discussion](#764 (comment)),
and this PR moves the dma optimization logic to an early pass where the
dma ops are converted.

Note the default options are not changed in this PR (will enable it in a
separate PR with other changes for dma optimization), but I have tested
all four combinations locally to make sure the dma generations are
correct and work e2e. The change of options can be added for example as

```
AMDAIEConvertToDmaOptions dmaOptions;
dmaOptions.packTransposeOnSource = false;
dmaOptions.unpackTransposeOnSource = true;
passManager.addPass(createAMDAIEConvertToDmaPass(dmaOptions));
```
  • Loading branch information
yzhang93 authored Oct 2, 2024
1 parent b7f8fc4 commit 5b816a5
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,28 @@

#define DEBUG_TYPE "iree-amdaie-convert-to-dma"



namespace mlir::iree_compiler::AMDAIE {

namespace {

/// Applies packing to a given input.
LogicalResult packDmaInputs(IREE::LinalgExt::PackOp packOp,
SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes,
SmallVector<OpFoldResult> &strides) {
MLIRContext *ctx = packOp.getContext();
/// Applies dma transposition on the side that has lower number of dimensions,
/// which means the source side for pack ops and the destination side for unpack
/// ops.
template <typename PackOrUnpackOp>
LogicalResult dmaTransposeOnLowerNumDims(PackOrUnpackOp packOrUnpackOp,
SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes,
SmallVector<OpFoldResult> &strides) {
MLIRContext *ctx = packOrUnpackOp.getContext();

llvm::ArrayRef<int64_t> permutation = packOp.getOuterDimsPerm();
llvm::ArrayRef<int64_t> innerTiles = packOp.getStaticInnerTiles();
llvm::ArrayRef<int64_t> permutation = packOrUnpackOp.getOuterDimsPerm();
llvm::ArrayRef<int64_t> innerTiles = packOrUnpackOp.getStaticInnerTiles();

SmallVector<OpFoldResult> innerSizes;
SmallVector<OpFoldResult> innerStrides;
SmallVector<OpFoldResult> innerOffsets;

auto innerDimsPos = packOp.getInnerDimsPos();
ArrayRef<int64_t> innerDimsPos = packOrUnpackOp.getInnerDimsPos();

for (int i = 0; i < innerTiles.size(); i++) {
// Calculate new sizes.
Expand All @@ -52,7 +53,7 @@ LogicalResult packDmaInputs(IREE::LinalgExt::PackOp packOp,
"in dimension {0}, the tile size {1} does not divide the tensor size "
"{2}. Imperfect/partial tiling is currently not supported.",
i, innerTiles[i], size.value());
return packOp->emitOpError(message);
return packOrUnpackOp->emitOpError(message);
}

sizes[innerDimsPos[i]] =
Expand All @@ -71,33 +72,38 @@ LogicalResult packDmaInputs(IREE::LinalgExt::PackOp packOp,
innerOffsets.push_back(offsets[innerDimsPos[i]]);
offsets[innerDimsPos[i]] = getAsIndexOpFoldResult(ctx, 0);
}

// Apply permutations to the outer dims if provided.
if (!permutation.empty()) {
applyPermutationToVector(strides, permutation);
applyPermutationToVector(sizes, permutation);
applyPermutationToVector(offsets, permutation);
}

// Merge the dims.
sizes.insert(sizes.end(), innerSizes.begin(), innerSizes.end());
strides.insert(strides.end(), innerStrides.begin(), innerStrides.end());
offsets.insert(offsets.end(), innerOffsets.begin(), innerOffsets.end());
return success();
}

/// Applies unpacking to a given input.
LogicalResult unPackDmaInputs(IREE::LinalgExt::UnPackOp unPackOp,
SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes,
SmallVector<OpFoldResult> &strides) {
MLIRContext *ctx = unPackOp.getContext();
/// Applies dma transposition on the side which has higher number of dimensions,
/// which means the destination side for pack ops and the source side for unpack
/// ops.
template <typename PackOrUnpackOp>
LogicalResult dmaTransposeOnHigherNumDims(PackOrUnpackOp packOrUnpackOp,
SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes,
SmallVector<OpFoldResult> &strides) {
MLIRContext *ctx = packOrUnpackOp.getContext();

llvm::ArrayRef<int64_t> permutation = unPackOp.getOuterDimsPerm();
llvm::ArrayRef<int64_t> innerTiles = unPackOp.getStaticInnerTiles();
llvm::ArrayRef<int64_t> permutation = packOrUnpackOp.getOuterDimsPerm();
llvm::ArrayRef<int64_t> innerTiles = packOrUnpackOp.getStaticInnerTiles();

SmallVector<OpFoldResult> innerSizes;
SmallVector<OpFoldResult> innerStrides;
SmallVector<OpFoldResult> innerOffsets;
auto innerDimsPos = unPackOp.getInnerDimsPos();
ArrayRef<int64_t> innerDimsPos = packOrUnpackOp.getInnerDimsPos();

int numOuterDims = sizes.size() - innerTiles.size();
SmallVector<OpFoldResult> outerOffsets = SmallVector<OpFoldResult>(
Expand All @@ -116,29 +122,30 @@ LogicalResult unPackDmaInputs(IREE::LinalgExt::UnPackOp unPackOp,
applyPermutationToVector(outerSizes, inversePermutation);
applyPermutationToVector(outerOffsets, inversePermutation);
}
// Do the unpacking on the Outer dims.

// Initialize the indexing of each outer dim.
llvm::SmallDenseMap<int64_t, int64_t> outerDimsIndexMap;
// Intialize the indexing of each outer dim.
for (int i = 0; i < numOuterDims; i++) {
outerDimsIndexMap[i] = i;
}

// Update outer dim sizes/strides/offsts.
for (int i = 0; i < innerTiles.size(); i++) {
// Insert inner dims adjacent to there corresponding outer dims.
outerSizes.insert(
outerSizes.begin() + outerDimsIndexMap[innerDimsPos[i]] + 1,
getAsIndexOpFoldResult(ctx, innerTiles[i]));
outerStrides.insert(
outerStrides.begin() + outerDimsIndexMap[innerDimsPos[i]] + 1,
strides[numOuterDims + i]);
outerOffsets.insert(
outerOffsets.begin() + outerDimsIndexMap[innerDimsPos[i]] + 1,
offsets[numOuterDims + i]);
// Insert inner dims adjacent to their corresponding outer dims.
int insertionIndex = outerDimsIndexMap[innerDimsPos[i]] + 1;
outerSizes.insert(outerSizes.begin() + insertionIndex,
getAsIndexOpFoldResult(ctx, innerTiles[i]));
outerStrides.insert(outerStrides.begin() + insertionIndex,
strides[numOuterDims + i]);
outerOffsets.insert(outerOffsets.begin() + insertionIndex,
offsets[numOuterDims + i]);
// Update the map as all the dimensions inner to the innerDimsPos[i] are now
// shifted by 1.
for (int j = innerDimsPos[i] + 1; j < numOuterDims; j++) {
outerDimsIndexMap[j]++;
}
}

// Make the outer dims as the final returned dims
offsets = outerOffsets;
strides = outerStrides;
Expand All @@ -147,7 +154,7 @@ LogicalResult unPackDmaInputs(IREE::LinalgExt::UnPackOp unPackOp,
}

/// Examines an input/output of a pack/unpack op and provides the
/// corresponding offsets, sizes and strides required by the dma op
/// corresponding offsets, sizes and strides required by the dma op.
LogicalResult setDmaInputs(Operation *&operandOp,
SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes,
Expand Down Expand Up @@ -232,25 +239,6 @@ LogicalResult setDmaInputs(Operation *&operandOp,
"and SubViewOp as inputs.");
}

/// Get the inputs from the pack/unpack op 'op'. Return failure if 'op' is not
/// a pack/unpack op, or if 'op' is determined unlowerable to a DMA operation.
LogicalResult processInputs(Operation *op, SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes,
SmallVector<OpFoldResult> &strides) {
if (auto packOp = dyn_cast<IREE::LinalgExt::PackOp>(op)) {
if (failed(packDmaInputs(packOp, offsets, sizes, strides))) {
return failure();
}
} else if (auto unPackOp = dyn_cast<IREE::LinalgExt::UnPackOp>(op)) {
if (failed(unPackDmaInputs(unPackOp, offsets, sizes, strides))) {
return failure();
}
} else {
return failure();
}
return success();
}

/// Rewrite the pack/unpack op 'op' as a DMA operation. The function arguments
/// 'input', 'output', and 'innerTiles' are the input, output, and inner tile
/// of 'op'. If 'op' is not a pack/unpack op, or if it determined to not
Expand All @@ -260,8 +248,11 @@ LogicalResult processInputs(Operation *op, SmallVector<OpFoldResult> &offsets,
/// obtained from 'op' inside this function if it were templatized, but
/// I've factorized out that logic to reduce the total amount of templatized
/// code.
LogicalResult rewriteAsDma(IRRewriter &rewriter, Operation *op, Value input,
Value output, llvm::ArrayRef<int64_t> innerTiles) {
template <typename PackOrUnpackOp>
LogicalResult rewriteAsDma(IRRewriter &rewriter, PackOrUnpackOp op, Value input,
Value output, llvm::ArrayRef<int64_t> innerTiles,
bool packTransposeOnSource,
bool unpackTransposeOnSource) {
if (llvm::any_of(innerTiles,
[](int64_t size) { return ShapedType::isDynamic(size); })) {
op->emitError("has a non-static shape: not yet supported by this pass.");
Expand All @@ -283,10 +274,6 @@ LogicalResult rewriteAsDma(IRRewriter &rewriter, Operation *op, Value input,
return failure();
}

if (!succeeded(processInputs(op, srcOffsets, srcShape, srcBaseStrides))) {
return failure();
}

// Prepare destination DMA inputs.
SmallVector<OpFoldResult> dstOffsets;
SmallVector<OpFoldResult> dstBaseStrides;
Expand All @@ -295,6 +282,29 @@ LogicalResult rewriteAsDma(IRRewriter &rewriter, Operation *op, Value input,
return failure();
}

// Update dma source or destination addressing based on the side for dma
// transposition and pack/unpack operations.
if (packTransposeOnSource && isa<IREE::LinalgExt::PackOp>(op)) {
if (!succeeded(dmaTransposeOnLowerNumDims(op, srcOffsets, srcShape,
srcBaseStrides)))
return failure();
} else if (!packTransposeOnSource && isa<IREE::LinalgExt::PackOp>(op)) {
if (!succeeded(dmaTransposeOnHigherNumDims(op, dstOffsets, dstShape,
dstBaseStrides)))
return failure();
} else if (unpackTransposeOnSource && isa<IREE::LinalgExt::UnPackOp>(op)) {
if (!succeeded(dmaTransposeOnHigherNumDims(op, srcOffsets, srcShape,
srcBaseStrides)))
return failure();
} else if (!unpackTransposeOnSource && isa<IREE::LinalgExt::UnPackOp>(op)) {
if (!succeeded(dmaTransposeOnLowerNumDims(op, dstOffsets, dstShape,
dstBaseStrides)))
return failure();
} else {
op->emitError("unhandled option for dma addressing update.");
return failure();
}

// Create logical objectFifos from source and destination memrefs.
Value srcVal = sourceOp->getResult(0);
Value dstVal = dstOp->getResult(0);
Expand All @@ -317,11 +327,14 @@ LogicalResult rewriteAsDma(IRRewriter &rewriter, Operation *op, Value input,
}

template <typename PackOrUnpackOp>
LogicalResult rewriteAsDma(PackOrUnpackOp op, IRRewriter &rewriter) {
LogicalResult rewriteAsDma(PackOrUnpackOp op, IRRewriter &rewriter,
bool packTransposeOnSource,
bool unpackTransposeOnSource) {
Value input = op.getInput();
Value output = op.getOutput();
llvm::ArrayRef<int64_t> innerTiles = op.getStaticInnerTiles();
return rewriteAsDma(rewriter, op, input, output, innerTiles);
return rewriteAsDma(rewriter, op, input, output, innerTiles,
packTransposeOnSource, unpackTransposeOnSource);
}

/// Convert a linalg.copy operation on 2 memrefs to an equivalent pack/unpack
Expand Down Expand Up @@ -375,6 +388,9 @@ class AMDAIEConvertToDmaPass

AMDAIEConvertToDmaPass() = default;
AMDAIEConvertToDmaPass(const AMDAIEConvertToDmaPass &pass){};
AMDAIEConvertToDmaPass(const AMDAIEConvertToDmaOptions &options)
: AMDAIEConvertToDmaBase(options) {}

void runOnOperation() override;
};

Expand All @@ -387,32 +403,35 @@ void AMDAIEConvertToDmaPass::runOnOperation() {
// step. This is easy to implement, but not the most direct lowering, so
// we might want to revisit this.
WalkResult convertCopiesWalkResult =
getOperation()->walk([&rewriter](linalg::CopyOp copyOp) {
getOperation()->walk([&](linalg::CopyOp copyOp) {
if (failed(copyToPack(rewriter, copyOp)))
return WalkResult::interrupt();
return WalkResult::advance();
});
if (convertCopiesWalkResult.wasInterrupted()) return signalPassFailure();

auto walkResult = getOperation()->walk(
[&rewriter](IREE::LinalgExt::PackOp op) -> WalkResult {
if (failed(rewriteAsDma(op, rewriter))) {
auto walkResult =
getOperation()->walk([&, this](IREE::LinalgExt::PackOp op) -> WalkResult {
if (failed(rewriteAsDma(op, rewriter, packTransposeOnSource,
unpackTransposeOnSource))) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted()) signalPassFailure();
walkResult = getOperation()->walk(
[&rewriter](IREE::LinalgExt::UnPackOp op) -> WalkResult {
if (failed(rewriteAsDma(op, rewriter))) {
[&, this](IREE::LinalgExt::UnPackOp op) -> WalkResult {
if (failed(rewriteAsDma(op, rewriter, packTransposeOnSource,
unpackTransposeOnSource))) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted()) signalPassFailure();
}

std::unique_ptr<Pass> createAMDAIEConvertToDmaPass() {
return std::make_unique<AMDAIEConvertToDmaPass>();
std::unique_ptr<Pass> createAMDAIEConvertToDmaPass(
AMDAIEConvertToDmaOptions options) {
return std::make_unique<AMDAIEConvertToDmaPass>(options);
}
} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ std::unique_ptr<Pass> createAMDAIEPackAndTransposePass(

/// Create pass to lower copy/pack/unpack ops to AMDAIE DMA ops operating on
/// logical objectFifos.
std::unique_ptr<Pass> createAMDAIEConvertToDmaPass();
std::unique_ptr<Pass> createAMDAIEConvertToDmaPass(
AMDAIEConvertToDmaOptions options = {});

/// Create a pass to pad MatmulOp.
std::unique_ptr<Pass> createAMDAIEPadPass(AMDAIEPadOptions options = {});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,12 @@ def AMDAIEConvertToDma :

}];
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEConvertToDmaPass()";
let options = [
Option<"packTransposeOnSource", "pack-transpose-on-source", "bool", /*default=*/"true",
"Option to set transposed dma dimensions on source or target side for pack ops">,
Option<"unpackTransposeOnSource", "unpack-transpose-on-source", "bool", /*default=*/"true",
"Option to set transposed dma dimensions on source or target side for unpack ops">
];
}

def AMDAIEPad :
Expand Down
Loading

0 comments on commit 5b816a5

Please sign in to comment.