Skip to content

Commit 4176ce6

Browse files
authored
[mlir][sparse] fix logical error when generating sort_coo. (llvm#66690)
To fix issue: llvm#66664
1 parent cacdb90 commit 4176ce6

File tree

3 files changed

+69
-380
lines changed

3 files changed

+69
-380
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -563,11 +563,14 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
563563
// p = (lo+hi)/2 // pivot index
564564
// i = lo
565565
// j = hi-1
566-
// while (i < j) do {
566+
// while (true) do {
567567
// while (xs[i] < xs[p]) i ++;
568568
// i_eq = (xs[i] == xs[p]);
569569
// while (xs[j] > xs[p]) j --;
570570
// j_eq = (xs[j] == xs[p]);
571+
//
572+
// if (i >= j) return j + 1;
573+
//
571574
// if (i < j) {
572575
// swap(xs[i], xs[j])
573576
// if (i == p) {
@@ -581,8 +584,7 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
581584
// }
582585
// }
583586
// }
584-
// return p
585-
// }
587+
// }
586588
static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
587589
func::FuncOp func, uint64_t nx, uint64_t ny,
588590
bool isCoo, uint32_t nTrailingP = 0) {
@@ -605,22 +607,22 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
605607
Value i = lo;
606608
Value j = builder.create<arith::SubIOp>(loc, hi, c1);
607609
createChoosePivot(builder, module, func, nx, ny, isCoo, i, j, p, args);
608-
SmallVector<Value, 3> operands{i, j, p}; // Exactly three values.
609-
SmallVector<Type, 3> types{i.getType(), j.getType(), p.getType()};
610+
Value trueVal = constantI1(builder, loc, true); // The value for while (true)
611+
SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values.
612+
SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(),
613+
trueVal.getType()};
610614
scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
611615

612616
// The before-region of the WhileOp.
613-
Block *before =
614-
builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc, loc});
617+
Block *before = builder.createBlock(&whileOp.getBefore(), {}, types,
618+
{loc, loc, loc, loc});
615619
builder.setInsertionPointToEnd(before);
616-
Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
617-
before->getArgument(0),
618-
before->getArgument(1));
619-
builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
620+
builder.create<scf::ConditionOp>(loc, before->getArgument(3),
621+
before->getArguments());
620622

621623
// The after-region of the WhileOp.
622624
Block *after =
623-
builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc});
625+
builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc});
624626
builder.setInsertionPointToEnd(after);
625627
i = after->getArgument(0);
626628
j = after->getArgument(1);
@@ -637,7 +639,8 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
637639
j = jresult;
638640

639641
// If i < j:
640-
cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
642+
Value cond =
643+
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
641644
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
642645
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
643646
SmallVector<Value> swapOperands{i, j};
@@ -675,11 +678,15 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
675678
builder.setInsertionPointAfter(ifOp2);
676679
builder.create<scf::YieldOp>(
677680
loc,
678-
ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0)});
681+
ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0),
682+
/*cont=*/constantI1(builder, loc, true)});
679683

680-
// False branch for if i < j:
684+
// False branch for if i < j (i.e., i >= j):
681685
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
682-
builder.create<scf::YieldOp>(loc, ValueRange{i, j, p});
686+
p = builder.create<arith::AddIOp>(loc, j,
687+
constantOne(builder, loc, j.getType()));
688+
builder.create<scf::YieldOp>(
689+
loc, ValueRange{i, j, p, /*cont=*/constantI1(builder, loc, false)});
683690

684691
// Return for the whileOp.
685692
builder.setInsertionPointAfter(ifOp);
@@ -927,6 +934,8 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
927934
Location loc = func.getLoc();
928935
Value lo = args[loIdx];
929936
Value hi = args[hiIdx];
937+
SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
938+
930939
FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
931940
builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx,
932941
ny, isCoo, args.drop_back(nTrailingP), createPartitionFunc);
@@ -935,14 +944,25 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
935944
TypeRange{IndexType::get(context)},
936945
args.drop_back(nTrailingP))
937946
.getResult(0);
938-
Value pP1 =
939-
builder.create<arith::AddIOp>(loc, p, constantIndex(builder, loc, 1));
947+
940948
Value lenLow = builder.create<arith::SubIOp>(loc, p, lo);
941949
Value lenHigh = builder.create<arith::SubIOp>(loc, hi, p);
950+
// Partition already sorts array with len <= 2
951+
Value c2 = constantIndex(builder, loc, 2);
952+
Value len = builder.create<arith::SubIOp>(loc, hi, lo);
953+
Value lenGtTwo =
954+
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, len, c2);
955+
scf::IfOp ifLenGtTwo =
956+
builder.create<scf::IfOp>(loc, types, lenGtTwo, /*else=*/true);
957+
builder.setInsertionPointToStart(&ifLenGtTwo.getElseRegion().front());
958+
// Returns an empty range to mark the entire region is fully sorted.
959+
builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
960+
961+
// Else len > 2, need recursion.
962+
builder.setInsertionPointToStart(&ifLenGtTwo.getThenRegion().front());
942963
Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
943964
lenLow, lenHigh);
944965

945-
SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
946966
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
947967

948968
Value c0 = constantIndex(builder, loc, 0);
@@ -961,14 +981,17 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
961981
// the bigger partition to be processed by the enclosed while-loop.
962982
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
963983
mayRecursion(lo, p, lenLow);
964-
builder.create<scf::YieldOp>(loc, ValueRange{pP1, hi});
984+
builder.create<scf::YieldOp>(loc, ValueRange{p, hi});
965985

966986
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
967-
mayRecursion(pP1, hi, lenHigh);
987+
mayRecursion(p, hi, lenHigh);
968988
builder.create<scf::YieldOp>(loc, ValueRange{lo, p});
969989

970990
builder.setInsertionPointAfter(ifOp);
971-
return std::make_pair(ifOp.getResult(0), ifOp.getResult(1));
991+
builder.create<scf::YieldOp>(loc, ifOp.getResults());
992+
993+
builder.setInsertionPointAfter(ifLenGtTwo);
994+
return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1));
972995
}
973996

974997
/// Creates a function to perform insertion sort on the values in the range of

0 commit comments

Comments
 (0)