Skip to content

[mlir][IntRangeAnalysis] Handle unstructured loop arguments correctly #119459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

krzysz00
Copy link
Contributor

The integer range analysis currently has a bug where, because of how it interacts with dead code analysis, it will sometimes declare code dead that isn't dead, becaues it hasn't seen the edge that loops an incremented value back to itself yet.

This commit fixes the issue by overriding the join method on lattice values in order to detect these back-edges on non-entry blocks and then snapping the passed-around value to its maximum possible range, just like we do for loop-varying values in region control flow.

Fixes #119045

The integer range analysis currently has a bug where, because of how
it interacts with dead code analysis, it will sometimes declare code
dead that isn't dead, becaues it hasn't seen the edge that loops an
incremented value back to itself yet.

This commit fixes the issue by overriding the join method on lattice
values in order to detect these back-edges on non-entry blocks and
then snapping the passed-around value to its maximum possible range,
just like we do for loop-varying values in region control flow.

Fixes llvm#119045
@llvmbot
Copy link
Member

llvmbot commented Dec 10, 2024

@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Krzysztof Drewniak (krzysz00)

Changes

The integer range analysis currently has a bug where, because of how it interacts with dead code analysis, it will sometimes declare code dead that isn't dead, becaues it hasn't seen the edge that loops an incremented value back to itself yet.

This commit fixes the issue by overriding the join method on lattice values in order to detect these back-edges on non-entry blocks and then snapping the passed-around value to its maximum possible range, just like we do for loop-varying values in region control flow.

Fixes #119045


Full diff: https://github.com/llvm/llvm-project/pull/119459.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h (+10)
  • (modified) mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp (+46)
  • (modified) mlir/test/Dialect/Arith/int-range-opts.mlir (+60)
diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index f99eae379596b6..464a47355b4207 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -31,6 +31,16 @@ class IntegerValueRangeLattice : public Lattice<IntegerValueRange> {
 public:
   using Lattice::Lattice;
 
+  /// Override the join logic so that arguments to non-entry blocks
+  /// whose arguments come from later in the program get set to
+  /// a maximal value so that we don't prematurely declare code to be
+  /// deade.
+  ChangeResult join(const AbstractSparseLattice &rhs) override;
+
+  ChangeResult join(const IntegerValueRange &range) {
+    return Lattice::join(range);
+  }
+
   /// If the range can be narrowed to an integer constant, update the constant
   /// value of the SSA value.
   void onUpdate(DataFlowSolver *solver) const override;
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index a97e43708d9a37..a45fcee345e91d 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -37,6 +37,50 @@
 using namespace mlir;
 using namespace mlir::dataflow;
 
+/// Return true if `block` is a non-entry block with a predecessor that's
+/// defined after the block. This allows us to detect loop-varying values
+/// in unstructured control flow.
+static bool isLoopLikeBlock(Block *block) {
+  if (!block || block->isEntryBlock())
+    return false;
+  Region *parent = block->getParent();
+  if (!parent)
+    return false;
+
+  SmallPtrSet<Block *, 4> preds;
+  for (Block *pred : block->getPredecessors())
+    preds.insert(pred);
+  if (preds.size() <= 1)
+    return false;
+
+  for (Block &regionBlock : parent->getBlocks()) {
+    if (&regionBlock == block)
+      break;
+    preds.erase(&regionBlock);
+  }
+
+  // The block loops back on itself or has an edge from further in the program.
+  return !preds.empty();
+}
+
+ChangeResult IntegerValueRangeLattice::join(const AbstractSparseLattice &rhs) {
+  Value lhsAnchor = getAnchor();
+  Block *lhsBlock = lhsAnchor.getParentBlock();
+  unsigned width = ConstantIntRanges::getStorageBitwidth(lhsAnchor.getType());
+  /// Special-case: we're in unstructured control flow and one of the
+  /// predecessors of this block argument is defined in a block that comes after
+  /// the argument. So we conservatively conclude that the value could be
+  /// anything.
+  if (width > 0 && isa<BlockArgument>(lhsAnchor) && isLoopLikeBlock(lhsBlock)) {
+    LLVM_DEBUG(llvm::dbgs() << "Found loop-varying block argument " << lhsAnchor
+                            << " from " << rhs.getAnchor() << "\n");
+    LLVM_DEBUG(llvm::dbgs() << "Inferring maximum range\n");
+    IntegerValueRange maxRange = IntegerValueRange::getMaxRange(lhsAnchor);
+    return join(maxRange);
+  }
+  return Lattice::join(rhs);
+}
+
 void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
   Lattice::onUpdate(solver);
 
@@ -206,6 +250,8 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
     if (max.sge(min)) {
       IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
       auto ivRange = ConstantIntRanges::fromSigned(min, max);
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Inferred loop bound range: " << ivRange << "\n");
       propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
     }
     return;
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index ea5969a1002580..e312cf175f5b56 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -132,3 +132,63 @@ func.func @wraps() -> i8 {
   %mod = arith.remsi %val, %c64 : i8
   return %mod : i8
 }
+
+// -----
+
+// Note: I wish I had a simpler example than this, but getting rid of a
+// bunch of the arithmetic made the issue go away.
+// CHECK-LABEL: @blocks_prematurely_declared_dead_bug
+// CHECK-NOT: arith.constant true
+func.func @blocks_prematurely_declared_dead_bug(%mem: memref<?xf16>) {
+  %cst = arith.constant dense<false> : vector<1xi1>
+  %c1 = arith.constant 1 : index
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<1xf16>
+  %cst_1 = arith.constant 0.000000e+00 : f16
+  %c16 = arith.constant 16 : index
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %thread_id_x = gpu.thread_id  x upper_bound 64
+  %6 = test.with_bounds { smin = 16 : index, smax = 112 : index, umin = 16 : index, umax = 112 : index } : index
+  %8 = arith.divui %6, %c16 : index
+  %9 = arith.muli %8, %c16 : index
+  cf.br ^bb1(%c0 : index)
+^bb1(%12: index):  // 2 preds: ^bb0, ^bb7
+  %13 = arith.cmpi slt, %12, %9 : index
+  cf.cond_br %13, ^bb2, ^bb8
+^bb2:  // pred: ^bb1
+  %14 = arith.subi %9, %12 : index
+  %15 = arith.minsi %14, %c64 : index
+  %16 = arith.subi %15, %thread_id_x : index
+  %17 = vector.constant_mask [1] : vector<1xi1>
+  %18 = arith.cmpi sgt, %16, %c0 : index
+  %19 = arith.select %18, %17, %cst : vector<1xi1>
+  %20 = vector.extract %19[0] : i1 from vector<1xi1>
+  %21 = vector.insert %20, %cst [0] : i1 into vector<1xi1>
+  %22 = arith.addi %12, %thread_id_x : index
+  cf.br ^bb3(%c0, %cst_0 : index, vector<1xf16>)
+^bb3(%23: index, %24: vector<1xf16>):  // 2 preds: ^bb2, ^bb6
+  %25 = arith.cmpi slt, %23, %c1 : index
+  cf.cond_br %25, ^bb4, ^bb7
+^bb4:  // pred: ^bb3
+  %26 = vector.extractelement %21[%23 : index] : vector<1xi1>
+  cf.cond_br %26, ^bb5, ^bb6(%24 : vector<1xf16>)
+^bb5:  // pred: ^bb4
+  %27 = arith.addi %22, %23 : index
+  %28 = memref.load %mem[%27] : memref<?xf16>
+  %29 = vector.insertelement %28, %24[%23 : index] : vector<1xf16>
+  cf.br ^bb6(%29 : vector<1xf16>)
+^bb6(%30: vector<1xf16>):  // 2 preds: ^bb4, ^bb5
+  %31 = arith.addi %23, %c1 : index
+  cf.br ^bb3(%31, %30 : index, vector<1xf16>)
+^bb7:  // pred: ^bb3
+  %37 = arith.addi %12, %c64 : index
+  cf.br ^bb1(%37 : index)
+^bb8:  // pred: ^bb1
+  %70 = arith.cmpi eq, %thread_id_x, %c0 : index
+  cf.cond_br %70, ^bb9, ^bb10
+^bb9:  // pred: ^bb8
+  memref.store %cst_1, %mem[%c0] : memref<?xf16>
+  cf.br ^bb10
+^bb10:  // 2 preds: ^bb8, ^bb9
+  return
+}

@Hardcode84
Copy link
Contributor

Checking for specific CFG structures in what supposed to be a generic detaflow analysis doesn't look very robust, I need some more time to process this PR.

@krzysz00
Copy link
Contributor Author

I'm open to better suggestions if you've got them

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir][Dataflow] Blocks getting prematurely declared dead during integer range optimization
3 participants