Skip to content

[mlir][sparse] partially support lowering sparse coiteration loops to scf.while/for. #105565

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,24 +96,32 @@ class I64BitSet {
return *this;
}

bool isSubSetOf(const I64BitSet p) const {
I64BitSet tmp = *this;
tmp |= p;
return tmp == p;
}

// Needed by `llvm::const_set_bits_iterator_impl`.
int find_first() const { return min(); }
int find_next(unsigned prev) const {
if (prev >= max())
if (prev >= max() - 1)
return -1;

uint64_t b = storage >> (prev + 1);
if (b == 0)
return -1;
uint64_t b = storage >> (prev + static_cast<int64_t>(1));
assert(b != 0);

return llvm::countr_zero(b) + prev + 1;
return llvm::countr_zero(b) + prev + static_cast<int64_t>(1);
}

bool operator[](unsigned i) const {
assert(i < 64);
return (storage & (1 << i)) != 0;
return (storage & (static_cast<int64_t>(1) << i)) != 0;
}
unsigned min() const {
unsigned m = llvm::countr_zero(storage);
return m == 64 ? -1 : m;
}
unsigned min() const { return llvm::countr_zero(storage); }
unsigned max() const { return 64 - llvm::countl_zero(storage); }
unsigned count() const { return llvm::popcount(storage); }
bool empty() const { return storage == 0; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1787,6 +1787,10 @@ def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
.take_back(getRegionDefinedSpace(regionIdx).count());
}
ValueRange getYieldedValues(unsigned regionIdx);

// Returns a vector of regions that are the `sub-cases` of the given case region.
// E.g., `case %it1, _, %it3` is a subcase of `case %it1, %it2, %it3`.
SmallVector<Region *> getSubCasesOf(unsigned regionIdx);
}];

let hasVerifier = 1;
Expand Down
10 changes: 10 additions & 0 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2745,6 +2745,16 @@ LogicalResult CoIterateOp::verifyRegions() {
return success();
}

SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) {
SmallVector<Region *> ret;
I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
for (Region &r : getCaseRegions())
if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
ret.push_back(&r);

return ret;
}

//===----------------------------------------------------------------------===//
// Sparse Tensor Dialect Setups.
//===----------------------------------------------------------------------===//
Expand Down
291 changes: 290 additions & 1 deletion mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

#include "Utils/CodegenUtils.h"
#include "Utils/LoopEmitter.h"
#include "Utils/SparseTensorIterator.h"

#include "mlir/Dialect/MemRef/IR/MemRef.h"
Expand Down Expand Up @@ -49,6 +50,144 @@ convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
return success();
}

static ValueRange
genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
Value loopCrd,
ArrayRef<std::unique_ptr<SparseIterator>> iters,
ArrayRef<Region *> subCases, ArrayRef<Value> userReduc) {
if (subCases.empty())
return userReduc;

// The current branch that we are handling.
Region *b = subCases.front();
Value casePred = constantI1(rewriter, loc, true);
I64BitSet caseBits = op.getRegionDefinedSpace(b->getRegionNumber());
for (unsigned i : caseBits.bits()) {
SparseIterator *it = iters[i].get();
Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
it->getCrd(), loopCrd);
casePred = rewriter.create<arith::AndIOp>(loc, casePred, pred);
}
scf::IfOp ifOp = rewriter.create<scf::IfOp>(
loc, ValueRange(userReduc).getTypes(), casePred, /*else=*/true);
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());

// Erase the empty block.
rewriter.eraseBlock(&ifOp.getThenRegion().front());
// Set up block arguments: user-provided values -> loop coord -> iterators.
SmallVector<Value> blockArgs(userReduc);
blockArgs.push_back(loopCrd);
for (unsigned idx : caseBits.bits())
llvm::append_range(blockArgs, iters[idx]->getCursor());

IRMapping mapping;
for (auto [from, to] :
llvm::zip_equal(b->front().getArguments(), blockArgs)) {
mapping.map(from, to);
}

// Clone the region, we can not erase the region now because the same region
// might be a subcase for multiple lattice point.
rewriter.cloneRegionBefore(*b, ifOp.getThenRegion(),
ifOp.getThenRegion().begin(), mapping);

// replace sparse_tensor::YieldOp -> scf::YieldOp
auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back());
ValueRange yields = spY.getResults();
rewriter.eraseOp(spY);
rewriter.setInsertionPointToEnd(&ifOp.getThenRegion().front());
rewriter.create<scf::YieldOp>(loc, yields);

// Generates remaining case recursively.
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters,
subCases.drop_front(), userReduc);
if (!res.empty())
rewriter.create<scf::YieldOp>(loc, res);

rewriter.setInsertionPointAfter(ifOp);
return ifOp.getResults();
}

static ValueRange genLoopWithIterator(
PatternRewriter &rewriter, Location loc, SparseIterator *it,
ValueRange reduc, bool iterFirst,
function_ref<SmallVector<Value>(PatternRewriter &rewriter, Location loc,
Region &loopBody, SparseIterator *it,
ValueRange reduc)>
bodyBuilder) {
if (it->iteratableByFor()) {
auto [lo, hi] = it->genForCond(rewriter, loc);
Value step = constantIndex(rewriter, loc, 1);
scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, reduc);
{
OpBuilder::InsertionGuard guard(rewriter);
// Erase the implicit yield operation created by ForOp when there is no
// yielding values.
if (!forOp.getBody()->empty())
rewriter.eraseOp(&forOp.getBody()->front());
assert(forOp.getBody()->empty());

it->linkNewScope(forOp.getInductionVar());
rewriter.setInsertionPointToStart(forOp.getBody());
SmallVector<Value> ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(),
it, forOp.getRegionIterArgs());

rewriter.setInsertionPointToEnd(forOp.getBody());
rewriter.create<scf::YieldOp>(loc, ret);
}
return forOp.getResults();
}
SmallVector<Value> ivs;
// TODO: always put iterator SSA values at the end of argument list to be
// consistent with coiterate operation.
if (!iterFirst)
llvm::append_range(ivs, it->getCursor());
// Appends the user-provided values.
llvm::append_range(ivs, reduc);
if (iterFirst)
llvm::append_range(ivs, it->getCursor());

TypeRange types = ValueRange(ivs).getTypes();
auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
{
OpBuilder::InsertionGuard guard(rewriter);
// Generates loop conditions.
SmallVector<Location> l(types.size(), loc);
Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
rewriter.setInsertionPointToStart(before);
ValueRange bArgs = before->getArguments();
auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());

// Delegates loop body generation.
Region &dstRegion = whileOp.getAfter();
Block *after = rewriter.createBlock(&dstRegion, {}, types, l);
ValueRange aArgs = whileOp.getAfterArguments();
if (iterFirst) {
aArgs = it->linkNewScope(aArgs);
} else {
aArgs = aArgs.take_front(reduc.size());
it->linkNewScope(aArgs.drop_front(reduc.size()));
}

rewriter.setInsertionPointToStart(after);
SmallVector<Value> ret = bodyBuilder(rewriter, loc, dstRegion, it, aArgs);
rewriter.setInsertionPointToEnd(after);

// Forward loops
SmallVector<Value> yields;
ValueRange nx = it->forward(rewriter, loc);
if (iterFirst)
llvm::append_range(yields, nx);
llvm::append_range(yields, ret);
if (!iterFirst)
llvm::append_range(yields, nx);
rewriter.create<scf::YieldOp>(loc, yields);
}
return whileOp.getResults().drop_front(it->getCursor().size());
}

namespace {

/// Sparse codegen rule for number of entries operator.
Expand Down Expand Up @@ -136,6 +275,8 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
rewriter.replaceOp(op, forOp.getResults(), resultMapping);
} else {
SmallVector<Value> ivs;
// TODO: put iterator at the end of argument list to be consistent with
// coiterate operation.
llvm::append_range(ivs, it->getCursor());
for (ValueRange inits : adaptor.getInitArgs())
llvm::append_range(ivs, inits);
Expand Down Expand Up @@ -189,6 +330,153 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
}
};

class SparseCoIterateOpConverter
: public OneToNOpConversionPattern<CoIterateOp> {
using OneToNOpConversionPattern::OneToNOpConversionPattern;

LogicalResult
matchAndRewrite(CoIterateOp op, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
assert(op.getSpaceDim() == 1 && "Not implemented");
Location loc = op.getLoc();

I64BitSet denseBits(0);
for (auto [idx, spaceTp] : llvm::enumerate(op.getIterSpaces().getTypes()))
if (all_of(cast<IterSpaceType>(spaceTp).getLvlTypes(), isDenseLT))
denseBits.set(idx);

// If there exists a case that only contains dense spaces. I.e., case
// bits is a subset of dense bits, or when there is a full empty case (due
// to complements), we need a universal pointer to forward the coiteration
// loop.
bool needUniv =
any_of(op.getRegionDefinedSpaces(), [denseBits](I64BitSet caseBits) {
// A case for complement.
if (caseBits.count() == 0)
return true;
// An all-dense case.
return caseBits.isSubSetOf(denseBits);
});
assert(!needUniv && "Not implemented");
(void)needUniv;

for (Region &region : op.getCaseRegions()) {
// Do a one-shot type conversion on all region blocks, since the same
// region might be used multiple time.
Block *block = &region.getBlocks().front();
OneToNTypeMapping blockTypeMapping(block->getArgumentTypes());
if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
blockTypeMapping)))
return rewriter.notifyMatchFailure(
op, "failed to convert coiterate region argurment types");

rewriter.applySignatureConversion(block, blockTypeMapping);
}

SmallVector<SparseIterationSpace> spaces;
SmallVector<std::unique_ptr<SparseIterator>> iters;
for (auto [spaceTp, spaceVals] : llvm::zip_equal(
op.getIterSpaces().getTypes(), adaptor.getIterSpaces())) {
// TODO: do we really need tid?
spaces.push_back(SparseIterationSpace::fromValues(
cast<IterSpaceType>(spaceTp), spaceVals, /*tid=*/0));
// Extract the iterator.
iters.push_back(spaces.back().extractIterator(rewriter, loc));
}

auto getFilteredIters = [&iters](I64BitSet caseBits) {
// Retrives a vector of pointers to the iterators used in the case.
SmallVector<SparseIterator *> validIters;
for (auto idx : caseBits.bits())
validIters.push_back(iters[idx].get());
return validIters;
};

// Get a flattened user-provided loop reduction values.
SmallVector<Value> userReduc;
for (ValueRange r : adaptor.getInitArgs())
llvm::append_range(userReduc, r);

// TODO: we need to sort the cases such that they appears in lexical order.
// Although sparsification always generates cases in that order, it might
// not be the case for human-written code.

// Generates a loop sequence, one loop per case.
for (auto [r, caseBits] :
llvm::zip_equal(op.getCaseRegions(), op.getRegionDefinedSpaces())) {
assert(caseBits.count() > 0 && "Complement space not implemented");

// Retrives a vector of pointers to the iterators used in the case.
SmallVector<SparseIterator *> validIters = getFilteredIters(caseBits);

if (validIters.size() > 1) {
auto [loop, loopCrd] =
genCoIteration(rewriter, loc, validIters, userReduc,
/*uniIdx=*/nullptr, /*userReducFirst=*/true);

// 1st. find all the cases that is a strict subset of the current case
// condition, for which we generate one branch per case inside the loop.
// The subcases are never empty, it must contains at least the current
// region itself.
// TODO: these cases should be sorted.
SmallVector<Region *> subCases = op.getSubCasesOf(r.getRegionNumber());
assert(!subCases.empty());

ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd,
iters, subCases, userReduc);

SmallVector<Value> nextIterYields(res);
// 2nd. foward the loop.
for (SparseIterator *it : validIters) {
Value cmp = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, it->getCrd(), loopCrd);
it->forwardIf(rewriter, loc, cmp);
llvm::append_range(nextIterYields, it->getCursor());
}
rewriter.create<scf::YieldOp>(loc, nextIterYields);

// Exit the loop, relink the iterator SSA value.
rewriter.setInsertionPointAfter(loop);
ValueRange iterVals = loop->getResults().drop_front(userReduc.size());
for (SparseIterator *it : validIters)
iterVals = it->linkNewScope(iterVals);
assert(iterVals.empty());

ValueRange curResult = loop->getResults().take_front(userReduc.size());
userReduc.assign(curResult.begin(), curResult.end());
} else {
// This is a simple iteration loop.
assert(caseBits.count() == 1);

Block *block = &r.getBlocks().front();
ValueRange curResult = genLoopWithIterator(
rewriter, loc, validIters.front(), userReduc, /*iterFirst=*/false,
/*bodyBuilder=*/
[block](PatternRewriter &rewriter, Location loc, Region &dstRegion,
SparseIterator *it,
ValueRange reduc) -> SmallVector<Value> {
SmallVector<Value> blockArgs(reduc);
blockArgs.push_back(it->deref(rewriter, loc));
llvm::append_range(blockArgs, it->getCursor());

Block *dstBlock = &dstRegion.getBlocks().front();
rewriter.inlineBlockBefore(
block, dstBlock, rewriter.getInsertionPoint(), blockArgs);
auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
SmallVector<Value> result(yield.getResults());
rewriter.eraseOp(yield);
return result;
});

userReduc.assign(curResult.begin(), curResult.end());
}
}

rewriter.replaceOp(op, userReduc);
return success();
}
};

} // namespace

mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
Expand All @@ -210,5 +498,6 @@ void mlir::populateLowerSparseIterationToSCFPatterns(

IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext());
patterns.add<ExtractIterSpaceConverter, ExtractValOpConverter,
SparseIterateOpConverter>(converter, patterns.getContext());
SparseIterateOpConverter, SparseCoIterateOpConverter>(
converter, patterns.getContext());
}
Loading
Loading