-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][sparse] unify block arguments order between iterate/coiterate operations. #105567
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
…operations. stack-info: PR: #105567, branch: users/PeimingLiu/stack/3
PeimingLiu
pushed a commit
that referenced
this pull request
Aug 21, 2024
…operations. stack-info: PR: #105567, branch: users/PeimingLiu/stack/3
1a32495
to
937bcd8
Compare
6fd099f
to
3f83d7a
Compare
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-sparse Author: Peiming Liu (PeimingLiu) ChangesStacked PRs:
[mlir][sparse] unify block arguments order between iterate/coiterate operations.Full diff: https://github.com/llvm/llvm-project/pull/105567.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 20512f972e67cd..96a61419a541f7 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1644,7 +1644,7 @@ def IterateOp : SparseTensor_Op<"iterate",
return getIterSpace().getType().getSpaceDim();
}
BlockArgument getIterator() {
- return getRegion().getArguments().front();
+ return getRegion().getArguments().back();
}
std::optional<BlockArgument> getLvlCrd(Level lvl) {
if (getCrdUsedLvls()[lvl]) {
@@ -1654,9 +1654,8 @@ def IterateOp : SparseTensor_Op<"iterate",
return std::nullopt;
}
Block::BlockArgListType getCrds() {
- // The first block argument is iterator, the remaining arguments are
- // referenced coordinates.
- return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
+ // User-provided iteration arguments -> coords -> iterator.
+ return getRegion().getArguments().slice(getNumRegionIterArgs(), getCrdUsedLvls().count());
}
unsigned getNumRegionIterArgs() {
return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 16856b958d4f13..b21bc1a93036c4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2228,9 +2228,10 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state,
parser.getNameLoc(),
"mismatch in number of sparse iterators and sparse spaces");
- if (failed(parseUsedCoordList(parser, state, blockArgs)))
+ SmallVector<OpAsmParser::Argument> coords;
+ if (failed(parseUsedCoordList(parser, state, coords)))
return failure();
- size_t numCrds = blockArgs.size();
+ size_t numCrds = coords.size();
// Parse "iter_args(%arg = %init, ...)"
bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
@@ -2238,6 +2239,8 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state,
if (parser.parseAssignmentList(blockArgs, initArgs))
return failure();
+ blockArgs.append(coords);
+
SmallVector<Type> iterSpaceTps;
// parse ": sparse_tensor.iter_space -> ret"
if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
@@ -2267,7 +2270,7 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state,
if (hasIterArgs) {
// Strip off leading args that used for coordinates.
- MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
+ MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
if (args.size() != initArgs.size() || args.size() != state.types.size()) {
return parser.emitError(
parser.getNameLoc(),
@@ -2448,18 +2451,18 @@ void IterateOp::build(OpBuilder &builder, OperationState &odsState,
odsState.addTypes(initArgs.getTypes());
Block *bodyBlock = builder.createBlock(bodyRegion);
- // First argument, sparse iterator
- bodyBlock->addArgument(
- llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
- odsState.location);
+ // Starts with a list of user-provided loop arguments.
+ for (Value v : initArgs)
+ bodyBlock->addArgument(v.getType(), v.getLoc());
- // Followed by a list of used coordinates.
+ // Follows by a list of used coordinates.
for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
bodyBlock->addArgument(builder.getIndexType(), odsState.location);
- // Followed by a list of user-provided loop arguments.
- for (Value v : initArgs)
- bodyBlock->addArgument(v.getType(), v.getLoc());
+ // Ends with sparse iterator
+ bodyBlock->addArgument(
+ llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
+ odsState.location);
}
ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -2473,9 +2476,9 @@ ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
return parser.emitError(parser.getNameLoc(),
"expected only one iterator/iteration space");
- iters.append(iterArgs);
+ iterArgs.append(iters);
Region *body = result.addRegion();
- if (parser.parseRegion(*body, iters))
+ if (parser.parseRegion(*body, iterArgs))
return failure();
IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
@@ -2580,7 +2583,7 @@ MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
}
Block::BlockArgListType IterateOp::getRegionIterArgs() {
- return getRegion().getArguments().take_back(getNumRegionIterArgs());
+ return getRegion().getArguments().take_front(getNumRegionIterArgs());
}
std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index f7fcabb0220b50..71a229bea990c0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -111,7 +111,7 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
static ValueRange genLoopWithIterator(
PatternRewriter &rewriter, Location loc, SparseIterator *it,
- ValueRange reduc, bool iterFirst,
+ ValueRange reduc,
function_ref<SmallVector<Value>(PatternRewriter &rewriter, Location loc,
Region &loopBody, SparseIterator *it,
ValueRange reduc)>
@@ -138,15 +138,9 @@ static ValueRange genLoopWithIterator(
}
return forOp.getResults();
}
- SmallVector<Value> ivs;
- // TODO: always put iterator SSA values at the end of argument list to be
- // consistent with coiterate operation.
- if (!iterFirst)
- llvm::append_range(ivs, it->getCursor());
- // Appends the user-provided values.
- llvm::append_range(ivs, reduc);
- if (iterFirst)
- llvm::append_range(ivs, it->getCursor());
+
+ SmallVector<Value> ivs(reduc);
+ llvm::append_range(ivs, it->getCursor());
TypeRange types = ValueRange(ivs).getTypes();
auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
@@ -164,12 +158,8 @@ static ValueRange genLoopWithIterator(
Region &dstRegion = whileOp.getAfter();
Block *after = rewriter.createBlock(&dstRegion, {}, types, l);
ValueRange aArgs = whileOp.getAfterArguments();
- if (iterFirst) {
- aArgs = it->linkNewScope(aArgs);
- } else {
- aArgs = aArgs.take_front(reduc.size());
- it->linkNewScope(aArgs.drop_front(reduc.size()));
- }
+ it->linkNewScope(aArgs.drop_front(reduc.size()));
+ aArgs = aArgs.take_front(reduc.size());
rewriter.setInsertionPointToStart(after);
SmallVector<Value> ret = bodyBuilder(rewriter, loc, dstRegion, it, aArgs);
@@ -177,12 +167,8 @@ static ValueRange genLoopWithIterator(
// Forward loops
SmallVector<Value> yields;
- ValueRange nx = it->forward(rewriter, loc);
- if (iterFirst)
- llvm::append_range(yields, nx);
llvm::append_range(yields, ret);
- if (!iterFirst)
- llvm::append_range(yields, nx);
+ llvm::append_range(yields, it->forward(rewriter, loc));
rewriter.create<scf::YieldOp>(loc, yields);
}
return whileOp.getResults().drop_front(it->getCursor().size());
@@ -258,13 +244,13 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
Block *block = op.getBody();
ValueRange ret = genLoopWithIterator(
- rewriter, loc, it.get(), ivs, /*iterFirst=*/true,
+ rewriter, loc, it.get(), ivs,
[block](PatternRewriter &rewriter, Location loc, Region &loopBody,
SparseIterator *it, ValueRange reduc) -> SmallVector<Value> {
- SmallVector<Value> blockArgs(it->getCursor());
+ SmallVector<Value> blockArgs(reduc);
// TODO: Also appends coordinates if used.
// blockArgs.push_back(it->deref(rewriter, loc));
- llvm::append_range(blockArgs, reduc);
+ llvm::append_range(blockArgs, it->getCursor());
Block *dstBlock = &loopBody.getBlocks().front();
rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(),
@@ -404,7 +390,7 @@ class SparseCoIterateOpConverter
Block *block = &r.getBlocks().front();
ValueRange curResult = genLoopWithIterator(
- rewriter, loc, validIters.front(), userReduc, /*iterFirst=*/false,
+ rewriter, loc, validIters.front(), userReduc,
/*bodyBuilder=*/
[block](PatternRewriter &rewriter, Location loc, Region &dstRegion,
SparseIterator *it,
|
PeimingLiu
pushed a commit
that referenced
this pull request
Aug 22, 2024
…operations. stack-info: PR: #105567, branch: users/PeimingLiu/stack/3
3f83d7a
to
8c40bac
Compare
aartbik
approved these changes
Aug 23, 2024
5d73f23
to
984d8d5
Compare
PeimingLiu
pushed a commit
that referenced
this pull request
Aug 23, 2024
…operations. stack-info: PR: #105567, branch: users/PeimingLiu/stack/3
8c40bac
to
58bae5c
Compare
58bae5c
to
950043e
Compare
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Stacked PRs:
[mlir][sparse] unify block arguments order between iterate/coiterate operations.