Skip to content

[mlir][scf] Extend option to yield replacement for multiple results case #93144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,14 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter,
/// where `%0` had other uses as well. If not reconstructed from within the loop
/// body, uses of `%0` could not be replaced, making it still live and the
/// fusion immaterial.
///
/// The @param `yieldResultNumber` decides which result would be yield. If not
/// given, yield all `opResult` of fused producer.
LogicalResult yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
MutableArrayRef<LoopLikeOpInterface> loops);
MutableArrayRef<LoopLikeOpInterface> loops,
ArrayRef<unsigned> yieldResultNumber = ArrayRef<unsigned>{});

/// Transformation information returned after tile and fuse.
struct SCFTileAndFuseResult {
Expand Down
38 changes: 37 additions & 1 deletion mlir/include/mlir/Interfaces/TilingInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def TilingInterface : OpInterface<"TilingInterface"> {
For an operation to be "tiled and fused" with its (already tiled) consumer,
an operation has to implement the following additional method (see
description below):
- `generateResultTileValue
- `generateResultTileValue`
- `getIterationDomainTileFromResultTile`

For an operation to be "tiled and fused" with its (already tiled) producer,
an operation has to implement the following additional methods (see
Expand Down Expand Up @@ -302,6 +303,41 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return failure();
}]
>,
InterfaceMethod<
/*desc=*/[{
Method to return the tile of the iteration domain based
on the given tile of the certain result.

This method is required to allow operations to be "tiled and fused"
with an (already tiled) consumer. Given a tile of an result,
returns the tile of the iteration space that uses this tile.
- `resultNumber` is the result of the producer used by the consumer.
- `offsets` is the offset of the slice of the producer result used by
the tiled implementation of the consumer.
- `sizes` is the size of the slice of the producer result used by the
consumer.
If fusion of the producer with the consumer is not legal for the
result, or if this mapping cannot be computed, the implementation
should return a failure.

For most cases `generateResultTileValue` could be a implemented using
`getIterationDomainTileFromResultTile` + `getTiledImplementation`
methods.
}],
/*retType=*/"::mlir::LogicalResult",
/*methodName=*/"getIterationDomainTileFromResultTile",
/*args=*/(ins
"OpBuilder &":$b,
"unsigned":$resultNumber,
"ArrayRef<OpFoldResult> ":$offsets,
"ArrayRef<OpFoldResult> ":$sizes,
"SmallVectorImpl<OpFoldResult> &":$iterDomainOffsets,
"SmallVectorImpl<OpFoldResult> &":$iterDomainSizes),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
}]
>,
InterfaceMethod<
/*desc=*/[{
Generates the scalar implementation of the operation.
Expand Down
25 changes: 19 additions & 6 deletions mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,11 @@ struct LinalgOpTilingInterface
return success();
}

FailureOr<TilingResult>
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) const {
LogicalResult getIterationDomainTileFromResultTile(
Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
auto linalgOp = cast<LinalgOp>(op);

// Check that the indexing map used for the output is a projected
Expand All @@ -232,9 +233,21 @@ struct LinalgOpTilingInterface
"unhandled tiled implementation generation when result is not "
"accessed using a permuted projection");
}
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;

getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
mappedOffsets, mappedSizes);
iterDomainOffsets, iterDomainSizes);
return success();
}

FailureOr<TilingResult>
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) const {
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
if (failed(getIterationDomainTileFromResultTile(
op, b, resultNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
return failure();
}
auto tilingInterfaceOp = cast<TilingInterface>(op);
FailureOr<TilingResult> tilingResult =
tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
Expand Down
157 changes: 119 additions & 38 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -940,49 +940,122 @@ mlir::scf::tileAndFuseProducerOfSlice(
LogicalResult mlir::scf::yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
MutableArrayRef<LoopLikeOpInterface> loops) {
MutableArrayRef<LoopLikeOpInterface> loops,
ArrayRef<unsigned> yieldResultNumber) {
if (loops.empty())
return success();

OpResult fusableProducer = fusedProducerInfo.origProducer;
Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer;
FailureOr<Value> initValue = tensor::getOrCreateDestination(
rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
if (succeeded(initValue)) {

YieldTiledValuesFn newYieldValuesFn =
[&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
SmallVector<SmallVector<OpFoldResult>> &tiledSizes)
-> LogicalResult {
OpBuilder::InsertionGuard g(innerRewriter);
if (auto tiledDestStyleOp =
tiledAndFusedProducer
.getDefiningOp<DestinationStyleOpInterface>()) {
rewriter.setInsertionPoint(tiledDestStyleOp);
Value newRegionArg = newRegionIterArgs.back();
Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
*tiledOwner = fusedProducerInfo.tiledOps[0];

Location loc = originalOwner->getLoc();
// a. collect all init Value to be appended
SmallVector<unsigned> initNumberList =
yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
0, originalOwner->getNumResults()))
: llvm::to_vector(yieldResultNumber);
SmallVector<Value> initValueList;
for (const auto &resultNumber : initNumberList) {
FailureOr<Value> initValue = tensor::getOrCreateDestination(
rewriter, loc, originalOwner->getResult(resultNumber));
if (succeeded(initValue)) {
initValueList.push_back(initValue.value());
} else {
return failure();
}
}

YieldTiledValuesFn newYieldValuesFn =
[&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
OpBuilder::InsertionGuard g(innerRewriter);

// get sliceOp tile information
SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
sliceSizes = sliceOp.getMixedSizes();

// expect all strides of sliceOp being 1
if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
return !isConstantIntValue(ofr, 1);
}))
return failure();

unsigned sliceResultNumber =
fusedProducerInfo.origProducer.getResultNumber();

auto tilableOp = cast<TilingInterface>(originalOwner);
// b. get iterDomain Offset and Sizes based on sliceOp tile
SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
// skip tensor.pack/unpack/pad, which expects single opResult
if (tilableOp->getNumResults() > 1 &&
failed(tilableOp.getIterationDomainTileFromResultTile(
rewriter, sliceResultNumber, sliceOffset, sliceSizes,
iterDomainOffset, iterDomainSizes))) {
// In theory, it is unnecessary to raise an error here. Actually although
// it fails to reconstruct the result tensor, it should not broke current
// fusion anyway. The reason why we must return failure currently is that
// the callback function `newYieldValuesFn` will be called after new init
// operand(s) has already been appended. It will take more refactoring to
// make sure the init operands are added consistently in the future. For
// more details, please refer to:
// https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
return failure();
}

// c. calculate offsets and sizes info of all OpResults respectively based
// on iteration Domain Tile
SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
for (const auto &resultNumber : initNumberList) {
if (resultNumber == sliceResultNumber) {
offsetList.push_back(sliceOffset);
sizesList.push_back(sliceSizes);
} else {
assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
// infer result tile according to the iteration domain tile
SmallVector<OpFoldResult> offset, sizes;
if (failed(tilableOp.getResultTilePosition(
rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
offset, sizes))) {
return failure();
}
offsetList.push_back(offset);
sizesList.push_back(sizes);
}
}

// d. create `extract_slice` for `iter_args` for DPS operation if necessary
if (auto tiledDestStyleOp =
dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
rewriter.setInsertionPoint(tiledDestStyleOp);
for (const auto &&[index, newRegionArg] :
llvm::enumerate(newRegionIterArgs)) {
auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
unsigned resultNumber = fusableProducer.getResultNumber();
loc, newRegionArg, offsetList[index], sizesList[index],
SmallVector<OpFoldResult>(offsetList[index].size(),
rewriter.getIndexAttr(1)));
unsigned resultNumber = initNumberList[index];
rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
});
}
Block *block = rewriter.getInsertionPoint()->getBlock();
rewriter.setInsertionPoint(block->getTerminator());
tiledResult.push_back(fusedProducerInfo.tiledAndFusedProducer);
tiledOffset.emplace_back(sliceOp.getMixedOffsets());
tiledSizes.emplace_back(sliceOp.getMixedSizes());
return success();
};
}

return addInitOperandsToLoopNest(rewriter, loops,
SmallVector<Value>{initValue.value()},
newYieldValuesFn);
}
return success();
// e. prepare tiled offset and sizes for later `insert_slice` creation by
// caller
Block *block = rewriter.getInsertionPoint()->getBlock();
rewriter.setInsertionPoint(block->getTerminator());
for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
tiledResult.push_back(tiledOwner->getResult(resultNumber));
tiledOffset.emplace_back(offsetList[index]);
tiledSizes.emplace_back(sizesList[index]);
}
return success();
};

return addInitOperandsToLoopNest(rewriter, loops, initValueList,
newYieldValuesFn);
}

/// Implementation of tile consumer and fuse producer greedily.
Expand Down Expand Up @@ -1072,14 +1145,22 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
continue;

if (yieldReplacement) {
// Reconstruct and yield all opResult of fusableProducerOp by default. The
// caller can specific which one to yield by designating optional argument
// named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
Operation *fusableProducerOp = fusableProducer.getOwner();
if (failed(yieldReplacementForFusedProducer(
rewriter, candidateSliceOp, fusedResult.value(), loops))) {
return rewriter.notifyMatchFailure(
fusableProducer.getOwner(), "failed to replacement value for this "
"oepration from within the tiled loop");
fusableProducerOp, "failed to replacement value for this "
"operation from within the tiled loop");
}
for (auto [index, result] :
llvm::enumerate(fusableProducerOp->getResults())) {
origValToResultNumber[result] = loops.front()->getNumResults() -
fusableProducerOp->getNumResults() +
index;
}
origValToResultNumber[fusableProducer] =
loops.front()->getNumResults() - 1;
}

if (Operation *tiledAndFusedOp =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,65 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0]
// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]]
// CHECK: return %[[RESULT]]#1, %[[RESULT]]#0

// -----

func.func @multiple_outputs_fusion_yield_all(%lhs0: tensor<32x32xf32>,
%rhs0: tensor<32x32xf32>, %init0: tensor<32x32xf32>, %init1: tensor<32x32xf32>,
%rhs1: tensor<32x32xf32>, %init2: tensor<32x32xf32>)
-> (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) {
%out0, %out1 = linalg.generic {
indexing_maps = [affine_map<(i, j) -> (i, j)>,
affine_map<(i, j) -> (i, j)>,
affine_map<(i, j) -> (i, j)>,
affine_map<(i, j) -> (j, i)>],
iterator_types = ["parallel", "parallel"]
}
ins(%lhs0, %rhs0: tensor<32x32xf32>, tensor<32x32xf32>)
outs(%init0, %init1: tensor<32x32xf32>, tensor<32x32xf32>) {
^bb0(%0: f32, %1: f32, %2: f32, %3: f32):
%4 = arith.mulf %0, %1 : f32
%5 = arith.addf %0, %1 : f32
linalg.yield %4, %5: f32, f32
} -> (tensor<32x32xf32>, tensor<32x32xf32>)

%out3 = linalg.add ins(%out0, %rhs1: tensor<32x32xf32>, tensor<32x32xf32>) outs(%init2: tensor<32x32xf32>) -> tensor<32x32xf32>

return %out0, %out1, %out3 : tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%add = transform.structured.match ops{["linalg.add"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%a, %b = transform.test.fuse_and_yield %add [16]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
// CHECK: func.func @multiple_outputs_fusion_yield_all(
// CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
// CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
// CHECK-SAME: %[[INIT0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<32x32xf32>)
// CHECK: %[[RESULT:.+]]:3 = scf.for %[[IV:[a-zA-Z0-9]+]] =
// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT2]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ITERARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
// CHECK-DAG: %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][%[[IV]], 0]
// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0]
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG2]][0, %[[IV]]]
// CHECK: %[[GENERIC_TILE:.+]]:2 = linalg.generic
// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
// CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][%[[IV]], 0]
// CHECK-DAG: %[[INIT2_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
// CHECK: %[[ADD_TILE:.+]] = linalg.add
// CHECK-SAME: ins(%[[GENERIC_TILE]]#0, %[[RHS1_TILE]] :
// CHECK-SAME: outs(%[[INIT2_TILE]] :
// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ADD_TILE]] into %[[ITERARG0]][%[[IV]], 0]
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#0 into %[[ITERARG1]][%[[IV]], 0]
// CHECK: %[[INSERT2:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#1 into %[[ITERARG2]][0, %[[IV]]]
// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]], %[[INSERT2]]
// CHECK: return %[[RESULT]]#1, %[[RESULT]]#2, %[[RESULT]]#0
Loading