-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][IntRangeAnalysis] Handle unstructured loop arguments correctly #119459
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The integer range analysis currently has a bug where, because of how it interacts with dead code analysis, it will sometimes declare code dead that isn't dead, becaues it hasn't seen the edge that loops an incremented value back to itself yet. This commit fixes the issue by overriding the join method on lattice values in order to detect these back-edges on non-entry blocks and then snapping the passed-around value to its maximum possible range, just like we do for loop-varying values in region control flow. Fixes llvm#119045
@llvm/pr-subscribers-mlir-arith @llvm/pr-subscribers-mlir Author: Krzysztof Drewniak (krzysz00) ChangesThe integer range analysis currently has a bug where, because of how it interacts with dead code analysis, it will sometimes declare code dead that isn't dead, becaues it hasn't seen the edge that loops an incremented value back to itself yet. This commit fixes the issue by overriding the join method on lattice values in order to detect these back-edges on non-entry blocks and then snapping the passed-around value to its maximum possible range, just like we do for loop-varying values in region control flow. Fixes #119045 Full diff: https://github.com/llvm/llvm-project/pull/119459.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index f99eae379596b6..464a47355b4207 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -31,6 +31,16 @@ class IntegerValueRangeLattice : public Lattice<IntegerValueRange> {
public:
using Lattice::Lattice;
+ /// Override the join logic so that arguments to non-entry blocks
+ /// whose arguments come from later in the program get set to
+ /// a maximal value so that we don't prematurely declare code to be
+ /// deade.
+ ChangeResult join(const AbstractSparseLattice &rhs) override;
+
+ ChangeResult join(const IntegerValueRange &range) {
+ return Lattice::join(range);
+ }
+
/// If the range can be narrowed to an integer constant, update the constant
/// value of the SSA value.
void onUpdate(DataFlowSolver *solver) const override;
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index a97e43708d9a37..a45fcee345e91d 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -37,6 +37,50 @@
using namespace mlir;
using namespace mlir::dataflow;
+/// Return true if `block` is a non-entry block with a predecessor that's
+/// defined after the block. This allows us to detect loop-varying values
+/// in unstructured control flow.
+static bool isLoopLikeBlock(Block *block) {
+ if (!block || block->isEntryBlock())
+ return false;
+ Region *parent = block->getParent();
+ if (!parent)
+ return false;
+
+ SmallPtrSet<Block *, 4> preds;
+ for (Block *pred : block->getPredecessors())
+ preds.insert(pred);
+ if (preds.size() <= 1)
+ return false;
+
+ for (Block ®ionBlock : parent->getBlocks()) {
+ if (®ionBlock == block)
+ break;
+ preds.erase(®ionBlock);
+ }
+
+ // The block loops back on itself or has an edge from further in the program.
+ return !preds.empty();
+}
+
+ChangeResult IntegerValueRangeLattice::join(const AbstractSparseLattice &rhs) {
+ Value lhsAnchor = getAnchor();
+ Block *lhsBlock = lhsAnchor.getParentBlock();
+ unsigned width = ConstantIntRanges::getStorageBitwidth(lhsAnchor.getType());
+ /// Special-case: we're in unstructured control flow and one of the
+ /// predecessors of this block argument is defined in a block that comes after
+ /// the argument. So we conservatively conclude that the value could be
+ /// anything.
+ if (width > 0 && isa<BlockArgument>(lhsAnchor) && isLoopLikeBlock(lhsBlock)) {
+ LLVM_DEBUG(llvm::dbgs() << "Found loop-varying block argument " << lhsAnchor
+ << " from " << rhs.getAnchor() << "\n");
+ LLVM_DEBUG(llvm::dbgs() << "Inferring maximum range\n");
+ IntegerValueRange maxRange = IntegerValueRange::getMaxRange(lhsAnchor);
+ return join(maxRange);
+ }
+ return Lattice::join(rhs);
+}
+
void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
Lattice::onUpdate(solver);
@@ -206,6 +250,8 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
if (max.sge(min)) {
IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
auto ivRange = ConstantIntRanges::fromSigned(min, max);
+ LLVM_DEBUG(llvm::dbgs()
+ << "Inferred loop bound range: " << ivRange << "\n");
propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
}
return;
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index ea5969a1002580..e312cf175f5b56 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -132,3 +132,63 @@ func.func @wraps() -> i8 {
%mod = arith.remsi %val, %c64 : i8
return %mod : i8
}
+
+// -----
+
+// Note: I wish I had a simpler example than this, but getting rid of a
+// bunch of the arithmetic made the issue go away.
+// CHECK-LABEL: @blocks_prematurely_declared_dead_bug
+// CHECK-NOT: arith.constant true
+func.func @blocks_prematurely_declared_dead_bug(%mem: memref<?xf16>) {
+ %cst = arith.constant dense<false> : vector<1xi1>
+ %c1 = arith.constant 1 : index
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<1xf16>
+ %cst_1 = arith.constant 0.000000e+00 : f16
+ %c16 = arith.constant 16 : index
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %thread_id_x = gpu.thread_id x upper_bound 64
+ %6 = test.with_bounds { smin = 16 : index, smax = 112 : index, umin = 16 : index, umax = 112 : index } : index
+ %8 = arith.divui %6, %c16 : index
+ %9 = arith.muli %8, %c16 : index
+ cf.br ^bb1(%c0 : index)
+^bb1(%12: index): // 2 preds: ^bb0, ^bb7
+ %13 = arith.cmpi slt, %12, %9 : index
+ cf.cond_br %13, ^bb2, ^bb8
+^bb2: // pred: ^bb1
+ %14 = arith.subi %9, %12 : index
+ %15 = arith.minsi %14, %c64 : index
+ %16 = arith.subi %15, %thread_id_x : index
+ %17 = vector.constant_mask [1] : vector<1xi1>
+ %18 = arith.cmpi sgt, %16, %c0 : index
+ %19 = arith.select %18, %17, %cst : vector<1xi1>
+ %20 = vector.extract %19[0] : i1 from vector<1xi1>
+ %21 = vector.insert %20, %cst [0] : i1 into vector<1xi1>
+ %22 = arith.addi %12, %thread_id_x : index
+ cf.br ^bb3(%c0, %cst_0 : index, vector<1xf16>)
+^bb3(%23: index, %24: vector<1xf16>): // 2 preds: ^bb2, ^bb6
+ %25 = arith.cmpi slt, %23, %c1 : index
+ cf.cond_br %25, ^bb4, ^bb7
+^bb4: // pred: ^bb3
+ %26 = vector.extractelement %21[%23 : index] : vector<1xi1>
+ cf.cond_br %26, ^bb5, ^bb6(%24 : vector<1xf16>)
+^bb5: // pred: ^bb4
+ %27 = arith.addi %22, %23 : index
+ %28 = memref.load %mem[%27] : memref<?xf16>
+ %29 = vector.insertelement %28, %24[%23 : index] : vector<1xf16>
+ cf.br ^bb6(%29 : vector<1xf16>)
+^bb6(%30: vector<1xf16>): // 2 preds: ^bb4, ^bb5
+ %31 = arith.addi %23, %c1 : index
+ cf.br ^bb3(%31, %30 : index, vector<1xf16>)
+^bb7: // pred: ^bb3
+ %37 = arith.addi %12, %c64 : index
+ cf.br ^bb1(%37 : index)
+^bb8: // pred: ^bb1
+ %70 = arith.cmpi eq, %thread_id_x, %c0 : index
+ cf.cond_br %70, ^bb9, ^bb10
+^bb9: // pred: ^bb8
+ memref.store %cst_1, %mem[%c0] : memref<?xf16>
+ cf.br ^bb10
+^bb10: // 2 preds: ^bb8, ^bb9
+ return
+}
|
Checking for specific CFG structures in what supposed to be a generic detaflow analysis doesn't look very robust, I need some more time to process this PR. |
I'm open to better suggestions if you've got them |
The integer range analysis currently has a bug where, because of how it interacts with dead code analysis, it will sometimes declare code dead that isn't dead, becaues it hasn't seen the edge that loops an incremented value back to itself yet.
This commit fixes the issue by overriding the join method on lattice values in order to detect these back-edges on non-entry blocks and then snapping the passed-around value to its maximum possible range, just like we do for loop-varying values in region control flow.
Fixes #119045