Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
newling committed Sep 27, 2024
1 parent 085d517 commit 86ed927
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ LogicalResult updateFromPack(IREE::LinalgExt::PackOp packOp,
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();

assert(offsets.size() == sizes.size() && sizes.size() == strides.size() &&
"offsets, sizes, and strides must have the same size,");
"offsets, sizes, and strides must have the same size.");
for (int64_t dim : innerDimsPos) {
assert(dim < sizes.size() && "innerDimsPos must be within sizes.");
}
Expand All @@ -53,7 +53,9 @@ LogicalResult updateFromPack(IREE::LinalgExt::PackOp packOp,
innerSizes.push_back(getAsIndexOpFoldResult(ctx, innerTiles[i]));
std::optional<int64_t> maybeSize =
getConstantIntValue(sizes[innerDimsPos[i]]);
assert(maybeSize.has_value() && "size expected to be constant here.");
if (!maybeSize.has_value()) {
packOp->emitOpError("requires all constant sizes.");
}
int64_t size = maybeSize.value();
if (size % innerTiles[i] != 0) {
auto message = llvm::formatv(
Expand All @@ -67,12 +69,13 @@ LogicalResult updateFromPack(IREE::LinalgExt::PackOp packOp,
// The tiled dim inherits the stride from the corresponding outer dim and
// the outer dims stride gets multiplied by the size of the tile.
innerStrides.push_back(strides[innerDimsPos[i]]);
std::optional<int64_t> stride =
std::optional<int64_t> maybeStride =
getConstantIntValue(strides[innerDimsPos[i]]);
if (!stride.has_value())
if (!maybeStride.has_value())
packOp->emitOpError("requires a constant stride here.");
int64_t stride = maybeStride.value();
strides[innerDimsPos[i]] =
getAsIndexOpFoldResult(ctx, stride.value() * innerTiles[i]);
getAsIndexOpFoldResult(ctx, stride * innerTiles[i]);

// The tiled dim inherits the offset from the corresponding outer dim and
// the outer dim offset is set to zero.
Expand Down Expand Up @@ -164,7 +167,7 @@ LogicalResult setFromAlloc(Operation *op, SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes,
SmallVector<OpFoldResult> &strides) {
assert(isAllocation(op) &&
"expected memref.alloc or hal.interface.binding.subspan ");
"expected memref.alloc or hal.interface.binding.subspan.");

MemRefType memRefType = cast<MemRefType>(op->getResult(0).getType());
MLIRContext *ctx = memRefType.getContext();
Expand All @@ -187,12 +190,15 @@ LogicalResult setFromAlloc(Operation *op, SmallVector<OpFoldResult> &offsets,
return success();
}

/// Return a + b.
/// Return `lhs` + `rhs`, where `lhs` and `rhs` are OpFoldResults of integers.
///
/// The implementation considers the 4 cases of
/// (`lhs`, `rhs`) in {Attribute, Value} x {Attribute, Value}.
SmallVector<OpFoldResult> getIndexOpFoldResultSum(OpBuilder &builder,
Location loc,
ArrayRef<OpFoldResult> lhs,
ArrayRef<OpFoldResult> rhs) {
assert(lhs.size() == rhs.size() && "a and b not same size");
assert(lhs.size() == rhs.size() && "lhs and rhs not same size.");
SmallVector<OpFoldResult> sum;
sum.reserve(lhs.size());

Expand All @@ -212,13 +218,13 @@ SmallVector<OpFoldResult> getIndexOpFoldResultSum(OpBuilder &builder,
IntegerAttr aAttr;
if (auto aAttr_ = dyn_cast<Attribute>(lhs[i])) {
aAttr = dyn_cast<IntegerAttr>(aAttr_);
assert(aAttr && "Expected an IntegerAttr");
assert(aAttr && "Expected IntegerAttr.");
}

IntegerAttr bAttr;
if (auto bAttr_ = dyn_cast<Attribute>(rhs[i])) {
bAttr = dyn_cast<IntegerAttr>(bAttr_);
assert(bAttr && "Expected an IntegerAttr");
assert(bAttr && "Expected IntegerAttr.");
}

if (aAttr && bAttr) {
Expand All @@ -241,7 +247,11 @@ SmallVector<OpFoldResult> getIndexOpFoldResultSum(OpBuilder &builder,
return sum;
}

/// Return sum_{i} values[i] * coeffs[i].
/// Return sum_{i} values[i] * coeffs[i], where
///
/// - values are OpFoldResults (i.e. each element in `values` is
/// either an mlir::Value or mlir::Attribute)
/// - coeffs are integers.
OpFoldResult getLinearCombination(OpBuilder &builder, Location loc,
ArrayRef<OpFoldResult> values,
ArrayRef<int64_t> coeffs) {
Expand All @@ -255,49 +265,57 @@ OpFoldResult getLinearCombination(OpBuilder &builder, Location loc,
};

// Initialize the linear combination to 0.
OpFoldResult combination = builder.getIndexAttr(0);
OpFoldResult lc = builder.getIndexAttr(0);

// For eacho of the (value, coeff) pairs, add the product to the linear
// combination, updating `lc` in each iteration. The implementation
// here is careful not to create constant zero values.
for (uint64_t dim = 0; dim < coeffs.size(); ++dim) {
// Four cases considered:
// 1) `values[dim]` is an attribute (constant) and `lc` is also
// an attribute (constant)
// 2) `values[dim]` is an attribute (constant) and `lc` is a Value
// (non-constant)
// 3) `values[dim]` is a Value (non-constant) and `lc` is an
// attribute (constant)
// 4) `values[dim]` is a Value (non-constant) and `lc` is also a
// Value (non-constant)
if (auto valueAttr = dyn_cast<Attribute>(values[dim])) {
int64_t term = coeffs[dim] * cast<IntegerAttr>(valueAttr).getInt();

// Case where both `combination` and `value` are constant.
if (auto combinationAttr = dyn_cast<Attribute>(combination)) {
combination = getAsIndexOpFoldResult(
ctx, term + cast<IntegerAttr>(combinationAttr).getInt());
// Case 1.
if (auto lcAttr = dyn_cast<Attribute>(lc)) {
lc = getAsIndexOpFoldResult(ctx,
term + cast<IntegerAttr>(lcAttr).getInt());
}

// Case where `combination` is not constant, `value` is constant.
// Case 2.
else if (term != 0) {
combination = builder
.create<arith::AddIOp>(loc, cast<Value>(combination),
getConstant(term))
.getResult();
lc = builder
.create<arith::AddIOp>(loc, cast<Value>(lc), getConstant(term))
.getResult();
}
} else {
Value term = builder.create<arith::MulIOp>(loc, cast<Value>(values[dim]),
getConstant(coeffs[dim]));
// Case where `combination` is constant, `value` is not constant.
if (auto combinationAttr = dyn_cast<Attribute>(combination)) {
int64_t c = cast<IntegerAttr>(combinationAttr).getInt();
// Case 3.
if (auto lcAttr = dyn_cast<Attribute>(lc)) {
int64_t c = cast<IntegerAttr>(lcAttr).getInt();
if (c != 0) {
combination = builder.create<arith::AddIOp>(loc, getConstant(c), term)
.getResult();
lc = builder.create<arith::AddIOp>(loc, getConstant(c), term)
.getResult();
} else {
combination = term;
lc = term;
}
} else {
// Case where neither `combination` nor `value` is constant.
Value combinationVal = cast<Value>(combination);
combination = builder.create<arith::AddIOp>(loc, combinationVal, term)
.getResult();
}
// Case 4.
else {
lc = builder.create<arith::AddIOp>(loc, cast<Value>(lc), term)
.getResult();
}
}
}
return combination;
return lc;
}


/// Update the offsets, sizes, and strides from a collapse shape operation.
LogicalResult updateFromCollapseShape(memref::CollapseShapeOp collapseOp,
SmallVector<OpFoldResult> &offsets,
Expand All @@ -310,6 +328,11 @@ LogicalResult updateFromCollapseShape(memref::CollapseShapeOp collapseOp,
MLIRContext *ctx = collapseOp.getContext();

// Set strides to inner-most stride in each reassocation group.
//
// Example: Consider a 2x3x5x7 tensor, with strides [70,35,7,1]. If this
// is collapsed to shape 6x35, the srides are [35,1]. The reassociation
// groups are [0,1] and [2,3], and so we've just taken the inner-most
// strides in each group.
for (auto reassociation : llvm::enumerate(reassociationIndices)) {
uint64_t index = reassociation.index();
uint64_t dim = reassociation.value().back();
Expand All @@ -327,15 +350,21 @@ LogicalResult updateFromCollapseShape(memref::CollapseShapeOp collapseOp,
sizes.push_back(getAsIndexOpFoldResult(ctx, dim));
}

// Offsets - merge reassocation groups.
// Offsets - merge reassocation groups by taking linear combinations of the
// offsets with local strides. Using the example of the shape of 2x3x5x7
// being collapsed to 6x35, if the initial offsets are [a,b,c,d], the
// collapsed offsets are [a*3 + b, c*7 + d].
SmallVector<OpFoldResult> collapsedOffsets;
for (auto reassociation : llvm::enumerate(reassociationIndices)) {
auto dims = reassociation.value();

// The strides within the group:
SmallVector<int64_t> localStrides(dims.size(), 1);
for (uint64_t i = 1; i < dims.size(); ++i) {
uint64_t dim = dims.size() - i - 1;
localStrides[dim] = localStrides[dim + 1] * inputShape[dims[dim + 1]];
}

OpBuilder builder(ctx);
builder.setInsertionPoint(collapseOp);
OpFoldResult combination = getLinearCombination(
Expand All @@ -360,35 +389,51 @@ LogicalResult updateFromExpandShape(memref::ExpandShapeOp expandShapeOp,
auto reassociationIndices = expandShapeOp.getReassociationIndices();
ArrayRef<int64_t> resultShape = expandShapeOp.getType().getShape();

// Sizes.
// Set the sizes to the output shape, and check that all dims are static.
SmallVector<OpFoldResult> newSizes(resultShape.size());
for (int i = 0; i < resultShape.size(); i++) {
if (resultShape[i] == ShapedType::kDynamic) {
return expandShapeOp.emitOpError(
"has a dynamic shape which is currently unsupported.");
}
newSizes[i] = getAsIndexOpFoldResult(ctx, resultShape[i]);
}

// Strides.
// Strides. Using the example expanding from a shape of 6x35 to 2x3x5x7, where
// the initial strides are [50, 1], the new strides will be [150, 50, 7, 1].
SmallVector<OpFoldResult> newStrides(resultShape.size());
for (auto reassociation : llvm::enumerate(reassociationIndices)) {
auto index = reassociation.index();
uint64_t index = reassociation.index();
auto dims = reassociation.value();
int64_t cum = getConstantIntValue(strides[index]).value();
OpFoldResult stride = strides[index];
if (!isa<Attribute>(stride)) {
return expandShapeOp.emitOpError("cannot operate on a dynamic stride.");
}
int64_t cum = getConstantIntValue(stride).value();
for (uint64_t i = 0; i < dims.size(); i++) {
auto d = dims[dims.size() - i - 1];
uint64_t d = dims[dims.size() - i - 1];
newStrides[d] = getAsIndexOpFoldResult(ctx, cum);
cum *= resultShape[d];
}
}

// Offsets. For now we don't do any arithmetic to split the offset across
// dimensions, in theory we need to split the offset amongst the reassociation
// indices, but for now I'm just putting the offset on the inner most
// indices, but for now we're just putting the offset on the inner most
// dimension.
//
// Example: suppose we're expanding from 6x35 to 2x3x5x7, and the initial
// offsets are [a, b]. The new offsets will be [0, a, 0, b]. In theory they
// should be [a/3, a%3, b/7 b%7] but these offsets ultimately get collapsed
// anyway so it doesn't matter if we don't.
SmallVector<OpFoldResult> newOffsets(resultShape.size());
// Initialize all ofsets to 0:
for (int i = 0; i < resultShape.size(); i++) {
newOffsets[i] = getAsIndexOpFoldResult(ctx, 0);
}
// Populate the inner-most dimensions with the original offsets:
for (auto reassociation : llvm::enumerate(reassociationIndices)) {
auto index = reassociation.index();
uint64_t index = reassociation.index();
auto dims = reassociation.value();
newOffsets[dims.back()] = offsets[index];
}
Expand Down Expand Up @@ -433,6 +478,16 @@ LogicalResult updateFromSubView(memref::SubViewOp subviewOp,
sizes[insertionIndex] = sizes[extractionIndex];
strides[insertionIndex] = strides[extractionIndex];
insertionIndex++;
} else {
// TODO(newling) add a test of this path.
// If the offset is non-zero, we shouldn't just be dropping it. For now,
// just bail.
OpFoldResult offset = offsets[extractionIndex];
if (isa<Value>(offset) || getConstantIntValue(offset).value() != 0) {
return subviewOp->emitOpError(
"cannot update a non-zero offset in a dimension that is being "
"dropped.");
}
}
}
offsets.resize(insertionIndex);
Expand Down Expand Up @@ -589,19 +644,11 @@ LogicalResult rewriteAsDma(IRRewriter &rewriter, Operation *packOrUnackOp,
auto dst = rewriter.create<AMDAIE::LogicalObjectFifoFromMemrefOp>(
rewriter.getUnknownLoc(), LogicalObjectFifoType::get(dstType), dstVal);

if (failed(mlir::verify(src)) || failed(mlir::verify(dst))) {
return failure();
}

rewriter.setInsertionPoint(packOrUnackOp);
rewriter.create<AMDAIE::DmaCpyNdOp>(packOrUnackOp->getLoc(), dst, dstOffsets,
dstShape, dstBaseStrides, src, srcOffsets,
srcShape, srcBaseStrides);

if (failed(mlir::verify(packOrUnackOp))) {
return failure();
}

rewriter.eraseOp(packOrUnackOp);
return success();
}
Expand Down
Loading

0 comments on commit 86ed927

Please sign in to comment.