Skip to content

[mlir][IntRangeAnalysis] Handle unstructured loop arguments correctly #119459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ class IntegerValueRangeLattice : public Lattice<IntegerValueRange> {
public:
using Lattice::Lattice;

/// Override the join logic so that arguments to non-entry blocks
/// whose arguments come from later in the program get set to
/// a maximal value so that we don't prematurely declare code to be
/// deade.
ChangeResult join(const AbstractSparseLattice &rhs) override;

ChangeResult join(const IntegerValueRange &range) {
return Lattice::join(range);
}

/// If the range can be narrowed to an integer constant, update the constant
/// value of the SSA value.
void onUpdate(DataFlowSolver *solver) const override;
Expand Down
46 changes: 46 additions & 0 deletions mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,50 @@
using namespace mlir;
using namespace mlir::dataflow;

/// Return true if `block` is a non-entry block with a predecessor that's
/// defined after the block. This allows us to detect loop-varying values
/// in unstructured control flow.
static bool isLoopLikeBlock(Block *block) {
if (!block || block->isEntryBlock())
return false;
Region *parent = block->getParent();
if (!parent)
return false;

SmallPtrSet<Block *, 4> preds;
for (Block *pred : block->getPredecessors())
preds.insert(pred);
if (preds.size() <= 1)
return false;

for (Block &regionBlock : parent->getBlocks()) {
if (&regionBlock == block)
break;
preds.erase(&regionBlock);
}

// The block loops back on itself or has an edge from further in the program.
return !preds.empty();
}

ChangeResult IntegerValueRangeLattice::join(const AbstractSparseLattice &rhs) {
Value lhsAnchor = getAnchor();
Block *lhsBlock = lhsAnchor.getParentBlock();
unsigned width = ConstantIntRanges::getStorageBitwidth(lhsAnchor.getType());
/// Special-case: we're in unstructured control flow and one of the
/// predecessors of this block argument is defined in a block that comes after
/// the argument. So we conservatively conclude that the value could be
/// anything.
if (width > 0 && isa<BlockArgument>(lhsAnchor) && isLoopLikeBlock(lhsBlock)) {
LLVM_DEBUG(llvm::dbgs() << "Found loop-varying block argument " << lhsAnchor
<< " from " << rhs.getAnchor() << "\n");
LLVM_DEBUG(llvm::dbgs() << "Inferring maximum range\n");
IntegerValueRange maxRange = IntegerValueRange::getMaxRange(lhsAnchor);
return join(maxRange);
}
return Lattice::join(rhs);
}

void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
Lattice::onUpdate(solver);

Expand Down Expand Up @@ -206,6 +250,8 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
if (max.sge(min)) {
IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
auto ivRange = ConstantIntRanges::fromSigned(min, max);
LLVM_DEBUG(llvm::dbgs()
<< "Inferred loop bound range: " << ivRange << "\n");
propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
}
return;
Expand Down
60 changes: 60 additions & 0 deletions mlir/test/Dialect/Arith/int-range-opts.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,63 @@ func.func @wraps() -> i8 {
%mod = arith.remsi %val, %c64 : i8
return %mod : i8
}

// -----

// Note: I wish I had a simpler example than this, but getting rid of a
// bunch of the arithmetic made the issue go away.
// CHECK-LABEL: @blocks_prematurely_declared_dead_bug
// CHECK-NOT: arith.constant true
func.func @blocks_prematurely_declared_dead_bug(%mem: memref<?xf16>) {
%cst = arith.constant dense<false> : vector<1xi1>
%c1 = arith.constant 1 : index
%cst_0 = arith.constant dense<0.000000e+00> : vector<1xf16>
%cst_1 = arith.constant 0.000000e+00 : f16
%c16 = arith.constant 16 : index
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
%thread_id_x = gpu.thread_id x upper_bound 64
%6 = test.with_bounds { smin = 16 : index, smax = 112 : index, umin = 16 : index, umax = 112 : index } : index
%8 = arith.divui %6, %c16 : index
%9 = arith.muli %8, %c16 : index
cf.br ^bb1(%c0 : index)
^bb1(%12: index): // 2 preds: ^bb0, ^bb7
%13 = arith.cmpi slt, %12, %9 : index
cf.cond_br %13, ^bb2, ^bb8
^bb2: // pred: ^bb1
%14 = arith.subi %9, %12 : index
%15 = arith.minsi %14, %c64 : index
%16 = arith.subi %15, %thread_id_x : index
%17 = vector.constant_mask [1] : vector<1xi1>
%18 = arith.cmpi sgt, %16, %c0 : index
%19 = arith.select %18, %17, %cst : vector<1xi1>
%20 = vector.extract %19[0] : i1 from vector<1xi1>
%21 = vector.insert %20, %cst [0] : i1 into vector<1xi1>
%22 = arith.addi %12, %thread_id_x : index
cf.br ^bb3(%c0, %cst_0 : index, vector<1xf16>)
^bb3(%23: index, %24: vector<1xf16>): // 2 preds: ^bb2, ^bb6
%25 = arith.cmpi slt, %23, %c1 : index
cf.cond_br %25, ^bb4, ^bb7
^bb4: // pred: ^bb3
%26 = vector.extractelement %21[%23 : index] : vector<1xi1>
cf.cond_br %26, ^bb5, ^bb6(%24 : vector<1xf16>)
^bb5: // pred: ^bb4
%27 = arith.addi %22, %23 : index
%28 = memref.load %mem[%27] : memref<?xf16>
%29 = vector.insertelement %28, %24[%23 : index] : vector<1xf16>
cf.br ^bb6(%29 : vector<1xf16>)
^bb6(%30: vector<1xf16>): // 2 preds: ^bb4, ^bb5
%31 = arith.addi %23, %c1 : index
cf.br ^bb3(%31, %30 : index, vector<1xf16>)
^bb7: // pred: ^bb3
%37 = arith.addi %12, %c64 : index
cf.br ^bb1(%37 : index)
^bb8: // pred: ^bb1
%70 = arith.cmpi eq, %thread_id_x, %c0 : index
cf.cond_br %70, ^bb9, ^bb10
^bb9: // pred: ^bb8
memref.store %cst_1, %mem[%c0] : memref<?xf16>
cf.br ^bb10
^bb10: // 2 preds: ^bb8, ^bb9
return
}
Loading