Skip to content

Commit

Permalink
tidy up
Browse files Browse the repository at this point in the history
  • Loading branch information
newling committed Sep 25, 2024
1 parent d019580 commit 762fa08
Showing 1 changed file with 47 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ LogicalResult updateFromPack(IREE::LinalgExt::PackOp packOp,
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();

assert(offsets.size() == sizes.size() && sizes.size() == strides.size() &&
"offsets, sizes, and strides must have the same size");
"offsets, sizes, and strides must have the same size,");
for (int64_t dim : innerDimsPos) {
assert(dim < sizes.size() && "innerDimsPos must be within sizes");
assert(dim < sizes.size() && "innerDimsPos must be within sizes.");
}

SmallVector<OpFoldResult> innerSizes;
Expand All @@ -53,8 +53,7 @@ LogicalResult updateFromPack(IREE::LinalgExt::PackOp packOp,
innerSizes.push_back(getAsIndexOpFoldResult(ctx, innerTiles[i]));
std::optional<int64_t> maybeSize =
getConstantIntValue(sizes[innerDimsPos[i]]);
if (!maybeSize.has_value())
packOp->emitOpError("requires a constant size here.");
assert(maybeSize.has_value() && "size expected to be constant here.");
int64_t size = maybeSize.value();
if (size % innerTiles[i] != 0) {
auto message = llvm::formatv(
Expand Down Expand Up @@ -155,44 +154,47 @@ LogicalResult updateFromUnPack(IREE::LinalgExt::UnPackOp unPackOp,
return success();
}

static bool isAllocation(Operation *op) {
return op && (isa<memref::AllocOp>(op) ||
isa<IREE::HAL::InterfaceBindingSubspanOp>(op));
}

/// Initialize offsets, sizes, and strides from an allocation operation.
LogicalResult setFromAlloc(Operation *op, SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes,
SmallVector<OpFoldResult> &strides) {
assert(isa<memref::AllocOp>(op) ||
isa<IREE::HAL::InterfaceBindingSubspanOp>(op) &&
"op must be a memref.alloc or hal.interface.binding.subspan "
"operation");
assert(isAllocation(op) &&
"expected memref.alloc or hal.interface.binding.subspan ");

MemRefType memRefType = cast<MemRefType>(op->getResult(0).getType());
MLIRContext *ctx = memRefType.getContext();
auto [stridesI64, baseOffset] = getStridesAndOffset(memRefType);
strides = getAsIndexOpFoldResult(ctx, stridesI64);

if (baseOffset != 0) {
auto message = llvm::formatv(
"has non-zero base offset {0} is not currently supported.", baseOffset);
"has non-zero offset {0} which is currently unsupported.", baseOffset);
return op->emitOpError(message);
}
strides = getAsIndexOpFoldResult(ctx, stridesI64);
offsets.resize(strides.size(), getAsIndexOpFoldResult(ctx, 0));

ArrayRef<int64_t> sizesI64 = memRefType.getShape();
if (llvm::any_of(sizesI64,
[](int64_t size) { return ShapedType::isDynamic(size); })) {
return op->emitOpError("has dynamic size, which is not supported in dma.");
return op->emitOpError("has dynamic size, which is unsupported in DMA.");
}
sizes = getAsIndexOpFoldResult(ctx, sizesI64);

// Alloc Op has no offsets.
for (int i = 0; i < sizes.size(); i++) {
offsets.push_back(getAsIndexOpFoldResult(ctx, 0));
}
return success();
}

/// Return a + b.
SmallVector<OpFoldResult> getSum(OpBuilder &builder, Location loc,
ArrayRef<OpFoldResult> a,
ArrayRef<OpFoldResult> b) {
assert(a.size() == b.size() && "a and b not same size");
SmallVector<OpFoldResult> getIndexOpFoldResultSum(OpBuilder &builder,
Location loc,
ArrayRef<OpFoldResult> lhs,
ArrayRef<OpFoldResult> rhs) {
assert(lhs.size() == rhs.size() && "a and b not same size");
SmallVector<OpFoldResult> sum;
sum.reserve(a.size());
sum.reserve(lhs.size());

auto getConstant = [&](int64_t v) -> Value {
return builder.create<arith::ConstantOp>(
Expand All @@ -206,15 +208,15 @@ SmallVector<OpFoldResult> getSum(OpBuilder &builder, Location loc,
.getResult();
};

for (uint64_t i = 0; i < a.size(); ++i) {
for (uint64_t i = 0; i < lhs.size(); ++i) {
IntegerAttr aAttr;
if (auto aAttr_ = dyn_cast<Attribute>(a[i])) {
if (auto aAttr_ = dyn_cast<Attribute>(lhs[i])) {
aAttr = dyn_cast<IntegerAttr>(aAttr_);
assert(aAttr && "Expected an IntegerAttr");
}

IntegerAttr bAttr;
if (auto bAttr_ = dyn_cast<Attribute>(b[i])) {
if (auto bAttr_ = dyn_cast<Attribute>(rhs[i])) {
bAttr = dyn_cast<IntegerAttr>(bAttr_);
assert(bAttr && "Expected an IntegerAttr");
}
Expand All @@ -223,20 +225,19 @@ SmallVector<OpFoldResult> getSum(OpBuilder &builder, Location loc,
sum.push_back(getAsIndexOpFoldResult(builder.getContext(),
aAttr.getInt() + bAttr.getInt()));
} else if (!aAttr && !bAttr) {
sum.push_back(
builder
.create<arith::AddIOp>(loc, cast<Value>(a[i]), cast<Value>(b[i]))
.getResult());
sum.push_back(builder
.create<arith::AddIOp>(loc, cast<Value>(lhs[i]),
cast<Value>(rhs[i]))
.getResult());
} else if (!aAttr && bAttr) {
sum.push_back(add(cast<Value>(a[i]), bAttr));
sum.push_back(add(cast<Value>(lhs[i]), bAttr));
} else if (aAttr && !bAttr) {
sum.push_back(add(cast<Value>(b[i]), aAttr));
sum.push_back(add(cast<Value>(rhs[i]), aAttr));
} else {
assert(false && "unreachable");
}
}

assert(sum.size() == a.size() && "sum and a not same size");
return sum;
}

Expand Down Expand Up @@ -296,13 +297,15 @@ OpFoldResult getLinearCombination(OpBuilder &builder, Location loc,
return combination;
}


/// Update the offsets, sizes, and strides from a collapse shape operation.
LogicalResult updateFromCollapseShape(memref::CollapseShapeOp collapseOp,
SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes,
SmallVector<OpFoldResult> &strides) {
auto reassociationIndices = collapseOp.getReassociationIndices();
ArrayRef<int64_t> resultShape = collapseOp.getType().getShape();
ArrayRef<int64_t> inputShape = collapseOp.getSrcType().getShape();
ArrayRef<int64_t> resultShape = collapseOp.getType().getShape();
uint64_t resultRank = resultShape.size();
MLIRContext *ctx = collapseOp.getContext();

Expand All @@ -314,14 +317,14 @@ LogicalResult updateFromCollapseShape(memref::CollapseShapeOp collapseOp,
}
strides.resize(resultRank);

// Set sizes to output shape, ensuring that all dims are static.
// Set sizes to output shape, and check that all dims are static.
sizes.clear();
for (int64_t dim : resultShape) {
if (dim == ShapedType::kDynamic) {
return collapseOp.emitOpError(
"has a dynamic shape which is currently unsupported.");
}
sizes.push_back(getAsIndexOpFoldResult(collapseOp.getContext(), dim));
sizes.push_back(getAsIndexOpFoldResult(ctx, dim));
}

// Offsets - merge reassocation groups.
Expand All @@ -348,6 +351,7 @@ LogicalResult updateFromCollapseShape(memref::CollapseShapeOp collapseOp,
return success();
}

/// Update the offsets, sizes, and strides from an expand shape operation.
LogicalResult updateFromExpandShape(memref::ExpandShapeOp expandShapeOp,
SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes,
Expand Down Expand Up @@ -375,9 +379,10 @@ LogicalResult updateFromExpandShape(memref::ExpandShapeOp expandShapeOp,
}
}

// Offsets. For now we don't do any modular arithmetic, in theory we need to
// split the offset amongst the reassociation indices, but for now I'm just
// putting the offset on the inner most dimension.
// Offsets. For now we don't do any arithmetic to split the offset across
// dimensions, in theory we need to split the offset amongst the reassociation
// indices, but for now I'm just putting the offset on the inner most
// dimension.
SmallVector<OpFoldResult> newOffsets(resultShape.size());
for (int i = 0; i < resultShape.size(); i++) {
newOffsets[i] = getAsIndexOpFoldResult(ctx, 0);
Expand All @@ -394,6 +399,8 @@ LogicalResult updateFromExpandShape(memref::ExpandShapeOp expandShapeOp,
return success();
}


/// Update the offsets, sizes, and strides from a subview operation.
LogicalResult updateFromSubView(memref::SubViewOp subviewOp,
SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes,
Expand All @@ -402,8 +409,8 @@ LogicalResult updateFromSubView(memref::SubViewOp subviewOp,

OpBuilder builder(subviewOp.getContext());
builder.setInsertionPoint(subviewOp);
offsets =
getSum(builder, subviewOp.getLoc(), offsets, subviewOp.getMixedOffsets());
offsets = getIndexOpFoldResultSum(builder, subviewOp.getLoc(), offsets,
subviewOp.getMixedOffsets());

sizes = subviewOp.getMixedSizes();
if (llvm::any_of(sizes, [](OpFoldResult size) {
Expand Down Expand Up @@ -435,7 +442,7 @@ LogicalResult updateFromSubView(memref::SubViewOp subviewOp,
}

/// Provide the offsets, sizes and strides of the inputs to `operandOp`.
/// This function update `operandOp`, setting it to the allocation operation
/// This function updates `operandOp`, setting it to the allocation operation
/// that it originates from.
LogicalResult setDmaInputs(Operation *&operandOp,
SmallVector<OpFoldResult> &offsets,
Expand All @@ -446,11 +453,6 @@ LogicalResult setDmaInputs(Operation *&operandOp,

if (!operandOp) assert(false && "operandOp must be non-null");

auto isAllocation = [](Operation *op) {
return op && (isa<memref::AllocOp>(op) ||
isa<IREE::HAL::InterfaceBindingSubspanOp>(op));
};

// Get the sequence of memref operations going from an allocation to
// `operandOp`
SmallVector<Operation *> chain;
Expand Down

0 comments on commit 762fa08

Please sign in to comment.