From 762fa08065257ada81a1c1ee97477eb0299cc705 Mon Sep 17 00:00:00 2001 From: James Newling Date: Wed, 25 Sep 2024 12:41:31 -0700 Subject: [PATCH] tidy up --- .../Transforms/AMDAIEConvertToDma.cpp | 92 ++++++++++--------- 1 file changed, 47 insertions(+), 45 deletions(-) diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEConvertToDma.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEConvertToDma.cpp index d61bf158d..8a688c11d 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEConvertToDma.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEConvertToDma.cpp @@ -39,9 +39,9 @@ LogicalResult updateFromPack(IREE::LinalgExt::PackOp packOp, ArrayRef innerDimsPos = packOp.getInnerDimsPos(); assert(offsets.size() == sizes.size() && sizes.size() == strides.size() && - "offsets, sizes, and strides must have the same size"); + "offsets, sizes, and strides must have the same size,"); for (int64_t dim : innerDimsPos) { - assert(dim < sizes.size() && "innerDimsPos must be within sizes"); + assert(dim < sizes.size() && "innerDimsPos must be within sizes."); } SmallVector innerSizes; @@ -53,8 +53,7 @@ LogicalResult updateFromPack(IREE::LinalgExt::PackOp packOp, innerSizes.push_back(getAsIndexOpFoldResult(ctx, innerTiles[i])); std::optional maybeSize = getConstantIntValue(sizes[innerDimsPos[i]]); - if (!maybeSize.has_value()) - packOp->emitOpError("requires a constant size here."); + assert(maybeSize.has_value() && "size expected to be constant here."); int64_t size = maybeSize.value(); if (size % innerTiles[i] != 0) { auto message = llvm::formatv( @@ -155,44 +154,47 @@ LogicalResult updateFromUnPack(IREE::LinalgExt::UnPackOp unPackOp, return success(); } +static bool isAllocation(Operation *op) { + return op && (isa(op) || + isa(op)); +} + +/// Initialize offsets, sizes, and strides from an allocation operation. LogicalResult setFromAlloc(Operation *op, SmallVector &offsets, SmallVector &sizes, SmallVector &strides) { - assert(isa(op) || - isa(op) && - "op must be a memref.alloc or hal.interface.binding.subspan " - "operation"); + assert(isAllocation(op) && + "expected memref.alloc or hal.interface.binding.subspan "); MemRefType memRefType = cast(op->getResult(0).getType()); MLIRContext *ctx = memRefType.getContext(); auto [stridesI64, baseOffset] = getStridesAndOffset(memRefType); + strides = getAsIndexOpFoldResult(ctx, stridesI64); + if (baseOffset != 0) { auto message = llvm::formatv( - "has non-zero base offset {0} is not currently supported.", baseOffset); + "has non-zero offset {0} which is currently unsupported.", baseOffset); return op->emitOpError(message); } - strides = getAsIndexOpFoldResult(ctx, stridesI64); + offsets.resize(strides.size(), getAsIndexOpFoldResult(ctx, 0)); + ArrayRef sizesI64 = memRefType.getShape(); if (llvm::any_of(sizesI64, [](int64_t size) { return ShapedType::isDynamic(size); })) { - return op->emitOpError("has dynamic size, which is not supported in dma."); + return op->emitOpError("has dynamic size, which is unsupported in DMA."); } sizes = getAsIndexOpFoldResult(ctx, sizesI64); - - // Alloc Op has no offsets. - for (int i = 0; i < sizes.size(); i++) { - offsets.push_back(getAsIndexOpFoldResult(ctx, 0)); - } return success(); } /// Return a + b. -SmallVector getSum(OpBuilder &builder, Location loc, - ArrayRef a, - ArrayRef b) { - assert(a.size() == b.size() && "a and b not same size"); +SmallVector getIndexOpFoldResultSum(OpBuilder &builder, + Location loc, + ArrayRef lhs, + ArrayRef rhs) { + assert(lhs.size() == rhs.size() && "a and b not same size"); SmallVector sum; - sum.reserve(a.size()); + sum.reserve(lhs.size()); auto getConstant = [&](int64_t v) -> Value { return builder.create( @@ -206,15 +208,15 @@ SmallVector getSum(OpBuilder &builder, Location loc, .getResult(); }; - for (uint64_t i = 0; i < a.size(); ++i) { + for (uint64_t i = 0; i < lhs.size(); ++i) { IntegerAttr aAttr; - if (auto aAttr_ = dyn_cast(a[i])) { + if (auto aAttr_ = dyn_cast(lhs[i])) { aAttr = dyn_cast(aAttr_); assert(aAttr && "Expected an IntegerAttr"); } IntegerAttr bAttr; - if (auto bAttr_ = dyn_cast(b[i])) { + if (auto bAttr_ = dyn_cast(rhs[i])) { bAttr = dyn_cast(bAttr_); assert(bAttr && "Expected an IntegerAttr"); } @@ -223,20 +225,19 @@ SmallVector getSum(OpBuilder &builder, Location loc, sum.push_back(getAsIndexOpFoldResult(builder.getContext(), aAttr.getInt() + bAttr.getInt())); } else if (!aAttr && !bAttr) { - sum.push_back( - builder - .create(loc, cast(a[i]), cast(b[i])) - .getResult()); + sum.push_back(builder + .create(loc, cast(lhs[i]), + cast(rhs[i])) + .getResult()); } else if (!aAttr && bAttr) { - sum.push_back(add(cast(a[i]), bAttr)); + sum.push_back(add(cast(lhs[i]), bAttr)); } else if (aAttr && !bAttr) { - sum.push_back(add(cast(b[i]), aAttr)); + sum.push_back(add(cast(rhs[i]), aAttr)); } else { assert(false && "unreachable"); } } - assert(sum.size() == a.size() && "sum and a not same size"); return sum; } @@ -296,13 +297,15 @@ OpFoldResult getLinearCombination(OpBuilder &builder, Location loc, return combination; } + +/// Update the offsets, sizes, and strides from a collapse shape operation. LogicalResult updateFromCollapseShape(memref::CollapseShapeOp collapseOp, SmallVector &offsets, SmallVector &sizes, SmallVector &strides) { auto reassociationIndices = collapseOp.getReassociationIndices(); - ArrayRef resultShape = collapseOp.getType().getShape(); ArrayRef inputShape = collapseOp.getSrcType().getShape(); + ArrayRef resultShape = collapseOp.getType().getShape(); uint64_t resultRank = resultShape.size(); MLIRContext *ctx = collapseOp.getContext(); @@ -314,14 +317,14 @@ LogicalResult updateFromCollapseShape(memref::CollapseShapeOp collapseOp, } strides.resize(resultRank); - // Set sizes to output shape, ensuring that all dims are static. + // Set sizes to output shape, and check that all dims are static. sizes.clear(); for (int64_t dim : resultShape) { if (dim == ShapedType::kDynamic) { return collapseOp.emitOpError( "has a dynamic shape which is currently unsupported."); } - sizes.push_back(getAsIndexOpFoldResult(collapseOp.getContext(), dim)); + sizes.push_back(getAsIndexOpFoldResult(ctx, dim)); } // Offsets - merge reassocation groups. @@ -348,6 +351,7 @@ LogicalResult updateFromCollapseShape(memref::CollapseShapeOp collapseOp, return success(); } +/// Update the offsets, sizes, and strides from an expand shape operation. LogicalResult updateFromExpandShape(memref::ExpandShapeOp expandShapeOp, SmallVector &offsets, SmallVector &sizes, @@ -375,9 +379,10 @@ LogicalResult updateFromExpandShape(memref::ExpandShapeOp expandShapeOp, } } - // Offsets. For now we don't do any modular arithmetic, in theory we need to - // split the offset amongst the reassociation indices, but for now I'm just - // putting the offset on the inner most dimension. + // Offsets. For now we don't do any arithmetic to split the offset across + // dimensions, in theory we need to split the offset amongst the reassociation + // indices, but for now I'm just putting the offset on the inner most + // dimension. SmallVector newOffsets(resultShape.size()); for (int i = 0; i < resultShape.size(); i++) { newOffsets[i] = getAsIndexOpFoldResult(ctx, 0); @@ -394,6 +399,8 @@ LogicalResult updateFromExpandShape(memref::ExpandShapeOp expandShapeOp, return success(); } + +/// Update the offsets, sizes, and strides from a subview operation. LogicalResult updateFromSubView(memref::SubViewOp subviewOp, SmallVector &offsets, SmallVector &sizes, @@ -402,8 +409,8 @@ LogicalResult updateFromSubView(memref::SubViewOp subviewOp, OpBuilder builder(subviewOp.getContext()); builder.setInsertionPoint(subviewOp); - offsets = - getSum(builder, subviewOp.getLoc(), offsets, subviewOp.getMixedOffsets()); + offsets = getIndexOpFoldResultSum(builder, subviewOp.getLoc(), offsets, + subviewOp.getMixedOffsets()); sizes = subviewOp.getMixedSizes(); if (llvm::any_of(sizes, [](OpFoldResult size) { @@ -435,7 +442,7 @@ LogicalResult updateFromSubView(memref::SubViewOp subviewOp, } /// Provide the offsets, sizes and strides of the inputs to `operandOp`. -/// This function update `operandOp`, setting it to the allocation operation +/// This function updates `operandOp`, setting it to the allocation operation /// that it originates from. LogicalResult setDmaInputs(Operation *&operandOp, SmallVector &offsets, @@ -446,11 +453,6 @@ LogicalResult setDmaInputs(Operation *&operandOp, if (!operandOp) assert(false && "operandOp must be non-null"); - auto isAllocation = [](Operation *op) { - return op && (isa(op) || - isa(op)); - }; - // Get the sequence of memref operations going from an allocation to // `operandOp` SmallVector chain;