Skip to content

[mlir] Fix block merging #97697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 17, 2024
Merged

[mlir] Fix block merging #97697

merged 7 commits into from
Jul 17, 2024

Conversation

giuseros
Copy link
Contributor

@giuseros giuseros commented Jul 4, 2024

With this PR I am trying to address: #63230.

What changed:

  • While merging identical blocks, don't add a block argument if it is "identical" to another block argument. I.e., if the two block arguments refer to the same Value. The operations operands in the block will point to the argument we already inserted. This needs to happen to all the arguments we pass to the different successors of the parent block
  • After merged the blocks, get rid of "unnecessary" arguments. I.e., if all the predecessors pass the same block argument, there is no need to pass it as an argument.
  • This last simplification clashed with BufferDeallocationSimplification. The reason, I think, is that the two simplifications are clashing. I.e., BufferDeallocationSimplification contains an analysis based on the block structure. If we simplify the block structure (by merging and/or dropping block arguments) the analysis is invalid . The solution I found is to do a more prudent simplification when running that pass.

Note: this a rework of #96871 . I ran all the integration tests (-DMLIR_INCLUDE_INTEGRATION_TESTS=ON) and they passed.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:linalg mlir mlir:bufferization Bufferization infrastructure labels Jul 4, 2024
@giuseros giuseros requested review from joker-eph and Mogball July 4, 2024 08:59
@llvmbot
Copy link
Member

llvmbot commented Jul 4, 2024

@llvm/pr-subscribers-mlir-bufferization
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Giuseppe Rossini (giuseros)

Changes

With this PR I am trying to address: #63230.

What changed:

  • While merging identical blocks, don't add a block argument if it is "identical" to another block argument. I.e., if the two block arguments refer to the same Value. The operations operands in the block will point to the argument we already inserted. This needs to happen to all the arguments we pass to the different successors of the parent block
  • After merged the blocks, get rid of "unnecessary" arguments. I.e., if all the predecessors pass the same block argument, there is no need to pass it as an argument.
  • This last simplification clashed with BufferDeallocationSimplification. The reason, I think, is that the two simplifications are clashing. I.e., BufferDeallocationSimplification contains an analysis based on the block structure. If we simplify the block structure (by merging and/or dropping block arguments) the analysis is invalid . The solution I found is to do a more prudent simplification when running that pass.

Note: this a rework of #96871 . I ran all the integration tests (-DMLIR_INCLUDE_INTEGRATION_TESTS=ON) and they passed.


Patch is 33.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97697.diff

12 Files Affected:

  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp (+7-2)
  • (modified) mlir/lib/Transforms/Utils/RegionUtils.cpp (+180-8)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir (+12-8)
  • (modified) mlir/test/Dialect/Linalg/detensorize_entry_block.mlir (+3-3)
  • (modified) mlir/test/Dialect/Linalg/detensorize_if.mlir (+29-38)
  • (modified) mlir/test/Dialect/Linalg/detensorize_while.mlir (+6-6)
  • (modified) mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir (+6-6)
  • (modified) mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir (+2-2)
  • (modified) mlir/test/Transforms/canonicalize-block-merge.mlir (+3-3)
  • (modified) mlir/test/Transforms/canonicalize-dce.mlir (+4-4)
  • (modified) mlir/test/Transforms/make-isolated-from-above.mlir (+9-9)
  • (added) mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir (+162)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 954485cfede3d..5227b22653eef 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -463,10 +463,15 @@ struct BufferDeallocationSimplificationPass
                  SplitDeallocWhenNotAliasingAnyOther,
                  RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
                                                                 analysis);
+    // We don't want that the block structure changes invalidating the
+    // `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
+    // region simplification
+    GreedyRewriteConfig config;
+    config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
     populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
 
-    if (failed(
-            applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+    if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
+                                            config)))
       signalPassFailure();
   }
 };
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 4c0f15bafbaba..8b742bada67b9 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Transforms/RegionUtils.h"
 #include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/IR/Block.h"
+#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
@@ -16,11 +17,15 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
 
 #include "llvm/ADT/DepthFirstIterator.h"
 #include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
 
 #include <deque>
+#include <iterator>
 
 using namespace mlir;
 
@@ -674,6 +679,64 @@ static bool ableToUpdatePredOperands(Block *block) {
   return true;
 }
 
+/// Prunes the redundant list of arguments. E.g., if we are passing an argument
+/// list like [x, y, z, x] this would return [x, y, z] and it would update the
+/// `block` (to whom the argument are passed to) accordingly
+static void
+pruneRedundantArguments(SmallVector<SmallVector<Value, 8>, 2> &newArguments,
+                        RewriterBase &rewriter, Block *block) {
+  SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned(
+      newArguments.size(), SmallVector<Value, 8>());
+
+  if (!newArguments.empty()) {
+    llvm::DenseMap<unsigned, unsigned> toReplace;
+    // Go through the first list of arguments (list 0)
+    for (unsigned j = 0; j < newArguments[0].size(); j++) {
+      bool shouldReplaceJ = false;
+      unsigned replacement = 0;
+      // Look back to see if there are possible redundancies in
+      // list 0
+      for (unsigned k = 0; k < j; k++) {
+        if (newArguments[0][k] == newArguments[0][j]) {
+          shouldReplaceJ = true;
+          replacement = k;
+          // If a possible redundancy is found, then scan the other lists: we
+          // can prune the arguments if and only if they are redundant in every
+          // list
+          for (unsigned i = 1; i < newArguments.size(); i++)
+            shouldReplaceJ =
+                shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
+        }
+      }
+      // Save the replacement
+      if (shouldReplaceJ)
+        toReplace[j] = replacement;
+    }
+
+    // Populate the pruned argument list
+    for (unsigned i = 0; i < newArguments.size(); i++)
+      for (unsigned j = 0; j < newArguments[i].size(); j++)
+        if (!toReplace.contains(j))
+          newArgumentsPruned[i].push_back(newArguments[i][j]);
+
+    // Replace the block's redundant arguments
+    SmallVector<unsigned> toErase;
+    for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
+      if (toReplace.contains(idx)) {
+        Value oldArg = block->getArgument(idx);
+        Value newArg = block->getArgument(toReplace[idx]);
+        rewriter.replaceAllUsesWith(oldArg, newArg);
+        toErase.push_back(idx);
+      }
+    }
+
+    // Erase the block's redundant arguments
+    for (auto idxToErase : llvm::reverse(toErase))
+      block->eraseArgument(idxToErase);
+    newArguments = newArgumentsPruned;
+  }
+}
+
 LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
   // Don't consider clusters that don't have blocks to merge.
   if (blocksToMerge.empty())
@@ -699,9 +762,8 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
       blockIterators.push_back(mergeBlock->begin());
 
     // Update each of the predecessor terminators with the new arguments.
-    SmallVector<SmallVector<Value, 8>, 2> newArguments(
-        1 + blocksToMerge.size(),
-        SmallVector<Value, 8>(operandsToMerge.size()));
+    SmallVector<SmallVector<Value, 8>, 2> newArguments(1 + blocksToMerge.size(),
+                                                       SmallVector<Value, 8>());
     unsigned curOpIndex = 0;
     for (const auto &it : llvm::enumerate(operandsToMerge)) {
       unsigned nextOpOffset = it.value().first - curOpIndex;
@@ -712,16 +774,19 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
         Block::iterator &blockIter = blockIterators[i];
         std::advance(blockIter, nextOpOffset);
         auto &operand = blockIter->getOpOperand(it.value().second);
-        newArguments[i][it.index()] = operand.get();
-
+        Value operandVal = operand.get();
+        newArguments[i].push_back(operandVal);
         // Update the operand and insert an argument if this is the leader.
         if (i == 0) {
-          Value operandVal = operand.get();
           operand.set(leaderBlock->addArgument(operandVal.getType(),
                                                operandVal.getLoc()));
         }
       }
     }
+
+    // Prune redundant arguments and update the leader block argument list
+    pruneRedundantArguments(newArguments, rewriter, leaderBlock);
+
     // Update the predecessors for each of the blocks.
     auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
       for (auto predIt = block->pred_begin(), predE = block->pred_end();
@@ -818,6 +883,109 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
   return success(anyChanged);
 }
 
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+                                            Block &block) {
+  SmallVector<size_t> argsToErase;
+
+  // Go through the arguments of the block
+  for (size_t argIdx = 0; argIdx < block.getNumArguments(); argIdx++) {
+    bool sameArg = true;
+    Value commonValue;
+
+    // Go through the block predecessor and flag if they pass to the block
+    // different values for the same argument
+    for (auto predIt = block.pred_begin(), predE = block.pred_end();
+         predIt != predE; ++predIt) {
+      auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
+      if (!branch) {
+        sameArg = false;
+        break;
+      }
+      unsigned succIndex = predIt.getSuccessorIndex();
+      SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+      auto operands = succOperands.getForwardedOperands();
+      if (!commonValue) {
+        commonValue = operands[argIdx];
+      } else {
+        if (operands[argIdx] != commonValue) {
+          sameArg = false;
+          break;
+        }
+      }
+    }
+
+    // If they are passing the same value, drop the argument
+    if (commonValue && sameArg) {
+      argsToErase.push_back(argIdx);
+
+      // Remove the argument from the block
+      Value argVal = block.getArgument(argIdx);
+      rewriter.replaceAllUsesWith(argVal, commonValue);
+    }
+  }
+
+  // Remove the arguments
+  for (auto argIdx : llvm::reverse(argsToErase)) {
+    block.eraseArgument(argIdx);
+
+    // Remove the argument from the branch ops
+    for (auto predIt = block.pred_begin(), predE = block.pred_end();
+         predIt != predE; ++predIt) {
+      auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
+      unsigned succIndex = predIt.getSuccessorIndex();
+      SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+      succOperands.erase(argIdx);
+    }
+  }
+  return success(!argsToErase.empty());
+}
+
+/// This optimization drops redundant argument to blocks. I.e., if a given
+/// argument to a block receives the same value from each of the block
+/// predecessors, we can remove the argument from the block and use directly the
+/// original value. This is a simple example:
+///
+/// %cond = llvm.call @rand() : () -> i1
+/// %val0 = llvm.mlir.constant(1 : i64) : i64
+/// %val1 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(3 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
+/// : i64)
+///
+/// ^bb1(%arg0 : i64, %arg1 : i64):
+///    llvm.call @foo(%arg0, %arg1)
+///
+/// The previous IR can be rewritten as:
+/// %cond = llvm.call @rand() : () -> i1
+/// %val0 = llvm.mlir.constant(1 : i64) : i64
+/// %val1 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(3 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
+///
+/// ^bb1(%arg0 : i64):
+///    llvm.call @foo(%val0, %arg0)
+///
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+                                            MutableArrayRef<Region> regions) {
+  llvm::SmallSetVector<Region *, 1> worklist;
+  for (auto &region : regions)
+    worklist.insert(&region);
+  bool anyChanged = false;
+  while (!worklist.empty()) {
+    Region *region = worklist.pop_back_val();
+
+    // Add any nested regions to the worklist.
+    for (Block &block : *region) {
+      anyChanged = succeeded(dropRedundantArguments(rewriter, block));
+
+      for (auto &op : block)
+        for (auto &nestedRegion : op.getRegions())
+          worklist.insert(&nestedRegion);
+    }
+  }
+  return success(anyChanged);
+}
+
 //===----------------------------------------------------------------------===//
 // Region Simplification
 //===----------------------------------------------------------------------===//
@@ -832,8 +1000,12 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
   bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
   bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
   bool mergedIdenticalBlocks = false;
-  if (mergeBlocks)
+  bool droppedRedundantArguments = false;
+  if (mergeBlocks) {
     mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
+    droppedRedundantArguments =
+        succeeded(dropRedundantArguments(rewriter, regions));
+  }
   return success(eliminatedBlocks || eliminatedOpsOrArgs ||
-                 mergedIdenticalBlocks);
+                 mergedIdenticalBlocks || droppedRedundantArguments);
 }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
index 5e8104f83cc4d..8e14990502143 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
@@ -178,7 +178,7 @@ func.func @condBranchDynamicTypeNested(
 //  CHECK-NEXT: ^bb1
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb5([[ARG1]], %false{{[0-9_]*}} :
+//       CHECK: cf.br ^bb6([[ARG1]], %false{{[0-9_]*}} :
 //       CHECK: ^bb2([[IDX:%.*]]:{{.*}})
 //       CHECK: [[ALLOC1:%.*]] = memref.alloc([[IDX]])
 //  CHECK-NEXT: test.buffer_based
@@ -186,20 +186,24 @@ func.func @condBranchDynamicTypeNested(
 //  CHECK-NEXT: [[OWN:%.+]] = arith.select [[ARG0]], [[ARG0]], [[NOT_ARG0]]
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.cond_br{{.*}}, ^bb3, ^bb3
+//       CHECK: cf.cond_br{{.*}}, ^bb3, ^bb4
 //  CHECK-NEXT: ^bb3:
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb4([[ALLOC1]], [[OWN]]
-//  CHECK-NEXT: ^bb4([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
+//       CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
+//  CHECK-NEXT: ^bb4:
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb5([[ALLOC2]], [[COND1]]
-//  CHECK-NEXT: ^bb5([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
+//       CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
+//  CHECK-NEXT: ^bb5([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
+//   CHECK-NOT: bufferization.dealloc
+//   CHECK-NOT: bufferization.clone
+//       CHECK: cf.br ^bb6([[ALLOC2]], [[COND1]]
+//  CHECK-NEXT: ^bb6([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
 //  CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC4]]
 //  CHECK-NEXT: [[OWN:%.+]]:2 = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[COND2]]) retain ([[ALLOC4]], [[ARG2]] :
-//       CHECK: cf.br ^bb6([[ALLOC4]], [[OWN]]#0
-//  CHECK-NEXT: ^bb6([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
+//       CHECK: cf.br ^bb7([[ALLOC4]], [[OWN]]#0
+//  CHECK-NEXT: ^bb7([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
 //       CHECK: test.copy
 //       CHECK: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC5]]
 //  CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND3]])
diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
index d1a89226fdb58..50a2d6bf532aa 100644
--- a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
@@ -15,7 +15,7 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
 // CHECK-LABEL: @main
 // CHECK-SAME:       (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
 // CHECK:   %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
-// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
-// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
-// CHECK:   %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
+// CHECK: cf.br ^{{.*}}
+// CHECK: ^{{.*}}:
+// CHECK:   %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32>
 // CHECK:   return %[[ELEMENTS]] : tensor<f32>
diff --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir
index 8d17763c04b6c..c728ad21d2209 100644
--- a/mlir/test/Dialect/Linalg/detensorize_if.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir
@@ -42,18 +42,15 @@ func.func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-DAG:     arith.constant 0
-// CHECK-DAG:     arith.constant 10
-// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT:     return %{{.*}}
+// CHECK-DAG:     %[[cst:.*]] = arith.constant dense<0>
+// CHECK-DAG:     arith.constant true
+// CHECK:         cf.br
+// CHECK-NEXT:   ^[[bb1:.*]]:
+// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
+// CHECK-NEXT:   ^[[bb2]]
+// CHECK-NEXT:     cf.br ^[[bb3:.*]]
+// CHECK-NEXT:   ^[[bb3]]
+// CHECK-NEXT:     return %[[cst]]
 // CHECK-NEXT:   }
 
 // -----
@@ -106,20 +103,17 @@ func.func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-DAG:     arith.constant 0
-// CHECK-DAG:     arith.constant 10
-// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT:     cf.br ^[[bb4:.*]](%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb4]](%{{.*}}: i32)
-// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT:     return %{{.*}}
+// CHECK-DAG:     %[[cst:.*]] = arith.constant dense<0>
+// CHECK-DAG:     arith.constant true
+// CHECK:         cf.br ^[[bb1:.*]]
+// CHECK-NEXT:   ^[[bb1:.*]]:
+// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
+// CHECK-NEXT:   ^[[bb2]]:
+// CHECK-NEXT:     cf.br ^[[bb3:.*]]
+// CHECK-NEXT:   ^[[bb3]]:
+// CHECK-NEXT:     cf.br ^[[bb4:.*]]
+// CHECK-NEXT:   ^[[bb4]]:
+// CHECK-NEXT:     return %[[cst]]
 // CHECK-NEXT:   }
 
 // -----
@@ -171,16 +165,13 @@ func.func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-DAG:     arith.constant 0
-// CHECK-DAG:     arith.constant 10
-// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb2(%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT:     return %{{.*}}
+// CHECK-DAG:     %[[cst:.*]] = arith.constant dense<10>
+// CHECK-DAG:     arith.constant true
+// CHECK:         cf.br ^[[bb1:.*]]
+// CHECK-NEXT:   ^[[bb1]]:
+// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb2
+// CHECK-NEXT:   ^[[bb2]]
+// CHECK-NEXT:     cf.br ^[[bb3:.*]]
+// CHECK-NEXT:   ^[[bb3]]
+// CHECK-NEXT:     return %[[cst]]
 // CHECK-NEXT:   }
diff --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir
index aa30900f76a33..580a97d3a851b 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir
@@ -46,11 +46,11 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
 // DET-ALL:         cf.br ^[[bb1:.*]](%{{.*}} : i32)
 // DET-ALL:       ^[[bb1]](%{{.*}}: i32)
 // DET-ALL:         arith.cmpi slt, {{.*}}
-// DET-ALL:         cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
-// DET-ALL:       ^[[bb2]](%{{.*}}: i32)
+// DET-ALL:         cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
+// DET-ALL:       ^[[bb2]]
 // DET-ALL:         arith.addi {{.*}}
 // DET-ALL:         cf.br ^[[bb1]](%{{.*}} : i32)
-// DET-ALL:       ^[[bb3]](%{{.*}}: i32)
+// DET-ALL:       ^[[bb3]]:
 // DET-ALL:         tensor.from_elements {{.*}}
 // DET-ALL:         return %{{.*}} : tensor<i32>
 
@@ -62,10 +62,10 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
 // DET-CF:         cf.br ^[[bb1:.*]](%{{.*}} : i32)
 // DET-CF:       ^[[bb1]](%{{.*}}: i32)
 // DET-CF:         arith.cmpi slt, {{.*}}
-// DET-CF:         cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
-// DET-CF:       ^[[bb2]](%{{.*}}: i32)
+// DET-CF:         cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
+// DET-CF:       ^[[bb2]]:
 // DET-CF:         arith.addi {{.*}}
 // DET-CF:         cf.br ^[[bb1]](%{{.*}} : i32)
-// DET-CF:       ^[[bb3]](%{{.*}}: i32)
+// DET-CF:       ^[[bb3]]:
 // DET-CF:         tensor.from_elements %{{.*}} : tensor<i32>
 // DET-CF:         return %{{.*}} : tensor<i32>
diff --git a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
index 955c7be5ef4c8..414d9b94cbf53 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
@@ -74,8 +74,8 @@ func.func @main(%farg0...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 4, 2024

@llvm/pr-subscribers-mlir-core

Author: Giuseppe Rossini (giuseros)

Changes

With this PR I am trying to address: #63230.

What changed:

  • While merging identical blocks, don't add a block argument if it is "identical" to another block argument. I.e., if the two block arguments refer to the same Value. The operations operands in the block will point to the argument we already inserted. This needs to happen to all the arguments we pass to the different successors of the parent block
  • After merged the blocks, get rid of "unnecessary" arguments. I.e., if all the predecessors pass the same block argument, there is no need to pass it as an argument.
  • This last simplification clashed with BufferDeallocationSimplification. The reason, I think, is that the two simplifications are clashing. I.e., BufferDeallocationSimplification contains an analysis based on the block structure. If we simplify the block structure (by merging and/or dropping block arguments) the analysis is invalid . The solution I found is to do a more prudent simplification when running that pass.

Note: this a rework of #96871 . I ran all the integration tests (-DMLIR_INCLUDE_INTEGRATION_TESTS=ON) and they passed.


Patch is 33.21 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97697.diff

12 Files Affected:

  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp (+7-2)
  • (modified) mlir/lib/Transforms/Utils/RegionUtils.cpp (+180-8)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir (+12-8)
  • (modified) mlir/test/Dialect/Linalg/detensorize_entry_block.mlir (+3-3)
  • (modified) mlir/test/Dialect/Linalg/detensorize_if.mlir (+29-38)
  • (modified) mlir/test/Dialect/Linalg/detensorize_while.mlir (+6-6)
  • (modified) mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir (+6-6)
  • (modified) mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir (+2-2)
  • (modified) mlir/test/Transforms/canonicalize-block-merge.mlir (+3-3)
  • (modified) mlir/test/Transforms/canonicalize-dce.mlir (+4-4)
  • (modified) mlir/test/Transforms/make-isolated-from-above.mlir (+9-9)
  • (added) mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir (+162)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 954485cfede3da..5227b22653eefc 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -463,10 +463,15 @@ struct BufferDeallocationSimplificationPass
                  SplitDeallocWhenNotAliasingAnyOther,
                  RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
                                                                 analysis);
+    // We don't want that the block structure changes invalidating the
+    // `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
+    // region simplification
+    GreedyRewriteConfig config;
+    config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
     populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
 
-    if (failed(
-            applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+    if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
+                                            config)))
       signalPassFailure();
   }
 };
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 4c0f15bafbaba3..8b742bada67b97 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Transforms/RegionUtils.h"
 #include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/IR/Block.h"
+#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
@@ -16,11 +17,15 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
 
 #include "llvm/ADT/DepthFirstIterator.h"
 #include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
 
 #include <deque>
+#include <iterator>
 
 using namespace mlir;
 
@@ -674,6 +679,64 @@ static bool ableToUpdatePredOperands(Block *block) {
   return true;
 }
 
+/// Prunes the redundant list of arguments. E.g., if we are passing an argument
+/// list like [x, y, z, x] this would return [x, y, z] and it would update the
+/// `block` (to whom the argument are passed to) accordingly
+static void
+pruneRedundantArguments(SmallVector<SmallVector<Value, 8>, 2> &newArguments,
+                        RewriterBase &rewriter, Block *block) {
+  SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned(
+      newArguments.size(), SmallVector<Value, 8>());
+
+  if (!newArguments.empty()) {
+    llvm::DenseMap<unsigned, unsigned> toReplace;
+    // Go through the first list of arguments (list 0)
+    for (unsigned j = 0; j < newArguments[0].size(); j++) {
+      bool shouldReplaceJ = false;
+      unsigned replacement = 0;
+      // Look back to see if there are possible redundancies in
+      // list 0
+      for (unsigned k = 0; k < j; k++) {
+        if (newArguments[0][k] == newArguments[0][j]) {
+          shouldReplaceJ = true;
+          replacement = k;
+          // If a possible redundancy is found, then scan the other lists: we
+          // can prune the arguments if and only if they are redundant in every
+          // list
+          for (unsigned i = 1; i < newArguments.size(); i++)
+            shouldReplaceJ =
+                shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
+        }
+      }
+      // Save the replacement
+      if (shouldReplaceJ)
+        toReplace[j] = replacement;
+    }
+
+    // Populate the pruned argument list
+    for (unsigned i = 0; i < newArguments.size(); i++)
+      for (unsigned j = 0; j < newArguments[i].size(); j++)
+        if (!toReplace.contains(j))
+          newArgumentsPruned[i].push_back(newArguments[i][j]);
+
+    // Replace the block's redundant arguments
+    SmallVector<unsigned> toErase;
+    for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
+      if (toReplace.contains(idx)) {
+        Value oldArg = block->getArgument(idx);
+        Value newArg = block->getArgument(toReplace[idx]);
+        rewriter.replaceAllUsesWith(oldArg, newArg);
+        toErase.push_back(idx);
+      }
+    }
+
+    // Erase the block's redundant arguments
+    for (auto idxToErase : llvm::reverse(toErase))
+      block->eraseArgument(idxToErase);
+    newArguments = newArgumentsPruned;
+  }
+}
+
 LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
   // Don't consider clusters that don't have blocks to merge.
   if (blocksToMerge.empty())
@@ -699,9 +762,8 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
       blockIterators.push_back(mergeBlock->begin());
 
     // Update each of the predecessor terminators with the new arguments.
-    SmallVector<SmallVector<Value, 8>, 2> newArguments(
-        1 + blocksToMerge.size(),
-        SmallVector<Value, 8>(operandsToMerge.size()));
+    SmallVector<SmallVector<Value, 8>, 2> newArguments(1 + blocksToMerge.size(),
+                                                       SmallVector<Value, 8>());
     unsigned curOpIndex = 0;
     for (const auto &it : llvm::enumerate(operandsToMerge)) {
       unsigned nextOpOffset = it.value().first - curOpIndex;
@@ -712,16 +774,19 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
         Block::iterator &blockIter = blockIterators[i];
         std::advance(blockIter, nextOpOffset);
         auto &operand = blockIter->getOpOperand(it.value().second);
-        newArguments[i][it.index()] = operand.get();
-
+        Value operandVal = operand.get();
+        newArguments[i].push_back(operandVal);
         // Update the operand and insert an argument if this is the leader.
         if (i == 0) {
-          Value operandVal = operand.get();
           operand.set(leaderBlock->addArgument(operandVal.getType(),
                                                operandVal.getLoc()));
         }
       }
     }
+
+    // Prune redundant arguments and update the leader block argument list
+    pruneRedundantArguments(newArguments, rewriter, leaderBlock);
+
     // Update the predecessors for each of the blocks.
     auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
       for (auto predIt = block->pred_begin(), predE = block->pred_end();
@@ -818,6 +883,109 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
   return success(anyChanged);
 }
 
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+                                            Block &block) {
+  SmallVector<size_t> argsToErase;
+
+  // Go through the arguments of the block
+  for (size_t argIdx = 0; argIdx < block.getNumArguments(); argIdx++) {
+    bool sameArg = true;
+    Value commonValue;
+
+    // Go through the block predecessor and flag if they pass to the block
+    // different values for the same argument
+    for (auto predIt = block.pred_begin(), predE = block.pred_end();
+         predIt != predE; ++predIt) {
+      auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
+      if (!branch) {
+        sameArg = false;
+        break;
+      }
+      unsigned succIndex = predIt.getSuccessorIndex();
+      SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+      auto operands = succOperands.getForwardedOperands();
+      if (!commonValue) {
+        commonValue = operands[argIdx];
+      } else {
+        if (operands[argIdx] != commonValue) {
+          sameArg = false;
+          break;
+        }
+      }
+    }
+
+    // If they are passing the same value, drop the argument
+    if (commonValue && sameArg) {
+      argsToErase.push_back(argIdx);
+
+      // Remove the argument from the block
+      Value argVal = block.getArgument(argIdx);
+      rewriter.replaceAllUsesWith(argVal, commonValue);
+    }
+  }
+
+  // Remove the arguments
+  for (auto argIdx : llvm::reverse(argsToErase)) {
+    block.eraseArgument(argIdx);
+
+    // Remove the argument from the branch ops
+    for (auto predIt = block.pred_begin(), predE = block.pred_end();
+         predIt != predE; ++predIt) {
+      auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
+      unsigned succIndex = predIt.getSuccessorIndex();
+      SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+      succOperands.erase(argIdx);
+    }
+  }
+  return success(!argsToErase.empty());
+}
+
+/// This optimization drops redundant argument to blocks. I.e., if a given
+/// argument to a block receives the same value from each of the block
+/// predecessors, we can remove the argument from the block and use directly the
+/// original value. This is a simple example:
+///
+/// %cond = llvm.call @rand() : () -> i1
+/// %val0 = llvm.mlir.constant(1 : i64) : i64
+/// %val1 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(3 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
+/// : i64)
+///
+/// ^bb1(%arg0 : i64, %arg1 : i64):
+///    llvm.call @foo(%arg0, %arg1)
+///
+/// The previous IR can be rewritten as:
+/// %cond = llvm.call @rand() : () -> i1
+/// %val0 = llvm.mlir.constant(1 : i64) : i64
+/// %val1 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(3 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
+///
+/// ^bb1(%arg0 : i64):
+///    llvm.call @foo(%val0, %arg0)
+///
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+                                            MutableArrayRef<Region> regions) {
+  llvm::SmallSetVector<Region *, 1> worklist;
+  for (auto &region : regions)
+    worklist.insert(&region);
+  bool anyChanged = false;
+  while (!worklist.empty()) {
+    Region *region = worklist.pop_back_val();
+
+    // Add any nested regions to the worklist.
+    for (Block &block : *region) {
+      anyChanged = succeeded(dropRedundantArguments(rewriter, block));
+
+      for (auto &op : block)
+        for (auto &nestedRegion : op.getRegions())
+          worklist.insert(&nestedRegion);
+    }
+  }
+  return success(anyChanged);
+}
+
 //===----------------------------------------------------------------------===//
 // Region Simplification
 //===----------------------------------------------------------------------===//
@@ -832,8 +1000,12 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
   bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
   bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
   bool mergedIdenticalBlocks = false;
-  if (mergeBlocks)
+  bool droppedRedundantArguments = false;
+  if (mergeBlocks) {
     mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
+    droppedRedundantArguments =
+        succeeded(dropRedundantArguments(rewriter, regions));
+  }
   return success(eliminatedBlocks || eliminatedOpsOrArgs ||
-                 mergedIdenticalBlocks);
+                 mergedIdenticalBlocks || droppedRedundantArguments);
 }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
index 5e8104f83cc4d4..8e14990502143e 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
@@ -178,7 +178,7 @@ func.func @condBranchDynamicTypeNested(
 //  CHECK-NEXT: ^bb1
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb5([[ARG1]], %false{{[0-9_]*}} :
+//       CHECK: cf.br ^bb6([[ARG1]], %false{{[0-9_]*}} :
 //       CHECK: ^bb2([[IDX:%.*]]:{{.*}})
 //       CHECK: [[ALLOC1:%.*]] = memref.alloc([[IDX]])
 //  CHECK-NEXT: test.buffer_based
@@ -186,20 +186,24 @@ func.func @condBranchDynamicTypeNested(
 //  CHECK-NEXT: [[OWN:%.+]] = arith.select [[ARG0]], [[ARG0]], [[NOT_ARG0]]
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.cond_br{{.*}}, ^bb3, ^bb3
+//       CHECK: cf.cond_br{{.*}}, ^bb3, ^bb4
 //  CHECK-NEXT: ^bb3:
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb4([[ALLOC1]], [[OWN]]
-//  CHECK-NEXT: ^bb4([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
+//       CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
+//  CHECK-NEXT: ^bb4:
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb5([[ALLOC2]], [[COND1]]
-//  CHECK-NEXT: ^bb5([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
+//       CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
+//  CHECK-NEXT: ^bb5([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
+//   CHECK-NOT: bufferization.dealloc
+//   CHECK-NOT: bufferization.clone
+//       CHECK: cf.br ^bb6([[ALLOC2]], [[COND1]]
+//  CHECK-NEXT: ^bb6([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
 //  CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC4]]
 //  CHECK-NEXT: [[OWN:%.+]]:2 = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[COND2]]) retain ([[ALLOC4]], [[ARG2]] :
-//       CHECK: cf.br ^bb6([[ALLOC4]], [[OWN]]#0
-//  CHECK-NEXT: ^bb6([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
+//       CHECK: cf.br ^bb7([[ALLOC4]], [[OWN]]#0
+//  CHECK-NEXT: ^bb7([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
 //       CHECK: test.copy
 //       CHECK: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC5]]
 //  CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND3]])
diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
index d1a89226fdb58f..50a2d6bf532aa3 100644
--- a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
@@ -15,7 +15,7 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
 // CHECK-LABEL: @main
 // CHECK-SAME:       (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
 // CHECK:   %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
-// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
-// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
-// CHECK:   %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
+// CHECK: cf.br ^{{.*}}
+// CHECK: ^{{.*}}:
+// CHECK:   %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32>
 // CHECK:   return %[[ELEMENTS]] : tensor<f32>
diff --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir
index 8d17763c04b6c4..c728ad21d2209b 100644
--- a/mlir/test/Dialect/Linalg/detensorize_if.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir
@@ -42,18 +42,15 @@ func.func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-DAG:     arith.constant 0
-// CHECK-DAG:     arith.constant 10
-// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT:     return %{{.*}}
+// CHECK-DAG:     %[[cst:.*]] = arith.constant dense<0>
+// CHECK-DAG:     arith.constant true
+// CHECK:         cf.br
+// CHECK-NEXT:   ^[[bb1:.*]]:
+// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
+// CHECK-NEXT:   ^[[bb2]]
+// CHECK-NEXT:     cf.br ^[[bb3:.*]]
+// CHECK-NEXT:   ^[[bb3]]
+// CHECK-NEXT:     return %[[cst]]
 // CHECK-NEXT:   }
 
 // -----
@@ -106,20 +103,17 @@ func.func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-DAG:     arith.constant 0
-// CHECK-DAG:     arith.constant 10
-// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT:     cf.br ^[[bb4:.*]](%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb4]](%{{.*}}: i32)
-// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT:     return %{{.*}}
+// CHECK-DAG:     %[[cst:.*]] = arith.constant dense<0>
+// CHECK-DAG:     arith.constant true
+// CHECK:         cf.br ^[[bb1:.*]]
+// CHECK-NEXT:   ^[[bb1:.*]]:
+// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
+// CHECK-NEXT:   ^[[bb2]]:
+// CHECK-NEXT:     cf.br ^[[bb3:.*]]
+// CHECK-NEXT:   ^[[bb3]]:
+// CHECK-NEXT:     cf.br ^[[bb4:.*]]
+// CHECK-NEXT:   ^[[bb4]]:
+// CHECK-NEXT:     return %[[cst]]
 // CHECK-NEXT:   }
 
 // -----
@@ -171,16 +165,13 @@ func.func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-DAG:     arith.constant 0
-// CHECK-DAG:     arith.constant 10
-// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb2(%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT:     return %{{.*}}
+// CHECK-DAG:     %[[cst:.*]] = arith.constant dense<10>
+// CHECK-DAG:     arith.constant true
+// CHECK:         cf.br ^[[bb1:.*]]
+// CHECK-NEXT:   ^[[bb1]]:
+// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb2
+// CHECK-NEXT:   ^[[bb2]]
+// CHECK-NEXT:     cf.br ^[[bb3:.*]]
+// CHECK-NEXT:   ^[[bb3]]
+// CHECK-NEXT:     return %[[cst]]
 // CHECK-NEXT:   }
diff --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir
index aa30900f76a334..580a97d3a851ba 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir
@@ -46,11 +46,11 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
 // DET-ALL:         cf.br ^[[bb1:.*]](%{{.*}} : i32)
 // DET-ALL:       ^[[bb1]](%{{.*}}: i32)
 // DET-ALL:         arith.cmpi slt, {{.*}}
-// DET-ALL:         cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
-// DET-ALL:       ^[[bb2]](%{{.*}}: i32)
+// DET-ALL:         cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
+// DET-ALL:       ^[[bb2]]
 // DET-ALL:         arith.addi {{.*}}
 // DET-ALL:         cf.br ^[[bb1]](%{{.*}} : i32)
-// DET-ALL:       ^[[bb3]](%{{.*}}: i32)
+// DET-ALL:       ^[[bb3]]:
 // DET-ALL:         tensor.from_elements {{.*}}
 // DET-ALL:         return %{{.*}} : tensor<i32>
 
@@ -62,10 +62,10 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
 // DET-CF:         cf.br ^[[bb1:.*]](%{{.*}} : i32)
 // DET-CF:       ^[[bb1]](%{{.*}}: i32)
 // DET-CF:         arith.cmpi slt, {{.*}}
-// DET-CF:         cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
-// DET-CF:       ^[[bb2]](%{{.*}}: i32)
+// DET-CF:         cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
+// DET-CF:       ^[[bb2]]:
 // DET-CF:         arith.addi {{.*}}
 // DET-CF:         cf.br ^[[bb1]](%{{.*}} : i32)
-// DET-CF:       ^[[bb3]](%{{.*}}: i32)
+// DET-CF:       ^[[bb3]]:
 // DET-CF:         tensor.from_elements %{{.*}} : tensor<i32>
 // DET-CF:         return %{{.*}} : tensor<i32>
diff --git a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
index 955c7be5ef4c89..414d9b94cbf530 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
@@ -74,8 +74,8 @@ func.fun...
[truncated]

@giuseros
Copy link
Contributor Author

giuseros commented Jul 4, 2024

Hi @joker-eph , @Mogball ,
I reworked the code and fixed the bug. What was happening was that in a test like the following:

  llvm.cond_br %cond, ^bb1, ^bb2
^bb1:
  llvm.call @foo(%0) : (i64) -> ()
  llvm.call @foo(%0) : (i64) -> ()
  llvm.br ^bb3
^bb2:
  llvm.call @foo(%1) : (i64) -> ()
  llvm.call @foo(%2) : (i64) -> ()
  llvm.br ^bb3
^bb3:
  llvm.return

We were forming argument lists of different lengths (%0) for bb1 and (%1,%2) for bb2, and this was returning a malformed IR. Now I do the argument pruning as a successive step and only prune an argument if it is prunable in every successor. I added this (and some more cases) as unit tests.

I compiled with -DMLIR_INCLUDE_INTEGRATION_TESTS=ON and everything passed. Is there any more testing I can enable?

Thanks!

@giuseros giuseros force-pushed the improve_block_merging_2 branch from 1f153c3 to c694caa Compare July 4, 2024 09:10
@giuseros giuseros force-pushed the improve_block_merging_2 branch from c694caa to 249a387 Compare July 4, 2024 09:11
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some comments on the coding style, I haven't reviewed the logic

@kuhar kuhar changed the title Fix block merging [mlir] Fix block merging Jul 8, 2024
@giuseros
Copy link
Contributor Author

giuseros commented Jul 9, 2024

Thanks for the review @kuhar !

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fine to me but I haven't had cycles to review the logic in full. You may need to find another reviewer for that.

@giuseros
Copy link
Contributor Author

Hi @krzysz00 , would you be able to have another look? Thanks!

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine now

@giuseros
Copy link
Contributor Author

Thanks @krzysz00!

@giuseros giuseros merged commit c63125d into llvm:main Jul 17, 2024
7 checks passed
ScottTodd pushed a commit to iree-org/iree that referenced this pull request Jul 19, 2024
)

Bumps llvm-project to:
llvm/llvm-project@15495b8
- Updated `refine_usage.mlir` and `flatten_tuples_in_cfg.mlir` tests
likely due to llvm/llvm-project#97697

Still carrying revert: 97c0dbe1ad6dacbcca84e63e9d726b85b65af4fe

(TODO: bump torch-mlir and update
to bumped submodule)

---------

Signed-off-by: aviator19941 <avinash.sharma@amd.com>
@Dinistro
Copy link
Contributor

Dinistro commented Jul 24, 2024

This seems to contain a bug. After this patch, the following IR example explodes when running canonicalizations with aggressive region simplifications:

  llvm.func @reproducer() {
    %0 = llvm.mlir.zero : !llvm.ptr
    %1 = llvm.mlir.constant(false) : i1
    %2 = llvm.mlir.constant(true) : i1
    llvm.cond_br %1, ^bb7(%0 : !llvm.ptr), ^bb1(%0 : !llvm.ptr)
  ^bb1(%3: !llvm.ptr):  // pred: ^bb0
    llvm.store %3, %0 {alignment = 8 : i64} : !llvm.ptr, !llvm.ptr
    llvm.cond_br %1, ^bb2(%2 : i1), ^bb4(%2 : i1)
  ^bb2(%4: i1):  // 2 preds: ^bb1, ^bb7
    llvm.cond_br %1, ^bb3(%4 : i1), ^bb4(%4 : i1)
  ^bb3(%5: i1):  // pred: ^bb2
    llvm.br ^bb4(%5 : i1)
  ^bb4(%6: i1):  // 4 preds: ^bb1, ^bb2, ^bb3, ^bb7
    llvm.cond_br %6, ^bb6, ^bb5
  ^bb5:  // pred: ^bb4
    llvm.br ^bb6
  ^bb6:  // 2 preds: ^bb4, ^bb5
    llvm.br ^bb8
  ^bb7(%7: !llvm.ptr):  // pred: ^bb0
    llvm.store %7, %0 {alignment = 8 : i64} : !llvm.ptr, !llvm.ptr
    llvm.cond_br %1, ^bb2(%1 : i1), ^bb4(%1 : i1)
  ^bb8:  // pred: ^bb6
    llvm.return
  }

It seems that there is a type mismatch between branch operands and block arguments.

Dinistro added a commit that referenced this pull request Jul 25, 2024
@Dinistro
Copy link
Contributor

Opened a revert PR, as this breaks downstream projects and a forward fix does not seem obvious to me.

Dinistro added a commit that referenced this pull request Jul 25, 2024
Reverts #97697

This commit introduced non-trivial bugs related to type consistency.
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
Summary:
With this PR I am trying to address:
#63230.

What changed:
- While merging identical blocks, don't add a block argument if it is
"identical" to another block argument. I.e., if the two block arguments
refer to the same `Value`. The operations operands in the block will
point to the argument we already inserted. This needs to happen to all
the arguments we pass to the different successors of the parent block
- After merged the blocks, get rid of "unnecessary" arguments. I.e., if
all the predecessors pass the same block argument, there is no need to
pass it as an argument.
- This last simplification clashed with
`BufferDeallocationSimplification`. The reason, I think, is that the two
simplifications are clashing. I.e., `BufferDeallocationSimplification`
contains an analysis based on the block structure. If we simplify the
block structure (by merging and/or dropping block arguments) the
analysis is invalid . The solution I found is to do a more prudent
simplification when running that pass.

**Note**: this a rework of #96871 . I ran all the integration tests
(`-DMLIR_INCLUDE_INTEGRATION_TESTS=ON`) and they passed.

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60250916
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
Summary:
Reverts #97697

This commit introduced non-trivial bugs related to type consistency.

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60250716
LLITCHEV pushed a commit to LLITCHEV/iree that referenced this pull request Jul 30, 2024
…e-org#17946)

Bumps llvm-project to:
llvm/llvm-project@15495b8
- Updated `refine_usage.mlir` and `flatten_tuples_in_cfg.mlir` tests
likely due to llvm/llvm-project#97697

Still carrying revert: 97c0dbe1ad6dacbcca84e63e9d726b85b65af4fe

(TODO: bump torch-mlir and update
to bumped submodule)

---------

Signed-off-by: aviator19941 <avinash.sharma@amd.com>
Signed-off-by: Lubo Litchev <lubol@google.com>
@giuseros
Copy link
Contributor Author

giuseros commented Aug 5, 2024

Hi @Dinistro , I re-created the PR (+the fix) here: #102038

giuseros added a commit that referenced this pull request Aug 7, 2024
With this PR I am trying to address:
#63230.

What changed:
- While merging identical blocks, don't add a block argument if it is
"identical" to another block argument. I.e., if the two block arguments
refer to the same `Value`. The operations operands in the block will
point to the argument we already inserted. This needs to happen to all
the arguments we pass to the different successors of the parent block
- After merged the blocks, get rid of "unnecessary" arguments. I.e., if
all the predecessors pass the same block argument, there is no need to
pass it as an argument.
- This last simplification clashed with
`BufferDeallocationSimplification`. The reason, I think, is that the two
simplifications are clashing. I.e., `BufferDeallocationSimplification`
contains an analysis based on the block structure. If we simplify the
block structure (by merging and/or dropping block arguments) the
analysis is invalid . The solution I found is to do a more prudent
simplification when running that pass.

**Note-1**: I ran all the integration tests
(`-DMLIR_INCLUDE_INTEGRATION_TESTS=ON`) and they passed.
**Note-2**: I fixed a bug found by @Dinistro in #97697 . The issue was
that, when looking for redundant arguments, I was not considering that
the block might have already some arguments. So the index (in the block
args list) of the i-th `newArgument` is `i+numOfOldArguments`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir:core MLIR Core Infrastructure mlir:linalg mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants