-
Notifications
You must be signed in to change notification settings - Fork 14.2k
[mlir] Fix block merging #97697
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Fix block merging #97697
Conversation
@llvm/pr-subscribers-mlir-bufferization @llvm/pr-subscribers-mlir Author: Giuseppe Rossini (giuseros) ChangesWith this PR I am trying to address: #63230. What changed:
Note: this a rework of #96871 . I ran all the integration tests ( Patch is 33.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97697.diff 12 Files Affected:
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 954485cfede3d..5227b22653eef 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -463,10 +463,15 @@ struct BufferDeallocationSimplificationPass
SplitDeallocWhenNotAliasingAnyOther,
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
analysis);
+ // We don't want that the block structure changes invalidating the
+ // `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
+ // region simplification
+ GreedyRewriteConfig config;
+ config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
- if (failed(
- applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
+ config)))
signalPassFailure();
}
};
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 4c0f15bafbaba..8b742bada67b9 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -9,6 +9,7 @@
#include "mlir/Transforms/RegionUtils.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Block.h"
+#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
@@ -16,11 +17,15 @@
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
#include <deque>
+#include <iterator>
using namespace mlir;
@@ -674,6 +679,64 @@ static bool ableToUpdatePredOperands(Block *block) {
return true;
}
+/// Prunes the redundant list of arguments. E.g., if we are passing an argument
+/// list like [x, y, z, x] this would return [x, y, z] and it would update the
+/// `block` (to whom the argument are passed to) accordingly
+static void
+pruneRedundantArguments(SmallVector<SmallVector<Value, 8>, 2> &newArguments,
+ RewriterBase &rewriter, Block *block) {
+ SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned(
+ newArguments.size(), SmallVector<Value, 8>());
+
+ if (!newArguments.empty()) {
+ llvm::DenseMap<unsigned, unsigned> toReplace;
+ // Go through the first list of arguments (list 0)
+ for (unsigned j = 0; j < newArguments[0].size(); j++) {
+ bool shouldReplaceJ = false;
+ unsigned replacement = 0;
+ // Look back to see if there are possible redundancies in
+ // list 0
+ for (unsigned k = 0; k < j; k++) {
+ if (newArguments[0][k] == newArguments[0][j]) {
+ shouldReplaceJ = true;
+ replacement = k;
+ // If a possible redundancy is found, then scan the other lists: we
+ // can prune the arguments if and only if they are redundant in every
+ // list
+ for (unsigned i = 1; i < newArguments.size(); i++)
+ shouldReplaceJ =
+ shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
+ }
+ }
+ // Save the replacement
+ if (shouldReplaceJ)
+ toReplace[j] = replacement;
+ }
+
+ // Populate the pruned argument list
+ for (unsigned i = 0; i < newArguments.size(); i++)
+ for (unsigned j = 0; j < newArguments[i].size(); j++)
+ if (!toReplace.contains(j))
+ newArgumentsPruned[i].push_back(newArguments[i][j]);
+
+ // Replace the block's redundant arguments
+ SmallVector<unsigned> toErase;
+ for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
+ if (toReplace.contains(idx)) {
+ Value oldArg = block->getArgument(idx);
+ Value newArg = block->getArgument(toReplace[idx]);
+ rewriter.replaceAllUsesWith(oldArg, newArg);
+ toErase.push_back(idx);
+ }
+ }
+
+ // Erase the block's redundant arguments
+ for (auto idxToErase : llvm::reverse(toErase))
+ block->eraseArgument(idxToErase);
+ newArguments = newArgumentsPruned;
+ }
+}
+
LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
// Don't consider clusters that don't have blocks to merge.
if (blocksToMerge.empty())
@@ -699,9 +762,8 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
blockIterators.push_back(mergeBlock->begin());
// Update each of the predecessor terminators with the new arguments.
- SmallVector<SmallVector<Value, 8>, 2> newArguments(
- 1 + blocksToMerge.size(),
- SmallVector<Value, 8>(operandsToMerge.size()));
+ SmallVector<SmallVector<Value, 8>, 2> newArguments(1 + blocksToMerge.size(),
+ SmallVector<Value, 8>());
unsigned curOpIndex = 0;
for (const auto &it : llvm::enumerate(operandsToMerge)) {
unsigned nextOpOffset = it.value().first - curOpIndex;
@@ -712,16 +774,19 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
Block::iterator &blockIter = blockIterators[i];
std::advance(blockIter, nextOpOffset);
auto &operand = blockIter->getOpOperand(it.value().second);
- newArguments[i][it.index()] = operand.get();
-
+ Value operandVal = operand.get();
+ newArguments[i].push_back(operandVal);
// Update the operand and insert an argument if this is the leader.
if (i == 0) {
- Value operandVal = operand.get();
operand.set(leaderBlock->addArgument(operandVal.getType(),
operandVal.getLoc()));
}
}
}
+
+ // Prune redundant arguments and update the leader block argument list
+ pruneRedundantArguments(newArguments, rewriter, leaderBlock);
+
// Update the predecessors for each of the blocks.
auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
for (auto predIt = block->pred_begin(), predE = block->pred_end();
@@ -818,6 +883,109 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
return success(anyChanged);
}
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+ Block &block) {
+ SmallVector<size_t> argsToErase;
+
+ // Go through the arguments of the block
+ for (size_t argIdx = 0; argIdx < block.getNumArguments(); argIdx++) {
+ bool sameArg = true;
+ Value commonValue;
+
+ // Go through the block predecessor and flag if they pass to the block
+ // different values for the same argument
+ for (auto predIt = block.pred_begin(), predE = block.pred_end();
+ predIt != predE; ++predIt) {
+ auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
+ if (!branch) {
+ sameArg = false;
+ break;
+ }
+ unsigned succIndex = predIt.getSuccessorIndex();
+ SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+ auto operands = succOperands.getForwardedOperands();
+ if (!commonValue) {
+ commonValue = operands[argIdx];
+ } else {
+ if (operands[argIdx] != commonValue) {
+ sameArg = false;
+ break;
+ }
+ }
+ }
+
+ // If they are passing the same value, drop the argument
+ if (commonValue && sameArg) {
+ argsToErase.push_back(argIdx);
+
+ // Remove the argument from the block
+ Value argVal = block.getArgument(argIdx);
+ rewriter.replaceAllUsesWith(argVal, commonValue);
+ }
+ }
+
+ // Remove the arguments
+ for (auto argIdx : llvm::reverse(argsToErase)) {
+ block.eraseArgument(argIdx);
+
+ // Remove the argument from the branch ops
+ for (auto predIt = block.pred_begin(), predE = block.pred_end();
+ predIt != predE; ++predIt) {
+ auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
+ unsigned succIndex = predIt.getSuccessorIndex();
+ SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+ succOperands.erase(argIdx);
+ }
+ }
+ return success(!argsToErase.empty());
+}
+
+/// This optimization drops redundant argument to blocks. I.e., if a given
+/// argument to a block receives the same value from each of the block
+/// predecessors, we can remove the argument from the block and use directly the
+/// original value. This is a simple example:
+///
+/// %cond = llvm.call @rand() : () -> i1
+/// %val0 = llvm.mlir.constant(1 : i64) : i64
+/// %val1 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(3 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
+/// : i64)
+///
+/// ^bb1(%arg0 : i64, %arg1 : i64):
+/// llvm.call @foo(%arg0, %arg1)
+///
+/// The previous IR can be rewritten as:
+/// %cond = llvm.call @rand() : () -> i1
+/// %val0 = llvm.mlir.constant(1 : i64) : i64
+/// %val1 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(3 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
+///
+/// ^bb1(%arg0 : i64):
+/// llvm.call @foo(%val0, %arg0)
+///
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+ MutableArrayRef<Region> regions) {
+ llvm::SmallSetVector<Region *, 1> worklist;
+ for (auto ®ion : regions)
+ worklist.insert(®ion);
+ bool anyChanged = false;
+ while (!worklist.empty()) {
+ Region *region = worklist.pop_back_val();
+
+ // Add any nested regions to the worklist.
+ for (Block &block : *region) {
+ anyChanged = succeeded(dropRedundantArguments(rewriter, block));
+
+ for (auto &op : block)
+ for (auto &nestedRegion : op.getRegions())
+ worklist.insert(&nestedRegion);
+ }
+ }
+ return success(anyChanged);
+}
+
//===----------------------------------------------------------------------===//
// Region Simplification
//===----------------------------------------------------------------------===//
@@ -832,8 +1000,12 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
bool mergedIdenticalBlocks = false;
- if (mergeBlocks)
+ bool droppedRedundantArguments = false;
+ if (mergeBlocks) {
mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
+ droppedRedundantArguments =
+ succeeded(dropRedundantArguments(rewriter, regions));
+ }
return success(eliminatedBlocks || eliminatedOpsOrArgs ||
- mergedIdenticalBlocks);
+ mergedIdenticalBlocks || droppedRedundantArguments);
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
index 5e8104f83cc4d..8e14990502143 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
@@ -178,7 +178,7 @@ func.func @condBranchDynamicTypeNested(
// CHECK-NEXT: ^bb1
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
-// CHECK: cf.br ^bb5([[ARG1]], %false{{[0-9_]*}} :
+// CHECK: cf.br ^bb6([[ARG1]], %false{{[0-9_]*}} :
// CHECK: ^bb2([[IDX:%.*]]:{{.*}})
// CHECK: [[ALLOC1:%.*]] = memref.alloc([[IDX]])
// CHECK-NEXT: test.buffer_based
@@ -186,20 +186,24 @@ func.func @condBranchDynamicTypeNested(
// CHECK-NEXT: [[OWN:%.+]] = arith.select [[ARG0]], [[ARG0]], [[NOT_ARG0]]
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
-// CHECK: cf.cond_br{{.*}}, ^bb3, ^bb3
+// CHECK: cf.cond_br{{.*}}, ^bb3, ^bb4
// CHECK-NEXT: ^bb3:
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
-// CHECK: cf.br ^bb4([[ALLOC1]], [[OWN]]
-// CHECK-NEXT: ^bb4([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
+// CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
+// CHECK-NEXT: ^bb4:
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
-// CHECK: cf.br ^bb5([[ALLOC2]], [[COND1]]
-// CHECK-NEXT: ^bb5([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
+// CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
+// CHECK-NEXT: ^bb5([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
+// CHECK-NOT: bufferization.dealloc
+// CHECK-NOT: bufferization.clone
+// CHECK: cf.br ^bb6([[ALLOC2]], [[COND1]]
+// CHECK-NEXT: ^bb6([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC4]]
// CHECK-NEXT: [[OWN:%.+]]:2 = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[COND2]]) retain ([[ALLOC4]], [[ARG2]] :
-// CHECK: cf.br ^bb6([[ALLOC4]], [[OWN]]#0
-// CHECK-NEXT: ^bb6([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
+// CHECK: cf.br ^bb7([[ALLOC4]], [[OWN]]#0
+// CHECK-NEXT: ^bb7([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
// CHECK: test.copy
// CHECK: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC5]]
// CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND3]])
diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
index d1a89226fdb58..50a2d6bf532aa 100644
--- a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
@@ -15,7 +15,7 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK-LABEL: @main
// CHECK-SAME: (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
// CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
-// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
-// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
-// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
+// CHECK: cf.br ^{{.*}}
+// CHECK: ^{{.*}}:
+// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32>
// CHECK: return %[[ELEMENTS]] : tensor<f32>
diff --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir
index 8d17763c04b6c..c728ad21d2209 100644
--- a/mlir/test/Dialect/Linalg/detensorize_if.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir
@@ -42,18 +42,15 @@ func.func @main() -> (tensor<i32>) attributes {} {
}
// CHECK-LABEL: func @main()
-// CHECK-DAG: arith.constant 0
-// CHECK-DAG: arith.constant 10
-// CHECK: cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT: arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT: return %{{.*}}
+// CHECK-DAG: %[[cst:.*]] = arith.constant dense<0>
+// CHECK-DAG: arith.constant true
+// CHECK: cf.br
+// CHECK-NEXT: ^[[bb1:.*]]:
+// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
+// CHECK-NEXT: ^[[bb2]]
+// CHECK-NEXT: cf.br ^[[bb3:.*]]
+// CHECK-NEXT: ^[[bb3]]
+// CHECK-NEXT: return %[[cst]]
// CHECK-NEXT: }
// -----
@@ -106,20 +103,17 @@ func.func @main() -> (tensor<i32>) attributes {} {
}
// CHECK-LABEL: func @main()
-// CHECK-DAG: arith.constant 0
-// CHECK-DAG: arith.constant 10
-// CHECK: cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT: arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT: cf.br ^[[bb4:.*]](%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb4]](%{{.*}}: i32)
-// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT: return %{{.*}}
+// CHECK-DAG: %[[cst:.*]] = arith.constant dense<0>
+// CHECK-DAG: arith.constant true
+// CHECK: cf.br ^[[bb1:.*]]
+// CHECK-NEXT: ^[[bb1:.*]]:
+// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
+// CHECK-NEXT: ^[[bb2]]:
+// CHECK-NEXT: cf.br ^[[bb3:.*]]
+// CHECK-NEXT: ^[[bb3]]:
+// CHECK-NEXT: cf.br ^[[bb4:.*]]
+// CHECK-NEXT: ^[[bb4]]:
+// CHECK-NEXT: return %[[cst]]
// CHECK-NEXT: }
// -----
@@ -171,16 +165,13 @@ func.func @main() -> (tensor<i32>) attributes {} {
}
// CHECK-LABEL: func @main()
-// CHECK-DAG: arith.constant 0
-// CHECK-DAG: arith.constant 10
-// CHECK: cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb2(%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT: arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT: return %{{.*}}
+// CHECK-DAG: %[[cst:.*]] = arith.constant dense<10>
+// CHECK-DAG: arith.constant true
+// CHECK: cf.br ^[[bb1:.*]]
+// CHECK-NEXT: ^[[bb1]]:
+// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb2
+// CHECK-NEXT: ^[[bb2]]
+// CHECK-NEXT: cf.br ^[[bb3:.*]]
+// CHECK-NEXT: ^[[bb3]]
+// CHECK-NEXT: return %[[cst]]
// CHECK-NEXT: }
diff --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir
index aa30900f76a33..580a97d3a851b 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir
@@ -46,11 +46,11 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
// DET-ALL: cf.br ^[[bb1:.*]](%{{.*}} : i32)
// DET-ALL: ^[[bb1]](%{{.*}}: i32)
// DET-ALL: arith.cmpi slt, {{.*}}
-// DET-ALL: cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
-// DET-ALL: ^[[bb2]](%{{.*}}: i32)
+// DET-ALL: cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
+// DET-ALL: ^[[bb2]]
// DET-ALL: arith.addi {{.*}}
// DET-ALL: cf.br ^[[bb1]](%{{.*}} : i32)
-// DET-ALL: ^[[bb3]](%{{.*}}: i32)
+// DET-ALL: ^[[bb3]]:
// DET-ALL: tensor.from_elements {{.*}}
// DET-ALL: return %{{.*}} : tensor<i32>
@@ -62,10 +62,10 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
// DET-CF: cf.br ^[[bb1:.*]](%{{.*}} : i32)
// DET-CF: ^[[bb1]](%{{.*}}: i32)
// DET-CF: arith.cmpi slt, {{.*}}
-// DET-CF: cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
-// DET-CF: ^[[bb2]](%{{.*}}: i32)
+// DET-CF: cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
+// DET-CF: ^[[bb2]]:
// DET-CF: arith.addi {{.*}}
// DET-CF: cf.br ^[[bb1]](%{{.*}} : i32)
-// DET-CF: ^[[bb3]](%{{.*}}: i32)
+// DET-CF: ^[[bb3]]:
// DET-CF: tensor.from_elements %{{.*}} : tensor<i32>
// DET-CF: return %{{.*}} : tensor<i32>
diff --git a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
index 955c7be5ef4c8..414d9b94cbf53 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
@@ -74,8 +74,8 @@ func.func @main(%farg0...
[truncated]
|
@llvm/pr-subscribers-mlir-core Author: Giuseppe Rossini (giuseros) ChangesWith this PR I am trying to address: #63230. What changed:
Note: this a rework of #96871 . I ran all the integration tests ( Patch is 33.21 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97697.diff 12 Files Affected:
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 954485cfede3da..5227b22653eefc 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -463,10 +463,15 @@ struct BufferDeallocationSimplificationPass
SplitDeallocWhenNotAliasingAnyOther,
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
analysis);
+ // We don't want that the block structure changes invalidating the
+ // `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
+ // region simplification
+ GreedyRewriteConfig config;
+ config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
- if (failed(
- applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
+ config)))
signalPassFailure();
}
};
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 4c0f15bafbaba3..8b742bada67b97 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -9,6 +9,7 @@
#include "mlir/Transforms/RegionUtils.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Block.h"
+#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
@@ -16,11 +17,15 @@
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
#include <deque>
+#include <iterator>
using namespace mlir;
@@ -674,6 +679,64 @@ static bool ableToUpdatePredOperands(Block *block) {
return true;
}
+/// Prunes the redundant list of arguments. E.g., if we are passing an argument
+/// list like [x, y, z, x] this would return [x, y, z] and it would update the
+/// `block` (to whom the argument are passed to) accordingly
+static void
+pruneRedundantArguments(SmallVector<SmallVector<Value, 8>, 2> &newArguments,
+ RewriterBase &rewriter, Block *block) {
+ SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned(
+ newArguments.size(), SmallVector<Value, 8>());
+
+ if (!newArguments.empty()) {
+ llvm::DenseMap<unsigned, unsigned> toReplace;
+ // Go through the first list of arguments (list 0)
+ for (unsigned j = 0; j < newArguments[0].size(); j++) {
+ bool shouldReplaceJ = false;
+ unsigned replacement = 0;
+ // Look back to see if there are possible redundancies in
+ // list 0
+ for (unsigned k = 0; k < j; k++) {
+ if (newArguments[0][k] == newArguments[0][j]) {
+ shouldReplaceJ = true;
+ replacement = k;
+ // If a possible redundancy is found, then scan the other lists: we
+ // can prune the arguments if and only if they are redundant in every
+ // list
+ for (unsigned i = 1; i < newArguments.size(); i++)
+ shouldReplaceJ =
+ shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
+ }
+ }
+ // Save the replacement
+ if (shouldReplaceJ)
+ toReplace[j] = replacement;
+ }
+
+ // Populate the pruned argument list
+ for (unsigned i = 0; i < newArguments.size(); i++)
+ for (unsigned j = 0; j < newArguments[i].size(); j++)
+ if (!toReplace.contains(j))
+ newArgumentsPruned[i].push_back(newArguments[i][j]);
+
+ // Replace the block's redundant arguments
+ SmallVector<unsigned> toErase;
+ for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
+ if (toReplace.contains(idx)) {
+ Value oldArg = block->getArgument(idx);
+ Value newArg = block->getArgument(toReplace[idx]);
+ rewriter.replaceAllUsesWith(oldArg, newArg);
+ toErase.push_back(idx);
+ }
+ }
+
+ // Erase the block's redundant arguments
+ for (auto idxToErase : llvm::reverse(toErase))
+ block->eraseArgument(idxToErase);
+ newArguments = newArgumentsPruned;
+ }
+}
+
LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
// Don't consider clusters that don't have blocks to merge.
if (blocksToMerge.empty())
@@ -699,9 +762,8 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
blockIterators.push_back(mergeBlock->begin());
// Update each of the predecessor terminators with the new arguments.
- SmallVector<SmallVector<Value, 8>, 2> newArguments(
- 1 + blocksToMerge.size(),
- SmallVector<Value, 8>(operandsToMerge.size()));
+ SmallVector<SmallVector<Value, 8>, 2> newArguments(1 + blocksToMerge.size(),
+ SmallVector<Value, 8>());
unsigned curOpIndex = 0;
for (const auto &it : llvm::enumerate(operandsToMerge)) {
unsigned nextOpOffset = it.value().first - curOpIndex;
@@ -712,16 +774,19 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
Block::iterator &blockIter = blockIterators[i];
std::advance(blockIter, nextOpOffset);
auto &operand = blockIter->getOpOperand(it.value().second);
- newArguments[i][it.index()] = operand.get();
-
+ Value operandVal = operand.get();
+ newArguments[i].push_back(operandVal);
// Update the operand and insert an argument if this is the leader.
if (i == 0) {
- Value operandVal = operand.get();
operand.set(leaderBlock->addArgument(operandVal.getType(),
operandVal.getLoc()));
}
}
}
+
+ // Prune redundant arguments and update the leader block argument list
+ pruneRedundantArguments(newArguments, rewriter, leaderBlock);
+
// Update the predecessors for each of the blocks.
auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
for (auto predIt = block->pred_begin(), predE = block->pred_end();
@@ -818,6 +883,109 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
return success(anyChanged);
}
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+ Block &block) {
+ SmallVector<size_t> argsToErase;
+
+ // Go through the arguments of the block
+ for (size_t argIdx = 0; argIdx < block.getNumArguments(); argIdx++) {
+ bool sameArg = true;
+ Value commonValue;
+
+ // Go through the block predecessor and flag if they pass to the block
+ // different values for the same argument
+ for (auto predIt = block.pred_begin(), predE = block.pred_end();
+ predIt != predE; ++predIt) {
+ auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
+ if (!branch) {
+ sameArg = false;
+ break;
+ }
+ unsigned succIndex = predIt.getSuccessorIndex();
+ SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+ auto operands = succOperands.getForwardedOperands();
+ if (!commonValue) {
+ commonValue = operands[argIdx];
+ } else {
+ if (operands[argIdx] != commonValue) {
+ sameArg = false;
+ break;
+ }
+ }
+ }
+
+ // If they are passing the same value, drop the argument
+ if (commonValue && sameArg) {
+ argsToErase.push_back(argIdx);
+
+ // Remove the argument from the block
+ Value argVal = block.getArgument(argIdx);
+ rewriter.replaceAllUsesWith(argVal, commonValue);
+ }
+ }
+
+ // Remove the arguments
+ for (auto argIdx : llvm::reverse(argsToErase)) {
+ block.eraseArgument(argIdx);
+
+ // Remove the argument from the branch ops
+ for (auto predIt = block.pred_begin(), predE = block.pred_end();
+ predIt != predE; ++predIt) {
+ auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
+ unsigned succIndex = predIt.getSuccessorIndex();
+ SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+ succOperands.erase(argIdx);
+ }
+ }
+ return success(!argsToErase.empty());
+}
+
+/// This optimization drops redundant argument to blocks. I.e., if a given
+/// argument to a block receives the same value from each of the block
+/// predecessors, we can remove the argument from the block and use directly the
+/// original value. This is a simple example:
+///
+/// %cond = llvm.call @rand() : () -> i1
+/// %val0 = llvm.mlir.constant(1 : i64) : i64
+/// %val1 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(3 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
+/// : i64)
+///
+/// ^bb1(%arg0 : i64, %arg1 : i64):
+/// llvm.call @foo(%arg0, %arg1)
+///
+/// The previous IR can be rewritten as:
+/// %cond = llvm.call @rand() : () -> i1
+/// %val0 = llvm.mlir.constant(1 : i64) : i64
+/// %val1 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(3 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
+///
+/// ^bb1(%arg0 : i64):
+/// llvm.call @foo(%val0, %arg0)
+///
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+ MutableArrayRef<Region> regions) {
+ llvm::SmallSetVector<Region *, 1> worklist;
+ for (auto ®ion : regions)
+ worklist.insert(®ion);
+ bool anyChanged = false;
+ while (!worklist.empty()) {
+ Region *region = worklist.pop_back_val();
+
+ // Add any nested regions to the worklist.
+ for (Block &block : *region) {
+ anyChanged = succeeded(dropRedundantArguments(rewriter, block));
+
+ for (auto &op : block)
+ for (auto &nestedRegion : op.getRegions())
+ worklist.insert(&nestedRegion);
+ }
+ }
+ return success(anyChanged);
+}
+
//===----------------------------------------------------------------------===//
// Region Simplification
//===----------------------------------------------------------------------===//
@@ -832,8 +1000,12 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
bool mergedIdenticalBlocks = false;
- if (mergeBlocks)
+ bool droppedRedundantArguments = false;
+ if (mergeBlocks) {
mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
+ droppedRedundantArguments =
+ succeeded(dropRedundantArguments(rewriter, regions));
+ }
return success(eliminatedBlocks || eliminatedOpsOrArgs ||
- mergedIdenticalBlocks);
+ mergedIdenticalBlocks || droppedRedundantArguments);
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
index 5e8104f83cc4d4..8e14990502143e 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
@@ -178,7 +178,7 @@ func.func @condBranchDynamicTypeNested(
// CHECK-NEXT: ^bb1
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
-// CHECK: cf.br ^bb5([[ARG1]], %false{{[0-9_]*}} :
+// CHECK: cf.br ^bb6([[ARG1]], %false{{[0-9_]*}} :
// CHECK: ^bb2([[IDX:%.*]]:{{.*}})
// CHECK: [[ALLOC1:%.*]] = memref.alloc([[IDX]])
// CHECK-NEXT: test.buffer_based
@@ -186,20 +186,24 @@ func.func @condBranchDynamicTypeNested(
// CHECK-NEXT: [[OWN:%.+]] = arith.select [[ARG0]], [[ARG0]], [[NOT_ARG0]]
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
-// CHECK: cf.cond_br{{.*}}, ^bb3, ^bb3
+// CHECK: cf.cond_br{{.*}}, ^bb3, ^bb4
// CHECK-NEXT: ^bb3:
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
-// CHECK: cf.br ^bb4([[ALLOC1]], [[OWN]]
-// CHECK-NEXT: ^bb4([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
+// CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
+// CHECK-NEXT: ^bb4:
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
-// CHECK: cf.br ^bb5([[ALLOC2]], [[COND1]]
-// CHECK-NEXT: ^bb5([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
+// CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
+// CHECK-NEXT: ^bb5([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
+// CHECK-NOT: bufferization.dealloc
+// CHECK-NOT: bufferization.clone
+// CHECK: cf.br ^bb6([[ALLOC2]], [[COND1]]
+// CHECK-NEXT: ^bb6([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC4]]
// CHECK-NEXT: [[OWN:%.+]]:2 = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[COND2]]) retain ([[ALLOC4]], [[ARG2]] :
-// CHECK: cf.br ^bb6([[ALLOC4]], [[OWN]]#0
-// CHECK-NEXT: ^bb6([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
+// CHECK: cf.br ^bb7([[ALLOC4]], [[OWN]]#0
+// CHECK-NEXT: ^bb7([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
// CHECK: test.copy
// CHECK: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC5]]
// CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND3]])
diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
index d1a89226fdb58f..50a2d6bf532aa3 100644
--- a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
@@ -15,7 +15,7 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK-LABEL: @main
// CHECK-SAME: (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
// CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
-// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
-// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
-// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
+// CHECK: cf.br ^{{.*}}
+// CHECK: ^{{.*}}:
+// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32>
// CHECK: return %[[ELEMENTS]] : tensor<f32>
diff --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir
index 8d17763c04b6c4..c728ad21d2209b 100644
--- a/mlir/test/Dialect/Linalg/detensorize_if.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir
@@ -42,18 +42,15 @@ func.func @main() -> (tensor<i32>) attributes {} {
}
// CHECK-LABEL: func @main()
-// CHECK-DAG: arith.constant 0
-// CHECK-DAG: arith.constant 10
-// CHECK: cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT: arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT: return %{{.*}}
+// CHECK-DAG: %[[cst:.*]] = arith.constant dense<0>
+// CHECK-DAG: arith.constant true
+// CHECK: cf.br
+// CHECK-NEXT: ^[[bb1:.*]]:
+// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
+// CHECK-NEXT: ^[[bb2]]
+// CHECK-NEXT: cf.br ^[[bb3:.*]]
+// CHECK-NEXT: ^[[bb3]]
+// CHECK-NEXT: return %[[cst]]
// CHECK-NEXT: }
// -----
@@ -106,20 +103,17 @@ func.func @main() -> (tensor<i32>) attributes {} {
}
// CHECK-LABEL: func @main()
-// CHECK-DAG: arith.constant 0
-// CHECK-DAG: arith.constant 10
-// CHECK: cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT: arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT: cf.br ^[[bb4:.*]](%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb4]](%{{.*}}: i32)
-// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT: return %{{.*}}
+// CHECK-DAG: %[[cst:.*]] = arith.constant dense<0>
+// CHECK-DAG: arith.constant true
+// CHECK: cf.br ^[[bb1:.*]]
+// CHECK-NEXT: ^[[bb1:.*]]:
+// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
+// CHECK-NEXT: ^[[bb2]]:
+// CHECK-NEXT: cf.br ^[[bb3:.*]]
+// CHECK-NEXT: ^[[bb3]]:
+// CHECK-NEXT: cf.br ^[[bb4:.*]]
+// CHECK-NEXT: ^[[bb4]]:
+// CHECK-NEXT: return %[[cst]]
// CHECK-NEXT: }
// -----
@@ -171,16 +165,13 @@ func.func @main() -> (tensor<i32>) attributes {} {
}
// CHECK-LABEL: func @main()
-// CHECK-DAG: arith.constant 0
-// CHECK-DAG: arith.constant 10
-// CHECK: cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb2(%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT: arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT: return %{{.*}}
+// CHECK-DAG: %[[cst:.*]] = arith.constant dense<10>
+// CHECK-DAG: arith.constant true
+// CHECK: cf.br ^[[bb1:.*]]
+// CHECK-NEXT: ^[[bb1]]:
+// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb2
+// CHECK-NEXT: ^[[bb2]]
+// CHECK-NEXT: cf.br ^[[bb3:.*]]
+// CHECK-NEXT: ^[[bb3]]
+// CHECK-NEXT: return %[[cst]]
// CHECK-NEXT: }
diff --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir
index aa30900f76a334..580a97d3a851ba 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir
@@ -46,11 +46,11 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
// DET-ALL: cf.br ^[[bb1:.*]](%{{.*}} : i32)
// DET-ALL: ^[[bb1]](%{{.*}}: i32)
// DET-ALL: arith.cmpi slt, {{.*}}
-// DET-ALL: cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
-// DET-ALL: ^[[bb2]](%{{.*}}: i32)
+// DET-ALL: cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
+// DET-ALL: ^[[bb2]]
// DET-ALL: arith.addi {{.*}}
// DET-ALL: cf.br ^[[bb1]](%{{.*}} : i32)
-// DET-ALL: ^[[bb3]](%{{.*}}: i32)
+// DET-ALL: ^[[bb3]]:
// DET-ALL: tensor.from_elements {{.*}}
// DET-ALL: return %{{.*}} : tensor<i32>
@@ -62,10 +62,10 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
// DET-CF: cf.br ^[[bb1:.*]](%{{.*}} : i32)
// DET-CF: ^[[bb1]](%{{.*}}: i32)
// DET-CF: arith.cmpi slt, {{.*}}
-// DET-CF: cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
-// DET-CF: ^[[bb2]](%{{.*}}: i32)
+// DET-CF: cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
+// DET-CF: ^[[bb2]]:
// DET-CF: arith.addi {{.*}}
// DET-CF: cf.br ^[[bb1]](%{{.*}} : i32)
-// DET-CF: ^[[bb3]](%{{.*}}: i32)
+// DET-CF: ^[[bb3]]:
// DET-CF: tensor.from_elements %{{.*}} : tensor<i32>
// DET-CF: return %{{.*}} : tensor<i32>
diff --git a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
index 955c7be5ef4c89..414d9b94cbf530 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
@@ -74,8 +74,8 @@ func.fun...
[truncated]
|
Hi @joker-eph , @Mogball ,
We were forming argument lists of different lengths I compiled with Thanks! |
1f153c3
to
c694caa
Compare
c694caa
to
249a387
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some comments on the coding style, I haven't reviewed the logic
Thanks for the review @kuhar ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks fine to me but I haven't had cycles to review the logic in full. You may need to find another reviewer for that.
Hi @krzysz00 , would you be able to have another look? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is fine now
Thanks @krzysz00! |
) Bumps llvm-project to: llvm/llvm-project@15495b8 - Updated `refine_usage.mlir` and `flatten_tuples_in_cfg.mlir` tests likely due to llvm/llvm-project#97697 Still carrying revert: 97c0dbe1ad6dacbcca84e63e9d726b85b65af4fe (TODO: bump torch-mlir and update to bumped submodule) --------- Signed-off-by: aviator19941 <avinash.sharma@amd.com>
This seems to contain a bug. After this patch, the following IR example explodes when running canonicalizations with aggressive region simplifications:
It seems that there is a type mismatch between branch operands and block arguments. |
Opened a revert PR, as this breaks downstream projects and a forward fix does not seem obvious to me. |
Reverts #97697 This commit introduced non-trivial bugs related to type consistency.
Summary: With this PR I am trying to address: #63230. What changed: - While merging identical blocks, don't add a block argument if it is "identical" to another block argument. I.e., if the two block arguments refer to the same `Value`. The operations operands in the block will point to the argument we already inserted. This needs to happen to all the arguments we pass to the different successors of the parent block - After merged the blocks, get rid of "unnecessary" arguments. I.e., if all the predecessors pass the same block argument, there is no need to pass it as an argument. - This last simplification clashed with `BufferDeallocationSimplification`. The reason, I think, is that the two simplifications are clashing. I.e., `BufferDeallocationSimplification` contains an analysis based on the block structure. If we simplify the block structure (by merging and/or dropping block arguments) the analysis is invalid . The solution I found is to do a more prudent simplification when running that pass. **Note**: this a rework of #96871 . I ran all the integration tests (`-DMLIR_INCLUDE_INTEGRATION_TESTS=ON`) and they passed. Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60250916
Summary: Reverts #97697 This commit introduced non-trivial bugs related to type consistency. Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60250716
…e-org#17946) Bumps llvm-project to: llvm/llvm-project@15495b8 - Updated `refine_usage.mlir` and `flatten_tuples_in_cfg.mlir` tests likely due to llvm/llvm-project#97697 Still carrying revert: 97c0dbe1ad6dacbcca84e63e9d726b85b65af4fe (TODO: bump torch-mlir and update to bumped submodule) --------- Signed-off-by: aviator19941 <avinash.sharma@amd.com> Signed-off-by: Lubo Litchev <lubol@google.com>
With this PR I am trying to address: #63230. What changed: - While merging identical blocks, don't add a block argument if it is "identical" to another block argument. I.e., if the two block arguments refer to the same `Value`. The operations operands in the block will point to the argument we already inserted. This needs to happen to all the arguments we pass to the different successors of the parent block - After merged the blocks, get rid of "unnecessary" arguments. I.e., if all the predecessors pass the same block argument, there is no need to pass it as an argument. - This last simplification clashed with `BufferDeallocationSimplification`. The reason, I think, is that the two simplifications are clashing. I.e., `BufferDeallocationSimplification` contains an analysis based on the block structure. If we simplify the block structure (by merging and/or dropping block arguments) the analysis is invalid . The solution I found is to do a more prudent simplification when running that pass. **Note-1**: I ran all the integration tests (`-DMLIR_INCLUDE_INTEGRATION_TESTS=ON`) and they passed. **Note-2**: I fixed a bug found by @Dinistro in #97697 . The issue was that, when looking for redundant arguments, I was not considering that the block might have already some arguments. So the index (in the block args list) of the i-th `newArgument` is `i+numOfOldArguments`.
With this PR I am trying to address: #63230.
What changed:
Value
. The operations operands in the block will point to the argument we already inserted. This needs to happen to all the arguments we pass to the different successors of the parent blockBufferDeallocationSimplification
. The reason, I think, is that the two simplifications are clashing. I.e.,BufferDeallocationSimplification
contains an analysis based on the block structure. If we simplify the block structure (by merging and/or dropping block arguments) the analysis is invalid . The solution I found is to do a more prudent simplification when running that pass.Note: this a rework of #96871 . I ran all the integration tests (
-DMLIR_INCLUDE_INTEGRATION_TESTS=ON
) and they passed.