Skip to content

Revert "[mlir] Fix block merging" #100510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 25, 2024
Merged

Conversation

Dinistro
Copy link
Contributor

Reverts #97697

This commit introduced non-trivial bugs related to type consistency.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:linalg mlir mlir:bufferization Bufferization infrastructure labels Jul 25, 2024
@Dinistro Dinistro requested review from giuseros and removed request for matthias-springer July 25, 2024 05:43
@llvmbot
Copy link
Member

llvmbot commented Jul 25, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Christian Ulmann (Dinistro)

Changes

Reverts llvm/llvm-project#97697

This commit introduced non-trivial bugs related to type consistency.


Patch is 33.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/100510.diff

12 Files Affected:

  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp (+2-7)
  • (modified) mlir/lib/Transforms/Utils/RegionUtils.cpp (+2-202)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir (+8-12)
  • (modified) mlir/test/Dialect/Linalg/detensorize_entry_block.mlir (+3-3)
  • (modified) mlir/test/Dialect/Linalg/detensorize_if.mlir (+38-29)
  • (modified) mlir/test/Dialect/Linalg/detensorize_while.mlir (+6-6)
  • (modified) mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir (+6-6)
  • (modified) mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir (+2-2)
  • (modified) mlir/test/Transforms/canonicalize-block-merge.mlir (+3-3)
  • (modified) mlir/test/Transforms/canonicalize-dce.mlir (+4-4)
  • (modified) mlir/test/Transforms/make-isolated-from-above.mlir (+9-9)
  • (removed) mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir (-162)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 5227b22653eef..954485cfede3d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -463,15 +463,10 @@ struct BufferDeallocationSimplificationPass
                  SplitDeallocWhenNotAliasingAnyOther,
                  RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
                                                                 analysis);
-    // We don't want that the block structure changes invalidating the
-    // `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
-    // region simplification
-    GreedyRewriteConfig config;
-    config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
     populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
 
-    if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
-                                            config)))
+    if (failed(
+            applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
       signalPassFailure();
   }
 };
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 946d65cef4186..4c0f15bafbaba 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -9,7 +9,6 @@
 #include "mlir/Transforms/RegionUtils.h"
 #include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/IR/Block.h"
-#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
@@ -17,15 +16,11 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Support/LogicalResult.h"
 
 #include "llvm/ADT/DepthFirstIterator.h"
 #include "llvm/ADT/PostOrderIterator.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallSet.h"
 
 #include <deque>
-#include <iterator>
 
 using namespace mlir;
 
@@ -679,91 +674,6 @@ static bool ableToUpdatePredOperands(Block *block) {
   return true;
 }
 
-/// Prunes the redundant list of arguments. E.g., if we are passing an argument
-/// list like [x, y, z, x] this would return [x, y, z] and it would update the
-/// `block` (to whom the argument are passed to) accordingly.
-static SmallVector<SmallVector<Value, 8>, 2> pruneRedundantArguments(
-    const SmallVector<SmallVector<Value, 8>, 2> &newArguments,
-    RewriterBase &rewriter, Block *block) {
-
-  SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned(
-      newArguments.size(), SmallVector<Value, 8>());
-
-  if (newArguments.empty())
-    return newArguments;
-
-  // `newArguments` is a 2D array of size `numLists` x `numArgs`
-  unsigned numLists = newArguments.size();
-  unsigned numArgs = newArguments[0].size();
-
-  // Map that for each arg index contains the index that we can use in place of
-  // the original index. E.g., if we have newArgs = [x, y, z, x], we will have
-  // idxToReplacement[3] = 0
-  llvm::DenseMap<unsigned, unsigned> idxToReplacement;
-
-  // This is a useful data structure to track the first appearance of a Value
-  // on a given list of arguments
-  DenseMap<Value, unsigned> firstValueToIdx;
-  for (unsigned j = 0; j < numArgs; ++j) {
-    Value newArg = newArguments[0][j];
-    if (!firstValueToIdx.contains(newArg))
-      firstValueToIdx[newArg] = j;
-  }
-
-  // Go through the first list of arguments (list 0).
-  for (unsigned j = 0; j < numArgs; ++j) {
-    bool shouldReplaceJ = false;
-    unsigned replacement = 0;
-    // Look back to see if there are possible redundancies in list 0. Please
-    // note that we are using a map to annotate when an argument was seen first
-    // to avoid a O(N^2) algorithm. This has the drawback that if we have two
-    // lists like:
-    // list0: [%a, %a, %a]
-    // list1: [%c, %b, %b]
-    // We cannot simplify it, because firstVlaueToIdx[%a] = 0, but we cannot
-    // point list1[1](==%b) or list1[2](==%b) to list1[0](==%c).  However, since
-    // the number of arguments can be potentially unbounded we cannot afford a
-    // O(N^2) algorithm (to search to all the possible pairs) and we need to
-    // accept the trade-off.
-    unsigned k = firstValueToIdx[newArguments[0][j]];
-    if (k != j) {
-      shouldReplaceJ = true;
-      replacement = k;
-      // If a possible redundancy is found, then scan the other lists: we
-      // can prune the arguments if and only if they are redundant in every
-      // list.
-      for (unsigned i = 1; i < numLists; ++i)
-        shouldReplaceJ =
-            shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
-    }
-    // Save the replacement.
-    if (shouldReplaceJ)
-      idxToReplacement[j] = replacement;
-  }
-
-  // Populate the pruned argument list.
-  for (unsigned i = 0; i < numLists; ++i)
-    for (unsigned j = 0; j < numArgs; ++j)
-      if (!idxToReplacement.contains(j))
-        newArgumentsPruned[i].push_back(newArguments[i][j]);
-
-  // Replace the block's redundant arguments.
-  SmallVector<unsigned> toErase;
-  for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
-    if (idxToReplacement.contains(idx)) {
-      Value oldArg = block->getArgument(idx);
-      Value newArg = block->getArgument(idxToReplacement[idx]);
-      rewriter.replaceAllUsesWith(oldArg, newArg);
-      toErase.push_back(idx);
-    }
-  }
-
-  // Erase the block's redundant arguments.
-  for (unsigned idxToErase : llvm::reverse(toErase))
-    block->eraseArgument(idxToErase);
-  return newArgumentsPruned;
-}
-
 LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
   // Don't consider clusters that don't have blocks to merge.
   if (blocksToMerge.empty())
@@ -812,10 +722,6 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
         }
       }
     }
-
-    // Prune redundant arguments and update the leader block argument list
-    newArguments = pruneRedundantArguments(newArguments, rewriter, leaderBlock);
-
     // Update the predecessors for each of the blocks.
     auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
       for (auto predIt = block->pred_begin(), predE = block->pred_end();
@@ -912,108 +818,6 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
   return success(anyChanged);
 }
 
-static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
-                                            Block &block) {
-  SmallVector<size_t> argsToErase;
-
-  // Go through the arguments of the block.
-  for (auto [argIdx, blockOperand] : llvm::enumerate(block.getArguments())) {
-    bool sameArg = true;
-    Value commonValue;
-
-    // Go through the block predecessor and flag if they pass to the block
-    // different values for the same argument.
-    for (auto predIt = block.pred_begin(), predE = block.pred_end();
-         predIt != predE; ++predIt) {
-      auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
-      if (!branch) {
-        sameArg = false;
-        break;
-      }
-      unsigned succIndex = predIt.getSuccessorIndex();
-      SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
-      auto branchOperands = succOperands.getForwardedOperands();
-      if (!commonValue) {
-        commonValue = branchOperands[argIdx];
-      } else {
-        if (branchOperands[argIdx] != commonValue) {
-          sameArg = false;
-          break;
-        }
-      }
-    }
-
-    // If they are passing the same value, drop the argument.
-    if (commonValue && sameArg) {
-      argsToErase.push_back(argIdx);
-
-      // Remove the argument from the block.
-      rewriter.replaceAllUsesWith(blockOperand, commonValue);
-    }
-  }
-
-  // Remove the arguments.
-  for (auto argIdx : llvm::reverse(argsToErase)) {
-    block.eraseArgument(argIdx);
-
-    // Remove the argument from the branch ops.
-    for (auto predIt = block.pred_begin(), predE = block.pred_end();
-         predIt != predE; ++predIt) {
-      auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
-      unsigned succIndex = predIt.getSuccessorIndex();
-      SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
-      succOperands.erase(argIdx);
-    }
-  }
-  return success(!argsToErase.empty());
-}
-
-/// This optimization drops redundant argument to blocks. I.e., if a given
-/// argument to a block receives the same value from each of the block
-/// predecessors, we can remove the argument from the block and use directly the
-/// original value. This is a simple example:
-///
-/// %cond = llvm.call @rand() : () -> i1
-/// %val0 = llvm.mlir.constant(1 : i64) : i64
-/// %val1 = llvm.mlir.constant(2 : i64) : i64
-/// %val2 = llvm.mlir.constant(3 : i64) : i64
-/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
-/// : i64)
-///
-/// ^bb1(%arg0 : i64, %arg1 : i64):
-///    llvm.call @foo(%arg0, %arg1)
-///
-/// The previous IR can be rewritten as:
-/// %cond = llvm.call @rand() : () -> i1
-/// %val0 = llvm.mlir.constant(1 : i64) : i64
-/// %val1 = llvm.mlir.constant(2 : i64) : i64
-/// %val2 = llvm.mlir.constant(3 : i64) : i64
-/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
-///
-/// ^bb1(%arg0 : i64):
-///    llvm.call @foo(%val0, %arg0)
-///
-static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
-                                            MutableArrayRef<Region> regions) {
-  llvm::SmallSetVector<Region *, 1> worklist;
-  for (Region &region : regions)
-    worklist.insert(&region);
-  bool anyChanged = false;
-  while (!worklist.empty()) {
-    Region *region = worklist.pop_back_val();
-
-    // Add any nested regions to the worklist.
-    for (Block &block : *region) {
-      anyChanged = succeeded(dropRedundantArguments(rewriter, block));
-
-      for (Operation &op : block)
-        for (Region &nestedRegion : op.getRegions())
-          worklist.insert(&nestedRegion);
-    }
-  }
-  return success(anyChanged);
-}
-
 //===----------------------------------------------------------------------===//
 // Region Simplification
 //===----------------------------------------------------------------------===//
@@ -1028,12 +832,8 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
   bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
   bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
   bool mergedIdenticalBlocks = false;
-  bool droppedRedundantArguments = false;
-  if (mergeBlocks) {
+  if (mergeBlocks)
     mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
-    droppedRedundantArguments =
-        succeeded(dropRedundantArguments(rewriter, regions));
-  }
   return success(eliminatedBlocks || eliminatedOpsOrArgs ||
-                 mergedIdenticalBlocks || droppedRedundantArguments);
+                 mergedIdenticalBlocks);
 }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
index 8e14990502143..5e8104f83cc4d 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
@@ -178,7 +178,7 @@ func.func @condBranchDynamicTypeNested(
 //  CHECK-NEXT: ^bb1
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb6([[ARG1]], %false{{[0-9_]*}} :
+//       CHECK: cf.br ^bb5([[ARG1]], %false{{[0-9_]*}} :
 //       CHECK: ^bb2([[IDX:%.*]]:{{.*}})
 //       CHECK: [[ALLOC1:%.*]] = memref.alloc([[IDX]])
 //  CHECK-NEXT: test.buffer_based
@@ -186,24 +186,20 @@ func.func @condBranchDynamicTypeNested(
 //  CHECK-NEXT: [[OWN:%.+]] = arith.select [[ARG0]], [[ARG0]], [[NOT_ARG0]]
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.cond_br{{.*}}, ^bb3, ^bb4
+//       CHECK: cf.cond_br{{.*}}, ^bb3, ^bb3
 //  CHECK-NEXT: ^bb3:
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
-//  CHECK-NEXT: ^bb4:
+//       CHECK: cf.br ^bb4([[ALLOC1]], [[OWN]]
+//  CHECK-NEXT: ^bb4([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
-//  CHECK-NEXT: ^bb5([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
-//   CHECK-NOT: bufferization.dealloc
-//   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb6([[ALLOC2]], [[COND1]]
-//  CHECK-NEXT: ^bb6([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
+//       CHECK: cf.br ^bb5([[ALLOC2]], [[COND1]]
+//  CHECK-NEXT: ^bb5([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
 //  CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC4]]
 //  CHECK-NEXT: [[OWN:%.+]]:2 = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[COND2]]) retain ([[ALLOC4]], [[ARG2]] :
-//       CHECK: cf.br ^bb7([[ALLOC4]], [[OWN]]#0
-//  CHECK-NEXT: ^bb7([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
+//       CHECK: cf.br ^bb6([[ALLOC4]], [[OWN]]#0
+//  CHECK-NEXT: ^bb6([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
 //       CHECK: test.copy
 //       CHECK: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC5]]
 //  CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND3]])
diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
index 50a2d6bf532aa..d1a89226fdb58 100644
--- a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
@@ -15,7 +15,7 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
 // CHECK-LABEL: @main
 // CHECK-SAME:       (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
 // CHECK:   %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
-// CHECK: cf.br ^{{.*}}
-// CHECK: ^{{.*}}:
-// CHECK:   %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32>
+// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
+// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
+// CHECK:   %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
 // CHECK:   return %[[ELEMENTS]] : tensor<f32>
diff --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir
index c728ad21d2209..8d17763c04b6c 100644
--- a/mlir/test/Dialect/Linalg/detensorize_if.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir
@@ -42,15 +42,18 @@ func.func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-DAG:     %[[cst:.*]] = arith.constant dense<0>
-// CHECK-DAG:     arith.constant true
-// CHECK:         cf.br
-// CHECK-NEXT:   ^[[bb1:.*]]:
-// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
-// CHECK-NEXT:   ^[[bb2]]
-// CHECK-NEXT:     cf.br ^[[bb3:.*]]
-// CHECK-NEXT:   ^[[bb3]]
-// CHECK-NEXT:     return %[[cst]]
+// CHECK-DAG:     arith.constant 0
+// CHECK-DAG:     arith.constant 10
+// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
+// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
+// CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
+// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
+// CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
+// CHECK-NEXT:     cf.br ^[[bb3:.*]](%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
+// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
+// CHECK-NEXT:     return %{{.*}}
 // CHECK-NEXT:   }
 
 // -----
@@ -103,17 +106,20 @@ func.func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-DAG:     %[[cst:.*]] = arith.constant dense<0>
-// CHECK-DAG:     arith.constant true
-// CHECK:         cf.br ^[[bb1:.*]]
-// CHECK-NEXT:   ^[[bb1:.*]]:
-// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
-// CHECK-NEXT:   ^[[bb2]]:
-// CHECK-NEXT:     cf.br ^[[bb3:.*]]
-// CHECK-NEXT:   ^[[bb3]]:
-// CHECK-NEXT:     cf.br ^[[bb4:.*]]
-// CHECK-NEXT:   ^[[bb4]]:
-// CHECK-NEXT:     return %[[cst]]
+// CHECK-DAG:     arith.constant 0
+// CHECK-DAG:     arith.constant 10
+// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
+// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
+// CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
+// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
+// CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
+// CHECK-NEXT:     cf.br ^[[bb3:.*]](%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
+// CHECK-NEXT:     cf.br ^[[bb4:.*]](%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb4]](%{{.*}}: i32)
+// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
+// CHECK-NEXT:     return %{{.*}}
 // CHECK-NEXT:   }
 
 // -----
@@ -165,13 +171,16 @@ func.func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-DAG:     %[[cst:.*]] = arith.constant dense<10>
-// CHECK-DAG:     arith.constant true
-// CHECK:         cf.br ^[[bb1:.*]]
-// CHECK-NEXT:   ^[[bb1]]:
-// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb2
-// CHECK-NEXT:   ^[[bb2]]
-// CHECK-NEXT:     cf.br ^[[bb3:.*]]
-// CHECK-NEXT:   ^[[bb3]]
-// CHECK-NEXT:     return %[[cst]]
+// CHECK-DAG:     arith.constant 0
+// CHECK-DAG:     arith.constant 10
+// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
+// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
+// CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
+// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb2(%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
+// CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
+// CHECK-NEXT:     cf.br ^[[bb3:.*]](%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
+// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
+// CHECK-NEXT:     return %{{.*}}
 // CHECK-NEXT:   }
diff --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir
index 580a97d3a851b..aa30900f76a33 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir
@@ -46,11 +46,11 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
 // DET-ALL:         cf.br ^[[bb1:.*]](%{{.*}} : i32)
 // DET-ALL:       ^[[bb1]](%{{.*}}: i32)
 // DET-ALL:         arith.cmpi slt, {{.*}}
-// DET-ALL:         cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
-// DET-ALL:       ^[[bb2]]
+// DET-ALL:         cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
+// DET-ALL:       ^[[bb2]](%{{.*}}: i32)
 // DET-ALL:         arith.addi {{.*}}
 // DET-ALL:         cf.br ^[[bb1]](%{{.*}} : i32)
-// DET-ALL:       ^[[bb3]]:
+// DET-ALL:       ^[[bb3]](%{{.*}}: i32)
 // DET-ALL:         tensor.from_elements {{.*}}
 // DET-ALL:         return %{{.*}} : tensor<i32>
 
@@ -62,10 +62,10 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
 // DET-CF:         cf.br ^[[bb1:.*]](%{{.*}} : i32)
 // DET-CF:       ^[[bb1]](%{{.*}}: i32)
 // DET-CF:         arith.cmpi slt, {{.*}}
-// DET-CF:         cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
-// DET-CF:       ^[[bb2]]:
+// DET-CF:         cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
+// DET-CF:       ^[[bb2]](%{{.*}}: i32)
 // DET-CF:         arith.addi {{.*}}
 // DET-CF:         cf.br ^[[bb1]](%{{.*}} : i32)
-// DET-CF:       ^[[bb3]]:
+// DET-CF:       ^[[bb3]](%{{.*}}: i32)
 // DET-CF:         tensor.from_elements %{{.*}} : tensor<i32>
 // DET-CF:         return %{{.*}} : tensor<i32>
diff --git a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
index 414d9b94cbf53..955c7be5ef4c8 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
@@ -74,8 +74,8 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attr
 // DET-ALL:         } -> tensor<i32>
 // DET-ALL:    ...
[truncated]

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

I can reproduce as well using the following command line:

./bin/mlir-opt test.mlir -pass-pipeline="any(any(canonicalize{region-simplify=aggressive}))" 
test.mlir:10:5: error: type mismatch for bb argument #0 of successor #1
    llvm.cond_br %1, ^bb3(%4 : i1), ^bb4(%4 : i1)

@Dinistro Dinistro merged commit 6a5a64c into main Jul 25, 2024
12 checks passed
@Dinistro Dinistro deleted the revert-97697-improve_block_merging_2 branch July 25, 2024 08:42
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
Summary:
Reverts #97697

This commit introduced non-trivial bugs related to type consistency.

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60250716
@giuseros
Copy link
Contributor

giuseros commented Aug 5, 2024

Hi @Dinistro , sorry, I was on holiday for the last 2 weeks. I will investigate why this is happening and I will come back with a new PR. Thanks!

@Dinistro
Copy link
Contributor Author

Dinistro commented Aug 5, 2024

Hi @Dinistro , sorry, I was on holiday for the last 2 weeks. I will investigate why this is happening and I will come back with a new PR. Thanks!

No worries, this can happen for such complicated changes. Good luck finding the issue ;)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir:core MLIR Core Infrastructure mlir:linalg mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants