Skip to content

[mlir][spirv] Fix FuncOpVectorUnroll to process placeholder values in all blocks #142339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

fairywreath
Copy link
Contributor

FuncOpVectorUnroll contains logic that replaces function arguments by placeholders values. These replacements also involve changing all instructions in the function that use the arguments to use these placeholders. These placeholder values will later be changed back to use the function arguments (either new or original if already legal).

The current implementation however only replaces back (the second replacement, i.e. replacing the placeholder values to new/legal arguments) the first block of instructions and not all of the blocks. This may leave some instructions to use these placeholder values (which for already legal arguments are just zeroattr values that will get DCE'd) instead of the arguments, which is incorrect.

Closes #132158.

TODO: add test

@fairywreath fairywreath marked this pull request as draft June 2, 2025 07:09
@llvmbot
Copy link
Member

llvmbot commented Jun 2, 2025

@llvm/pr-subscribers-mlir

Author: Darren Wihandi (fairywreath)

Changes

FuncOpVectorUnroll contains logic that replaces function arguments by placeholders values. These replacements also involve changing all instructions in the function that use the arguments to use these placeholders. These placeholder values will later be changed back to use the function arguments (either new or original if already legal).

The current implementation however only replaces back (the second replacement, i.e. replacing the placeholder values to new/legal arguments) the first block of instructions and not all of the blocks. This may leave some instructions to use these placeholder values (which for already legal arguments are just zeroattr values that will get DCE'd) instead of the arguments, which is incorrect.

Closes #132158.

TODO: add test


Full diff: https://github.com/llvm/llvm-project/pull/142339.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+14-12)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 62a24646d0662..84796fdeda03a 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1020,35 +1020,37 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
     SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
     entryBlock.addArguments(convertedTypes, locs);
 
-    // Replace the placeholder values with the new arguments. We assume there is
-    // only one block for now.
+    // Replace the placeholder values with the new arguments.
     size_t unrolledInputIdx = 0;
-    for (auto [count, op] : enumerate(entryBlock.getOperations())) {
+    newFuncOp.walk([&](Operation *op) {
       // We first look for operands that are placeholders for initially legal
       // arguments.
-      Operation &curOp = op;
-      for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
+      for (auto [operandIdx, operandVal] : llvm::enumerate(op->getOperands())) {
         Operation *operandOp = operandVal.getDefiningOp();
         if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
           size_t idx = operandIdx;
-          rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] {
-            curOp.setOperand(idx, newFuncOp.getArgument(it->second));
+          rewriter.modifyOpInPlace(op, [&] {
+            op->setOperand(idx, newFuncOp.getArgument(it->second));
           });
         }
       }
+
       // Since all newly created operations are in the beginning, reaching the
       // end of them means that any later `vector.insert_strided_slice` should
       // not be touched.
-      if (count >= newOpCount)
-        continue;
+      if (op->getBlock() == &entryBlock &&
+          static_cast<size_t>(std::distance(entryBlock.begin(),
+                                            op->getIterator())) >= newOpCount)
+        return;
+
       if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
         size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
-        rewriter.modifyOpInPlace(&curOp, [&] {
-          curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
+        rewriter.modifyOpInPlace(op, [&] {
+          op->setOperand(0, newFuncOp.getArgument(unrolledInputNo));
         });
         ++unrolledInputIdx;
       }
-    }
+    });
 
     // Erase the original funcOp. The `tmpOps` do not need to be erased since
     // they have no uses and will be handled by dead-code elimination.

@llvmbot
Copy link
Member

llvmbot commented Jun 2, 2025

@llvm/pr-subscribers-mlir-spirv

Author: Darren Wihandi (fairywreath)

Changes

FuncOpVectorUnroll contains logic that replaces function arguments by placeholders values. These replacements also involve changing all instructions in the function that use the arguments to use these placeholders. These placeholder values will later be changed back to use the function arguments (either new or original if already legal).

The current implementation however only replaces back (the second replacement, i.e. replacing the placeholder values to new/legal arguments) the first block of instructions and not all of the blocks. This may leave some instructions to use these placeholder values (which for already legal arguments are just zeroattr values that will get DCE'd) instead of the arguments, which is incorrect.

Closes #132158.

TODO: add test


Full diff: https://github.com/llvm/llvm-project/pull/142339.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+14-12)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 62a24646d0662..84796fdeda03a 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1020,35 +1020,37 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
     SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
     entryBlock.addArguments(convertedTypes, locs);
 
-    // Replace the placeholder values with the new arguments. We assume there is
-    // only one block for now.
+    // Replace the placeholder values with the new arguments.
     size_t unrolledInputIdx = 0;
-    for (auto [count, op] : enumerate(entryBlock.getOperations())) {
+    newFuncOp.walk([&](Operation *op) {
       // We first look for operands that are placeholders for initially legal
       // arguments.
-      Operation &curOp = op;
-      for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
+      for (auto [operandIdx, operandVal] : llvm::enumerate(op->getOperands())) {
         Operation *operandOp = operandVal.getDefiningOp();
         if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
           size_t idx = operandIdx;
-          rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] {
-            curOp.setOperand(idx, newFuncOp.getArgument(it->second));
+          rewriter.modifyOpInPlace(op, [&] {
+            op->setOperand(idx, newFuncOp.getArgument(it->second));
           });
         }
       }
+
       // Since all newly created operations are in the beginning, reaching the
       // end of them means that any later `vector.insert_strided_slice` should
       // not be touched.
-      if (count >= newOpCount)
-        continue;
+      if (op->getBlock() == &entryBlock &&
+          static_cast<size_t>(std::distance(entryBlock.begin(),
+                                            op->getIterator())) >= newOpCount)
+        return;
+
       if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
         size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
-        rewriter.modifyOpInPlace(&curOp, [&] {
-          curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
+        rewriter.modifyOpInPlace(op, [&] {
+          op->setOperand(0, newFuncOp.getArgument(unrolledInputNo));
         });
         ++unrolledInputIdx;
       }
-    }
+    });
 
     // Erase the original funcOp. The `tmpOps` do not need to be erased since
     // they have no uses and will be handled by dead-code elimination.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[MLIR] -test-convert-to-spirv triggers Assertion `succeeded(result) && "expected ConstantLike op to be foldable"' failed.
2 participants