From 86ed9276e628d0b97b72766d1f0369a4d7b18db1 Mon Sep 17 00:00:00 2001 From: James Newling Date: Thu, 26 Sep 2024 21:53:56 -0700 Subject: [PATCH] address review comments --- .../Transforms/AMDAIEConvertToDma.cpp | 151 ++++++++++++------ .../Transforms/test/convert_to_dma.mlir | 26 +-- 2 files changed, 112 insertions(+), 65 deletions(-) diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEConvertToDma.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEConvertToDma.cpp index 4a93ee31e..29ca0cc6d 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEConvertToDma.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEConvertToDma.cpp @@ -39,7 +39,7 @@ LogicalResult updateFromPack(IREE::LinalgExt::PackOp packOp, ArrayRef innerDimsPos = packOp.getInnerDimsPos(); assert(offsets.size() == sizes.size() && sizes.size() == strides.size() && - "offsets, sizes, and strides must have the same size,"); + "offsets, sizes, and strides must have the same size."); for (int64_t dim : innerDimsPos) { assert(dim < sizes.size() && "innerDimsPos must be within sizes."); } @@ -53,7 +53,9 @@ LogicalResult updateFromPack(IREE::LinalgExt::PackOp packOp, innerSizes.push_back(getAsIndexOpFoldResult(ctx, innerTiles[i])); std::optional maybeSize = getConstantIntValue(sizes[innerDimsPos[i]]); - assert(maybeSize.has_value() && "size expected to be constant here."); + if (!maybeSize.has_value()) { + packOp->emitOpError("requires all constant sizes."); + } int64_t size = maybeSize.value(); if (size % innerTiles[i] != 0) { auto message = llvm::formatv( @@ -67,12 +69,13 @@ LogicalResult updateFromPack(IREE::LinalgExt::PackOp packOp, // The tiled dim inherits the stride from the corresponding outer dim and // the outer dims stride gets multiplied by the size of the tile. innerStrides.push_back(strides[innerDimsPos[i]]); - std::optional stride = + std::optional maybeStride = getConstantIntValue(strides[innerDimsPos[i]]); - if (!stride.has_value()) + if (!maybeStride.has_value()) packOp->emitOpError("requires a constant stride here."); + int64_t stride = maybeStride.value(); strides[innerDimsPos[i]] = - getAsIndexOpFoldResult(ctx, stride.value() * innerTiles[i]); + getAsIndexOpFoldResult(ctx, stride * innerTiles[i]); // The tiled dim inherits the offset from the corresponding outer dim and // the outer dim offset is set to zero. @@ -164,7 +167,7 @@ LogicalResult setFromAlloc(Operation *op, SmallVector &offsets, SmallVector &sizes, SmallVector &strides) { assert(isAllocation(op) && - "expected memref.alloc or hal.interface.binding.subspan "); + "expected memref.alloc or hal.interface.binding.subspan."); MemRefType memRefType = cast(op->getResult(0).getType()); MLIRContext *ctx = memRefType.getContext(); @@ -187,12 +190,15 @@ LogicalResult setFromAlloc(Operation *op, SmallVector &offsets, return success(); } -/// Return a + b. +/// Return `lhs` + `rhs`, where `lhs` and `rhs` are OpFoldResults of integers. +/// +/// The implementation considers the 4 cases of +/// (`lhs`, `rhs`) in {Attribute, Value} x {Attribute, Value}. SmallVector getIndexOpFoldResultSum(OpBuilder &builder, Location loc, ArrayRef lhs, ArrayRef rhs) { - assert(lhs.size() == rhs.size() && "a and b not same size"); + assert(lhs.size() == rhs.size() && "lhs and rhs not same size."); SmallVector sum; sum.reserve(lhs.size()); @@ -212,13 +218,13 @@ SmallVector getIndexOpFoldResultSum(OpBuilder &builder, IntegerAttr aAttr; if (auto aAttr_ = dyn_cast(lhs[i])) { aAttr = dyn_cast(aAttr_); - assert(aAttr && "Expected an IntegerAttr"); + assert(aAttr && "Expected IntegerAttr."); } IntegerAttr bAttr; if (auto bAttr_ = dyn_cast(rhs[i])) { bAttr = dyn_cast(bAttr_); - assert(bAttr && "Expected an IntegerAttr"); + assert(bAttr && "Expected IntegerAttr."); } if (aAttr && bAttr) { @@ -241,7 +247,11 @@ SmallVector getIndexOpFoldResultSum(OpBuilder &builder, return sum; } -/// Return sum_{i} values[i] * coeffs[i]. +/// Return sum_{i} values[i] * coeffs[i], where +/// +/// - values are OpFoldResults (i.e. each element in `values` is +/// either an mlir::Value or mlir::Attribute) +/// - coeffs are integers. OpFoldResult getLinearCombination(OpBuilder &builder, Location loc, ArrayRef values, ArrayRef coeffs) { @@ -255,49 +265,57 @@ OpFoldResult getLinearCombination(OpBuilder &builder, Location loc, }; // Initialize the linear combination to 0. - OpFoldResult combination = builder.getIndexAttr(0); + OpFoldResult lc = builder.getIndexAttr(0); + // For eacho of the (value, coeff) pairs, add the product to the linear + // combination, updating `lc` in each iteration. The implementation + // here is careful not to create constant zero values. for (uint64_t dim = 0; dim < coeffs.size(); ++dim) { + // Four cases considered: + // 1) `values[dim]` is an attribute (constant) and `lc` is also + // an attribute (constant) + // 2) `values[dim]` is an attribute (constant) and `lc` is a Value + // (non-constant) + // 3) `values[dim]` is a Value (non-constant) and `lc` is an + // attribute (constant) + // 4) `values[dim]` is a Value (non-constant) and `lc` is also a + // Value (non-constant) if (auto valueAttr = dyn_cast(values[dim])) { int64_t term = coeffs[dim] * cast(valueAttr).getInt(); - - // Case where both `combination` and `value` are constant. - if (auto combinationAttr = dyn_cast(combination)) { - combination = getAsIndexOpFoldResult( - ctx, term + cast(combinationAttr).getInt()); + // Case 1. + if (auto lcAttr = dyn_cast(lc)) { + lc = getAsIndexOpFoldResult(ctx, + term + cast(lcAttr).getInt()); } - - // Case where `combination` is not constant, `value` is constant. + // Case 2. else if (term != 0) { - combination = builder - .create(loc, cast(combination), - getConstant(term)) - .getResult(); + lc = builder + .create(loc, cast(lc), getConstant(term)) + .getResult(); } } else { Value term = builder.create(loc, cast(values[dim]), getConstant(coeffs[dim])); - // Case where `combination` is constant, `value` is not constant. - if (auto combinationAttr = dyn_cast(combination)) { - int64_t c = cast(combinationAttr).getInt(); + // Case 3. + if (auto lcAttr = dyn_cast(lc)) { + int64_t c = cast(lcAttr).getInt(); if (c != 0) { - combination = builder.create(loc, getConstant(c), term) - .getResult(); + lc = builder.create(loc, getConstant(c), term) + .getResult(); } else { - combination = term; + lc = term; } - } else { - // Case where neither `combination` nor `value` is constant. - Value combinationVal = cast(combination); - combination = builder.create(loc, combinationVal, term) - .getResult(); + } + // Case 4. + else { + lc = builder.create(loc, cast(lc), term) + .getResult(); } } } - return combination; + return lc; } - /// Update the offsets, sizes, and strides from a collapse shape operation. LogicalResult updateFromCollapseShape(memref::CollapseShapeOp collapseOp, SmallVector &offsets, @@ -310,6 +328,11 @@ LogicalResult updateFromCollapseShape(memref::CollapseShapeOp collapseOp, MLIRContext *ctx = collapseOp.getContext(); // Set strides to inner-most stride in each reassocation group. + // + // Example: Consider a 2x3x5x7 tensor, with strides [70,35,7,1]. If this + // is collapsed to shape 6x35, the srides are [35,1]. The reassociation + // groups are [0,1] and [2,3], and so we've just taken the inner-most + // strides in each group. for (auto reassociation : llvm::enumerate(reassociationIndices)) { uint64_t index = reassociation.index(); uint64_t dim = reassociation.value().back(); @@ -327,15 +350,21 @@ LogicalResult updateFromCollapseShape(memref::CollapseShapeOp collapseOp, sizes.push_back(getAsIndexOpFoldResult(ctx, dim)); } - // Offsets - merge reassocation groups. + // Offsets - merge reassocation groups by taking linear combinations of the + // offsets with local strides. Using the example of the shape of 2x3x5x7 + // being collapsed to 6x35, if the initial offsets are [a,b,c,d], the + // collapsed offsets are [a*3 + b, c*7 + d]. SmallVector collapsedOffsets; for (auto reassociation : llvm::enumerate(reassociationIndices)) { auto dims = reassociation.value(); + + // The strides within the group: SmallVector localStrides(dims.size(), 1); for (uint64_t i = 1; i < dims.size(); ++i) { uint64_t dim = dims.size() - i - 1; localStrides[dim] = localStrides[dim + 1] * inputShape[dims[dim + 1]]; } + OpBuilder builder(ctx); builder.setInsertionPoint(collapseOp); OpFoldResult combination = getLinearCombination( @@ -360,20 +389,29 @@ LogicalResult updateFromExpandShape(memref::ExpandShapeOp expandShapeOp, auto reassociationIndices = expandShapeOp.getReassociationIndices(); ArrayRef resultShape = expandShapeOp.getType().getShape(); - // Sizes. + // Set the sizes to the output shape, and check that all dims are static. SmallVector newSizes(resultShape.size()); for (int i = 0; i < resultShape.size(); i++) { + if (resultShape[i] == ShapedType::kDynamic) { + return expandShapeOp.emitOpError( + "has a dynamic shape which is currently unsupported."); + } newSizes[i] = getAsIndexOpFoldResult(ctx, resultShape[i]); } - // Strides. + // Strides. Using the example expanding from a shape of 6x35 to 2x3x5x7, where + // the initial strides are [50, 1], the new strides will be [150, 50, 7, 1]. SmallVector newStrides(resultShape.size()); for (auto reassociation : llvm::enumerate(reassociationIndices)) { - auto index = reassociation.index(); + uint64_t index = reassociation.index(); auto dims = reassociation.value(); - int64_t cum = getConstantIntValue(strides[index]).value(); + OpFoldResult stride = strides[index]; + if (!isa(stride)) { + return expandShapeOp.emitOpError("cannot operate on a dynamic stride."); + } + int64_t cum = getConstantIntValue(stride).value(); for (uint64_t i = 0; i < dims.size(); i++) { - auto d = dims[dims.size() - i - 1]; + uint64_t d = dims[dims.size() - i - 1]; newStrides[d] = getAsIndexOpFoldResult(ctx, cum); cum *= resultShape[d]; } @@ -381,14 +419,21 @@ LogicalResult updateFromExpandShape(memref::ExpandShapeOp expandShapeOp, // Offsets. For now we don't do any arithmetic to split the offset across // dimensions, in theory we need to split the offset amongst the reassociation - // indices, but for now I'm just putting the offset on the inner most + // indices, but for now we're just putting the offset on the inner most // dimension. + // + // Example: suppose we're expanding from 6x35 to 2x3x5x7, and the initial + // offsets are [a, b]. The new offsets will be [0, a, 0, b]. In theory they + // should be [a/3, a%3, b/7 b%7] but these offsets ultimately get collapsed + // anyway so it doesn't matter if we don't. SmallVector newOffsets(resultShape.size()); + // Initialize all ofsets to 0: for (int i = 0; i < resultShape.size(); i++) { newOffsets[i] = getAsIndexOpFoldResult(ctx, 0); } + // Populate the inner-most dimensions with the original offsets: for (auto reassociation : llvm::enumerate(reassociationIndices)) { - auto index = reassociation.index(); + uint64_t index = reassociation.index(); auto dims = reassociation.value(); newOffsets[dims.back()] = offsets[index]; } @@ -433,6 +478,16 @@ LogicalResult updateFromSubView(memref::SubViewOp subviewOp, sizes[insertionIndex] = sizes[extractionIndex]; strides[insertionIndex] = strides[extractionIndex]; insertionIndex++; + } else { + // TODO(newling) add a test of this path. + // If the offset is non-zero, we shouldn't just be dropping it. For now, + // just bail. + OpFoldResult offset = offsets[extractionIndex]; + if (isa(offset) || getConstantIntValue(offset).value() != 0) { + return subviewOp->emitOpError( + "cannot update a non-zero offset in a dimension that is being " + "dropped."); + } } } offsets.resize(insertionIndex); @@ -589,19 +644,11 @@ LogicalResult rewriteAsDma(IRRewriter &rewriter, Operation *packOrUnackOp, auto dst = rewriter.create( rewriter.getUnknownLoc(), LogicalObjectFifoType::get(dstType), dstVal); - if (failed(mlir::verify(src)) || failed(mlir::verify(dst))) { - return failure(); - } - rewriter.setInsertionPoint(packOrUnackOp); rewriter.create(packOrUnackOp->getLoc(), dst, dstOffsets, dstShape, dstBaseStrides, src, srcOffsets, srcShape, srcBaseStrides); - if (failed(mlir::verify(packOrUnackOp))) { - return failure(); - } - rewriter.eraseOp(packOrUnackOp); return success(); } diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/convert_to_dma.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/convert_to_dma.mlir index 487c5714a..65c04b0cc 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/convert_to_dma.mlir +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/convert_to_dma.mlir @@ -359,13 +359,13 @@ func.func @unitdim_unpack_expand() { // ----- -// CHECK-LABEL: multidim_with_expand +// CHECK-LABEL: multidim_pack_with_expand // CHECK: amdaie.dma_cpy_nd // dst of dma cpy: // CHECK-SAME: [0, 0, 0, 0] [20, 5, 10, 10] [500, 100, 10, 1] // src of dma cpy: // CHECK-SAME: [0, 0, 0, 0] [20, 5, 10, 10] [500, 10, 50, 1] -func.func @multidim_with_expand() { +func.func @multidim_pack_with_expand() { %src = memref.alloc() : memref<200x50xi32, 1> %dst = memref.alloc() : memref<100x100xi32, 2> %dst_e = memref.expand_shape %dst [[0, 1], [2, 3]] output_shape [20, 5, 10, 10] @@ -378,14 +378,14 @@ func.func @multidim_with_expand() { // ----- // This test is included to illustrate that the dma copy is the same without the -// expand operation (compare to multidim_with_expand above). -// CHECK-LABEL: multidim_without_expand +// expand operation (compare to multidim_pack_with_expand above). +// CHECK-LABEL: multidim_pack_without_expand // CHECK: amdaie.dma_cpy_nd // dst of dma cpy: // CHECK-SAME: [0, 0, 0, 0] [20, 5, 10, 10] [500, 100, 10, 1] // src of dma cpy: // CHECK-SAME: [0, 0, 0, 0] [20, 5, 10, 10] [500, 10, 50, 1] -func.func @multidim_without_expand() { +func.func @multidim_pack_without_expand() { %src = memref.alloc() : memref<200x50xi32, 1> %dst = memref.alloc() : memref<20x5x10x10xi32, 2> iree_linalg_ext.pack %src inner_dims_pos = [0, 1] inner_tiles = [10, 10] @@ -395,14 +395,14 @@ func.func @multidim_without_expand() { // ----- -// CHECK-LABEL: @subview_then_collapse(%arg0: index) +// CHECK-LABEL: @pack_subview_then_collapse(%arg0: index) // CHECK: %[[ALLOC0:.*]] = memref.alloc() : memref<20x10xf32> // CHECK: %[[C10:.*]] = arith.constant 10 : index // CHECK: %[[MULI:.*]] = arith.muli %arg0, %[[C10]] : index // CHECK: amdaie.dma_cpy_nd // CHECK-SAME: [0, 0] [5, 20] [20, 1] // CHECK-SAME: [0, %[[MULI]]] [5, 20] [20, 1] -func.func @subview_then_collapse(%arg0 : index) { +func.func @pack_subview_then_collapse(%arg0 : index) { %src = memref.alloc() : memref<20x10xf32> %subview = memref.subview %src[%arg0, 0] [10, 10] [1, 1] : memref<20x10xf32> to memref<10x10xf32, strided<[10, 1], offset: ?>> @@ -416,12 +416,12 @@ func.func @subview_then_collapse(%arg0 : index) { // ----- -// CHECK-LABEL: @subview_then_expand +// CHECK-LABEL: @pack_subview_then_expand // CHECK: amdaie.dma_cpy_nd // CHECK-SAME: [0, 0, %arg0, 0, 0] [2, 3, 6, 6, 1] [300, 100, 10, 1, 1] // CHECK-SAME: [0, 0, 0, 0, 0] [2, 3, 6, 6, 1] [108, 6, 1, 18, 6] module { - func.func @subview_then_expand(%arg0: index) { + func.func @pack_subview_then_expand(%arg0: index) { %alloc = memref.alloc() : memref<10x10x10xf32> %subview = memref.subview %alloc[0, %arg0, 0] [6, 6, 6] [1, 1, 1] : memref<10x10x10xf32> to memref<6x6x6xf32, strided<[100, 10, 1], offset: ?>> @@ -437,12 +437,12 @@ module { // ----- -// CHECK-LABEL: @subview_then_subview(%arg0: index, %arg1: index) +// CHECK-LABEL: @unpack_subview_then_subview(%arg0: index, %arg1: index) // CHECK: %[[SUM:.*]] = arith.addi %arg0, %arg1 : index // CHECK: amdaie.dma_cpy_nd // CHECK-SAME: [0] [100] [1], // CHECK-SAME: [%[[SUM]], 5] [10, 10] [20, 1] -func.func @subview_then_subview(%arg0 : index, %arg1 : index){ +func.func @unpack_subview_then_subview(%arg0 : index, %arg1 : index){ %src = memref.alloc() : memref<20x20xf32> %subview0 = memref.subview %src[%arg0, 2] [15, 15] [1, 1] : memref<20x20xf32> to memref<15x15xf32, strided<[20, 1], offset: ?>> @@ -456,7 +456,7 @@ func.func @subview_then_subview(%arg0 : index, %arg1 : index){ // ----- -// CHECK-LABEL: @subview_then_expand_1(%arg0: index) +// CHECK-LABEL: @unpack_subview_then_expand_1(%arg0: index) // CHECK: amdaie.dma_cpy_nd // CHECK-SAME: [0, 0] [25, 4] [4, 1] // We might want to change the offsets to be @@ -464,7 +464,7 @@ func.func @subview_then_subview(%arg0 : index, %arg1 : index){ // in the future, but as the offsets ultimately get collapsed into a single // global cumulative offset, this would just be undone. // CHECK-SAME: [0, 0, %arg0, 2] [5, 5, 2, 2] [40, 2, 20, 1] -func.func @subview_then_expand_1(%arg0 : index){ +func.func @unpack_subview_then_expand_1(%arg0 : index){ %src = memref.alloc() : memref<20x20xf32> %subview = memref.subview %src[%arg0, 2] [10, 10] [1, 1] : memref<20x20xf32> to memref<10x10xf32, strided<[20, 1], offset: ?>>