Skip to content

Commit 8dc6abb

Browse files
[mlir][presburger] Implement moveColumns using std::rotate (#168243)
1 parent 28e2004 commit 8dc6abb

File tree

1 file changed

+15
-26
lines changed

1 file changed

+15
-26
lines changed

mlir/lib/Analysis/Presburger/Matrix.cpp

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -255,44 +255,33 @@ void Matrix<T>::fillRow(unsigned row, const T &value) {
255255
}
256256

257257
// moveColumns is implemented by moving the columns adjacent to the source range
258-
// to their final position. When moving right (i.e. dstPos > srcPos), the range
259-
// of the adjacent columns is [srcPos + num, dstPos + num). When moving left
260-
// (i.e. dstPos < srcPos) the range of the adjacent columns is [dstPos, srcPos).
261-
// First, zeroed out columns are inserted in the final positions of the adjacent
262-
// columns. Then, the adjacent columns are moved to their final positions by
263-
// swapping them with the zeroed columns. Finally, the now zeroed adjacent
264-
// columns are deleted.
258+
// to their final position.
265259
template <typename T>
266260
void Matrix<T>::moveColumns(unsigned srcPos, unsigned num, unsigned dstPos) {
267261
if (num == 0)
268262
return;
269263

270-
int offset = dstPos - srcPos;
271-
if (offset == 0)
264+
if (dstPos == srcPos)
272265
return;
273266

274267
assert(srcPos + num <= getNumColumns() &&
275268
"move source range exceeds matrix columns");
276269
assert(dstPos + num <= getNumColumns() &&
277270
"move destination range exceeds matrix columns");
278271

279-
unsigned insertCount = offset > 0 ? offset : -offset;
280-
unsigned finalAdjStart = offset > 0 ? srcPos : srcPos + num;
281-
unsigned curAdjStart = offset > 0 ? srcPos + num : dstPos;
282-
// TODO: This can be done using std::rotate.
283-
// Insert new zero columns in the positions where the adjacent columns are to
284-
// be moved.
285-
insertColumns(finalAdjStart, insertCount);
286-
// Update curAdjStart if insertion of new columns invalidates it.
287-
if (finalAdjStart < curAdjStart)
288-
curAdjStart += insertCount;
289-
290-
// Swap the adjacent columns with inserted zero columns.
291-
for (unsigned i = 0; i < insertCount; ++i)
292-
swapColumns(finalAdjStart + i, curAdjStart + i);
293-
294-
// Delete the now redundant zero columns.
295-
removeColumns(curAdjStart, insertCount);
272+
unsigned numRows = getNumRows();
273+
// std::rotate(start, middle, end) permutes the elements of [start, end] to
274+
// [middle, end) + [start, middle). NOTE: &at(i, srcPos + num) will trigger an
275+
// assert.
276+
if (dstPos > srcPos) {
277+
for (unsigned i = 0; i < numRows; ++i) {
278+
std::rotate(&at(i, srcPos), &at(i, srcPos) + num, &at(i, dstPos) + num);
279+
}
280+
return;
281+
}
282+
for (unsigned i = 0; i < numRows; ++i) {
283+
std::rotate(&at(i, dstPos), &at(i, srcPos), &at(i, srcPos) + num);
284+
}
296285
}
297286

298287
template <typename T>

0 commit comments

Comments
 (0)