Skip to content

[mlir][linalg] Enable fuse consumer #85528

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 61 additions & 6 deletions mlir/include/mlir/Interfaces/TilingInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
The method returns the operation that is the tiled
implementation.
}],
/*retType=*/"FailureOr<TilingResult>",
/*retType=*/"FailureOr<::mlir::TilingResult>",
/*methodName=*/"getTiledImplementation",
/*args=*/(ins
"OpBuilder &":$b,
Expand All @@ -82,15 +82,34 @@ def TilingInterface : OpInterface<"TilingInterface"> {
by the tiled implementation. Expects the same `offsets` and `sizes` as
used to obtain the tiled implementation of the operation.
}],
/*retType=*/"LogicalResult",
/*retType=*/"::mlir::LogicalResult",
/*methodName=*/"getResultTilePosition",
/*args=*/(ins
"OpBuilder &":$b,
"unsigned":$resultNumber,
"ArrayRef<OpFoldResult> ":$offsets,
"ArrayRef<OpFoldResult> ":$sizes,
"SmallVector<OpFoldResult> &":$resultOffsets,
"SmallVector<OpFoldResult> &":$resultSizes),
"SmallVectorImpl<OpFoldResult> &":$resultOffsets,
"SmallVectorImpl<OpFoldResult> &":$resultSizes),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
}]
>,
InterfaceMethod<
/*desc=*/[{
Method to return the position of iteration domain tile computed by the
tiled operation.
}],
/*retType=*/"::mlir::LogicalResult",
/*methodName=*/"getIterationDomainTileFromOperandTile",
/*args=*/(ins
"OpBuilder &":$b,
"unsigned":$operandNumber,
"ArrayRef<OpFoldResult> ":$offsets,
"ArrayRef<OpFoldResult> ":$sizes,
"SmallVectorImpl<OpFoldResult> &":$iterDomainOffsets,
"SmallVectorImpl<OpFoldResult> &":$iterDomainSizes),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
Expand Down Expand Up @@ -119,7 +138,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
iteration space).
- `sizes` provides the size of the tile.
}],
/*retType=*/"FailureOr<TilingResult>",
/*retType=*/"FailureOr<::mlir::TilingResult>",
/*methodName=*/"generateResultTileValue",
/*args=*/(ins
"OpBuilder &":$b,
Expand All @@ -131,6 +150,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return failure();
}]
>,
InterfaceMethod<
/*desc=*/[{
Method to generate the tiled implementation of an operation from
operand tile position.

Generates the IR that computes the tiled implementation of an
operation from operand tile. The `offsets` and `sizes`
describe the tile of the operand required. This is different from
`getTiledImplementation` which generates the tiled
implementation of the operation given a tile of the
iteration space. This method generates a tiled
implementation of the operation based on the tile of the
operand required. This method enables consumer fusion by using
tile and fuse. The method returns failure if the operation
can't be tiled to generate the operand tile. In practical terms
this implies it cannot be tiled and fused with its producers.

- `offsets` provides the offset of the tile in the coordinate system
of the original iteration space, i.e., if an iteration space
dimension had non-zero offset, it must be included in the offset
provided here (as opposed to zero-based offset "relative" to the
iteration space).
- `sizes` provides the size of the tile.
}],
/*retType=*/"FailureOr<::mlir::TilingResult>",
/*methodName=*/"getTiledImplementationFromOperandTile",
/*args=*/(ins
"OpBuilder &":$b,
"unsigned":$operandNumber,
"ArrayRef<OpFoldResult>":$offsets,
"ArrayRef<OpFoldResult>":$sizes),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
}]
>,
InterfaceMethod<
/*desc=*/[{
Generates the scalar implementation of the operation.
Expand All @@ -142,7 +197,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
transformations are done, this method can be used to lower to scalar
code that can then be lowered to LLVM or SPIR-V dialects.
}],
/*retType=*/"LogicalResult",
/*retType=*/"::mlir::LogicalResult",
/*methodName=*/"generateScalarImplementation",
/*args=*/(ins
"OpBuilder &":$b,
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2425,8 +2425,8 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,

LogicalResult SoftmaxOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
ArrayRef<OpFoldResult> sizes, SmallVectorImpl<OpFoldResult> &resultOffsets,
SmallVectorImpl<OpFoldResult> &resultSizes) {
if (resultNumber == 0) {
resultOffsets.assign(offsets.begin(), offsets.end());
resultSizes.assign(sizes.begin(), sizes.end());
Expand Down
106 changes: 80 additions & 26 deletions mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ struct LinalgOpTilingInterface
}));
}

// Instantiate the tiled implementation of the operation.
/// Instantiate the tiled implementation of the operation.
FailureOr<TilingResult>
getTiledImplementation(Operation *op, OpBuilder &b,
ArrayRef<OpFoldResult> offsets,
Expand All @@ -132,14 +132,66 @@ struct LinalgOpTilingInterface
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
}

// Return the details of the output tile generated by the tiled
// implementation.
void
getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
SmallVectorImpl<OpFoldResult> &mappedOffsets,
SmallVectorImpl<OpFoldResult> &mappedSizes) const {
unsigned numLoops = linalgOp.getNumLoops();
auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
mappedOffsets.resize(numLoops);
mappedSizes.resize(numLoops);
if (!indexingMap.isPermutation()) {
SmallVector<Range> iterationDomain =
tilingInterfaceOp.getIterationDomain(b);
for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
mappedOffsets[index] = value.offset;
mappedSizes[index] = value.size;
}
}
for (const auto &&[index, value] :
llvm::enumerate(indexingMap.getResults())) {
unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
mappedOffsets[dimPosition] = offsets[index];
mappedSizes[dimPosition] = sizes[index];
}
}

/// Return the details of the output tile generated by the tiled
/// implementation.
LogicalResult getIterationDomainTileFromOperandTile(
Operation *op, OpBuilder &b, unsigned operandNumber,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
auto linalgOp = cast<LinalgOp>(op);

// Check that the indexing map used for the operand is a projected
// permutation. This could be relaxed with a more general approach that can
// map the offsets and sizes from the operand to iteration space tiles
// (filling in full extent for dimensions not used to access the result).
AffineMap indexingMap =
linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
if (!indexingMap.isProjectedPermutation()) {
return emitError(op->getLoc(),
"unhandled get iter domain position when operand is not "
"accessed using a permuted projection");
}

getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
iterDomainOffsets, iterDomainSizes);
return success();
}

/// Return the details of the output tile generated by the tiled
/// implementation.
LogicalResult
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) const {
SmallVectorImpl<OpFoldResult> &resultOffsets,
SmallVectorImpl<OpFoldResult> &resultSizes) const {
Location loc = op->getLoc();
LinalgOp linalgOp = cast<LinalgOp>(op);

Expand All @@ -160,6 +212,21 @@ struct LinalgOpTilingInterface
return success();
}

FailureOr<TilingResult> getTiledImplementationFromOperandTile(
Operation *op, OpBuilder &b, unsigned operandNumber,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
auto tilingInterfaceOp = cast<TilingInterface>(op);
if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
return emitError(
op->getLoc(),
"unable to obtain the iter domain position of the operation.");
}
return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
mappedSizes);
}

FailureOr<TilingResult>
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
Expand All @@ -177,29 +244,16 @@ struct LinalgOpTilingInterface
"unhandled tiled implementation generation when result is not "
"accessed using a permuted projection");
}

auto numLoops = linalgOp.getNumLoops();
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
mappedOffsets, mappedSizes);
auto tilingInterfaceOp = cast<TilingInterface>(op);
SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
iterationTileSizes(numLoops);
if (!indexingMap.isPermutation()) {
SmallVector<Range> iterationDomain =
tilingInterfaceOp.getIterationDomain(b);
for (const auto &range : llvm::enumerate(iterationDomain)) {
iterationTileOffsets[range.index()] = range.value().offset;
iterationTileSizes[range.index()] = range.value().size;
}
}
for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
unsigned dimPosition =
cast<AffineDimExpr>(resultExpr.value()).getPosition();
iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
}

FailureOr<TilingResult> tilingResult =
tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
iterationTileSizes);
tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);

if (failed(tilingResult))
return failure();

if (tilingResult->tiledOps.size() != 1)
return op->emitOpError("failed to generate tiled implementation");

Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) const {
SmallVectorImpl<OpFoldResult> &resultOffsets,
SmallVectorImpl<OpFoldResult> &resultSizes) const {
resultOffsets.assign(offsets.begin(), offsets.end());
resultSizes.assign(sizes.begin(), sizes.end());
return success();
Expand Down Expand Up @@ -199,8 +199,8 @@ struct PackOpTiling
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) const {
SmallVectorImpl<OpFoldResult> &resultOffsets,
SmallVectorImpl<OpFoldResult> &resultSizes) const {
// The iteration domain is over outer dimensions of packed layout. In this
// context, the outer dimensions of `resultOffsets` are `offsets`. The
// inner dimensions of `resultOffsets` are zeros because tiling is not
Expand Down Expand Up @@ -452,8 +452,8 @@ struct UnPackOpTiling
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) const {
SmallVectorImpl<OpFoldResult> &resultOffsets,
SmallVectorImpl<OpFoldResult> &resultSizes) const {
resultOffsets = llvm::to_vector(offsets);
resultSizes = llvm::to_vector(sizes);
return success();
Expand Down