Skip to content

[mlir][sparse] unify block arguments order between iterate/coiterate operations. #105567

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 23, 2024

Conversation

PeimingLiu
Copy link
Member

@PeimingLiu PeimingLiu commented Aug 21, 2024

…operations.

stack-info: PR: #105567, branch: users/PeimingLiu/stack/3
@llvmbot
Copy link
Member

llvmbot commented Aug 21, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Author: Peiming Liu (PeimingLiu)

Changes

Stacked PRs:

  • ->#105567
  • #105566
  • #105565

[mlir][sparse] unify block arguments order between iterate/coiterate operations.


Full diff: https://github.com/llvm/llvm-project/pull/105567.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+3-4)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+17-14)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp (+11-25)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 20512f972e67cd..96a61419a541f7 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1644,7 +1644,7 @@ def IterateOp : SparseTensor_Op<"iterate",
       return getIterSpace().getType().getSpaceDim();
     }
     BlockArgument getIterator() {
-      return getRegion().getArguments().front();
+      return getRegion().getArguments().back();
     }
     std::optional<BlockArgument> getLvlCrd(Level lvl) {
       if (getCrdUsedLvls()[lvl]) {
@@ -1654,9 +1654,8 @@ def IterateOp : SparseTensor_Op<"iterate",
       return std::nullopt;
     }
     Block::BlockArgListType getCrds() {
-      // The first block argument is iterator, the remaining arguments are
-      // referenced coordinates.
-      return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
+      // User-provided iteration arguments -> coords -> iterator.
+      return getRegion().getArguments().slice(getNumRegionIterArgs(), getCrdUsedLvls().count());
     }
     unsigned getNumRegionIterArgs() {
       return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 16856b958d4f13..b21bc1a93036c4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2228,9 +2228,10 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state,
         parser.getNameLoc(),
         "mismatch in number of sparse iterators and sparse spaces");
 
-  if (failed(parseUsedCoordList(parser, state, blockArgs)))
+  SmallVector<OpAsmParser::Argument> coords;
+  if (failed(parseUsedCoordList(parser, state, coords)))
     return failure();
-  size_t numCrds = blockArgs.size();
+  size_t numCrds = coords.size();
 
   // Parse "iter_args(%arg = %init, ...)"
   bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
@@ -2238,6 +2239,8 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state,
     if (parser.parseAssignmentList(blockArgs, initArgs))
       return failure();
 
+  blockArgs.append(coords);
+
   SmallVector<Type> iterSpaceTps;
   // parse ": sparse_tensor.iter_space -> ret"
   if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
@@ -2267,7 +2270,7 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state,
 
   if (hasIterArgs) {
     // Strip off leading args that used for coordinates.
-    MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
+    MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
     if (args.size() != initArgs.size() || args.size() != state.types.size()) {
       return parser.emitError(
           parser.getNameLoc(),
@@ -2448,18 +2451,18 @@ void IterateOp::build(OpBuilder &builder, OperationState &odsState,
   odsState.addTypes(initArgs.getTypes());
   Block *bodyBlock = builder.createBlock(bodyRegion);
 
-  // First argument, sparse iterator
-  bodyBlock->addArgument(
-      llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
-      odsState.location);
+  // Starts with a list of user-provided loop arguments.
+  for (Value v : initArgs)
+    bodyBlock->addArgument(v.getType(), v.getLoc());
 
-  // Followed by a list of used coordinates.
+  // Follows by a list of used coordinates.
   for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
     bodyBlock->addArgument(builder.getIndexType(), odsState.location);
 
-  // Followed by a list of user-provided loop arguments.
-  for (Value v : initArgs)
-    bodyBlock->addArgument(v.getType(), v.getLoc());
+  // Ends with sparse iterator
+  bodyBlock->addArgument(
+      llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
+      odsState.location);
 }
 
 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -2473,9 +2476,9 @@ ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
     return parser.emitError(parser.getNameLoc(),
                             "expected only one iterator/iteration space");
 
-  iters.append(iterArgs);
+  iterArgs.append(iters);
   Region *body = result.addRegion();
-  if (parser.parseRegion(*body, iters))
+  if (parser.parseRegion(*body, iterArgs))
     return failure();
 
   IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
@@ -2580,7 +2583,7 @@ MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
 }
 
 Block::BlockArgListType IterateOp::getRegionIterArgs() {
-  return getRegion().getArguments().take_back(getNumRegionIterArgs());
+  return getRegion().getArguments().take_front(getNumRegionIterArgs());
 }
 
 std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index f7fcabb0220b50..71a229bea990c0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -111,7 +111,7 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
 
 static ValueRange genLoopWithIterator(
     PatternRewriter &rewriter, Location loc, SparseIterator *it,
-    ValueRange reduc, bool iterFirst,
+    ValueRange reduc,
     function_ref<SmallVector<Value>(PatternRewriter &rewriter, Location loc,
                                     Region &loopBody, SparseIterator *it,
                                     ValueRange reduc)>
@@ -138,15 +138,9 @@ static ValueRange genLoopWithIterator(
     }
     return forOp.getResults();
   }
-  SmallVector<Value> ivs;
-  // TODO: always put iterator SSA values at the end of argument list to be
-  // consistent with coiterate operation.
-  if (!iterFirst)
-    llvm::append_range(ivs, it->getCursor());
-  // Appends the user-provided values.
-  llvm::append_range(ivs, reduc);
-  if (iterFirst)
-    llvm::append_range(ivs, it->getCursor());
+
+  SmallVector<Value> ivs(reduc);
+  llvm::append_range(ivs, it->getCursor());
 
   TypeRange types = ValueRange(ivs).getTypes();
   auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
@@ -164,12 +158,8 @@ static ValueRange genLoopWithIterator(
     Region &dstRegion = whileOp.getAfter();
     Block *after = rewriter.createBlock(&dstRegion, {}, types, l);
     ValueRange aArgs = whileOp.getAfterArguments();
-    if (iterFirst) {
-      aArgs = it->linkNewScope(aArgs);
-    } else {
-      aArgs = aArgs.take_front(reduc.size());
-      it->linkNewScope(aArgs.drop_front(reduc.size()));
-    }
+    it->linkNewScope(aArgs.drop_front(reduc.size()));
+    aArgs = aArgs.take_front(reduc.size());
 
     rewriter.setInsertionPointToStart(after);
     SmallVector<Value> ret = bodyBuilder(rewriter, loc, dstRegion, it, aArgs);
@@ -177,12 +167,8 @@ static ValueRange genLoopWithIterator(
 
     // Forward loops
     SmallVector<Value> yields;
-    ValueRange nx = it->forward(rewriter, loc);
-    if (iterFirst)
-      llvm::append_range(yields, nx);
     llvm::append_range(yields, ret);
-    if (!iterFirst)
-      llvm::append_range(yields, nx);
+    llvm::append_range(yields, it->forward(rewriter, loc));
     rewriter.create<scf::YieldOp>(loc, yields);
   }
   return whileOp.getResults().drop_front(it->getCursor().size());
@@ -258,13 +244,13 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
 
     Block *block = op.getBody();
     ValueRange ret = genLoopWithIterator(
-        rewriter, loc, it.get(), ivs, /*iterFirst=*/true,
+        rewriter, loc, it.get(), ivs,
         [block](PatternRewriter &rewriter, Location loc, Region &loopBody,
                 SparseIterator *it, ValueRange reduc) -> SmallVector<Value> {
-          SmallVector<Value> blockArgs(it->getCursor());
+          SmallVector<Value> blockArgs(reduc);
           // TODO: Also appends coordinates if used.
           // blockArgs.push_back(it->deref(rewriter, loc));
-          llvm::append_range(blockArgs, reduc);
+          llvm::append_range(blockArgs, it->getCursor());
 
           Block *dstBlock = &loopBody.getBlocks().front();
           rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(),
@@ -404,7 +390,7 @@ class SparseCoIterateOpConverter
 
         Block *block = &r.getBlocks().front();
         ValueRange curResult = genLoopWithIterator(
-            rewriter, loc, validIters.front(), userReduc, /*iterFirst=*/false,
+            rewriter, loc, validIters.front(), userReduc,
             /*bodyBuilder=*/
             [block](PatternRewriter &rewriter, Location loc, Region &dstRegion,
                     SparseIterator *it,

@PeimingLiu PeimingLiu changed the base branch from users/PeimingLiu/stack/2 to main August 22, 2024 23:30
PeimingLiu pushed a commit that referenced this pull request Aug 22, 2024
…operations.

stack-info: PR: #105567, branch: users/PeimingLiu/stack/3
@PeimingLiu PeimingLiu force-pushed the users/PeimingLiu/stack/3 branch from 3f83d7a to 8c40bac Compare August 22, 2024 23:30
@PeimingLiu PeimingLiu changed the base branch from main to users/PeimingLiu/stack/2 August 22, 2024 23:30
@PeimingLiu PeimingLiu force-pushed the users/PeimingLiu/stack/2 branch from 5d73f23 to 984d8d5 Compare August 23, 2024 17:47
PeimingLiu pushed a commit that referenced this pull request Aug 23, 2024
…operations.

stack-info: PR: #105567, branch: users/PeimingLiu/stack/3
@PeimingLiu PeimingLiu force-pushed the users/PeimingLiu/stack/3 branch from 8c40bac to 58bae5c Compare August 23, 2024 17:48
Base automatically changed from users/PeimingLiu/stack/2 to main August 23, 2024 18:21
@PeimingLiu PeimingLiu force-pushed the users/PeimingLiu/stack/3 branch from 58bae5c to 950043e Compare August 23, 2024 18:21
@PeimingLiu PeimingLiu merged commit b48ef8d into main Aug 23, 2024
11 checks passed
@PeimingLiu PeimingLiu deleted the users/PeimingLiu/stack/3 branch August 23, 2024 21:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants