-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][spirv] Fix FuncOpVectorUnroll to process placeholder values in all blocks #142339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir Author: Darren Wihandi (fairywreath) Changes
The current implementation however only replaces back (the second replacement, i.e. replacing the placeholder values to new/legal arguments) the first block of instructions and not all of the blocks. This may leave some instructions to use these placeholder values (which for already legal arguments are just zeroattr values that will get DCE'd) instead of the arguments, which is incorrect. Closes #132158. TODO: add test Full diff: https://github.com/llvm/llvm-project/pull/142339.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 62a24646d0662..84796fdeda03a 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1020,35 +1020,37 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
entryBlock.addArguments(convertedTypes, locs);
- // Replace the placeholder values with the new arguments. We assume there is
- // only one block for now.
+ // Replace the placeholder values with the new arguments.
size_t unrolledInputIdx = 0;
- for (auto [count, op] : enumerate(entryBlock.getOperations())) {
+ newFuncOp.walk([&](Operation *op) {
// We first look for operands that are placeholders for initially legal
// arguments.
- Operation &curOp = op;
- for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
+ for (auto [operandIdx, operandVal] : llvm::enumerate(op->getOperands())) {
Operation *operandOp = operandVal.getDefiningOp();
if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
size_t idx = operandIdx;
- rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] {
- curOp.setOperand(idx, newFuncOp.getArgument(it->second));
+ rewriter.modifyOpInPlace(op, [&] {
+ op->setOperand(idx, newFuncOp.getArgument(it->second));
});
}
}
+
// Since all newly created operations are in the beginning, reaching the
// end of them means that any later `vector.insert_strided_slice` should
// not be touched.
- if (count >= newOpCount)
- continue;
+ if (op->getBlock() == &entryBlock &&
+ static_cast<size_t>(std::distance(entryBlock.begin(),
+ op->getIterator())) >= newOpCount)
+ return;
+
if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
- rewriter.modifyOpInPlace(&curOp, [&] {
- curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
+ rewriter.modifyOpInPlace(op, [&] {
+ op->setOperand(0, newFuncOp.getArgument(unrolledInputNo));
});
++unrolledInputIdx;
}
- }
+ });
// Erase the original funcOp. The `tmpOps` do not need to be erased since
// they have no uses and will be handled by dead-code elimination.
|
@llvm/pr-subscribers-mlir-spirv Author: Darren Wihandi (fairywreath) Changes
The current implementation however only replaces back (the second replacement, i.e. replacing the placeholder values to new/legal arguments) the first block of instructions and not all of the blocks. This may leave some instructions to use these placeholder values (which for already legal arguments are just zeroattr values that will get DCE'd) instead of the arguments, which is incorrect. Closes #132158. TODO: add test Full diff: https://github.com/llvm/llvm-project/pull/142339.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 62a24646d0662..84796fdeda03a 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1020,35 +1020,37 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
entryBlock.addArguments(convertedTypes, locs);
- // Replace the placeholder values with the new arguments. We assume there is
- // only one block for now.
+ // Replace the placeholder values with the new arguments.
size_t unrolledInputIdx = 0;
- for (auto [count, op] : enumerate(entryBlock.getOperations())) {
+ newFuncOp.walk([&](Operation *op) {
// We first look for operands that are placeholders for initially legal
// arguments.
- Operation &curOp = op;
- for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
+ for (auto [operandIdx, operandVal] : llvm::enumerate(op->getOperands())) {
Operation *operandOp = operandVal.getDefiningOp();
if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
size_t idx = operandIdx;
- rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] {
- curOp.setOperand(idx, newFuncOp.getArgument(it->second));
+ rewriter.modifyOpInPlace(op, [&] {
+ op->setOperand(idx, newFuncOp.getArgument(it->second));
});
}
}
+
// Since all newly created operations are in the beginning, reaching the
// end of them means that any later `vector.insert_strided_slice` should
// not be touched.
- if (count >= newOpCount)
- continue;
+ if (op->getBlock() == &entryBlock &&
+ static_cast<size_t>(std::distance(entryBlock.begin(),
+ op->getIterator())) >= newOpCount)
+ return;
+
if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
- rewriter.modifyOpInPlace(&curOp, [&] {
- curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
+ rewriter.modifyOpInPlace(op, [&] {
+ op->setOperand(0, newFuncOp.getArgument(unrolledInputNo));
});
++unrolledInputIdx;
}
- }
+ });
// Erase the original funcOp. The `tmpOps` do not need to be erased since
// they have no uses and will be handled by dead-code elimination.
|
FuncOpVectorUnroll
contains logic that replaces function arguments by placeholders values. These replacements also involve changing all instructions in the function that use the arguments to use these placeholders. These placeholder values will later be changed back to use the function arguments (either new or original if already legal).The current implementation however only replaces back (the second replacement, i.e. replacing the placeholder values to new/legal arguments) the first block of instructions and not all of the blocks. This may leave some instructions to use these placeholder values (which for already legal arguments are just zeroattr values that will get DCE'd) instead of the arguments, which is incorrect.
Closes #132158.
TODO: add test