Skip to content

Commit efca035

Browse files
[MLIR][Presburger] Template Matrix to allow MPInt and Fraction (#65272)
The method implementations remain in the .cpp file; explicit instantiations have been added for these two types. makeMatrix has been duplicated to makeIntMatrix and makeFracMatrix.
1 parent e18fa6e commit efca035

File tree

15 files changed

+163
-127
lines changed

15 files changed

+163
-127
lines changed

mlir/include/mlir/Analysis/Presburger/IntegerRelation.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ class IntegerRelation {
366366
/// bounded. The span of the returned vectors is guaranteed to contain all
367367
/// such vectors. The returned vectors are NOT guaranteed to be linearly
368368
/// independent. This function should not be called on empty sets.
369-
Matrix getBoundedDirections() const;
369+
Matrix<MPInt> getBoundedDirections() const;
370370

371371
/// Find an integer sample point satisfying the constraints using a
372372
/// branch and bound algorithm with generalized basis reduction, with some
@@ -792,10 +792,10 @@ class IntegerRelation {
792792
PresburgerSpace space;
793793

794794
/// Coefficients of affine equalities (in == 0 form).
795-
Matrix equalities;
795+
Matrix<MPInt> equalities;
796796

797797
/// Coefficients of affine inequalities (in >= 0 form).
798-
Matrix inequalities;
798+
Matrix<MPInt> inequalities;
799799
};
800800

801801
/// An IntegerPolyhedron represents the set of points from a PresburgerSpace

mlir/include/mlir/Analysis/Presburger/LinearTransform.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ namespace presburger {
2222

2323
class LinearTransform {
2424
public:
25-
explicit LinearTransform(Matrix &&oMatrix);
26-
explicit LinearTransform(const Matrix &oMatrix);
25+
explicit LinearTransform(Matrix<MPInt> &&oMatrix);
26+
explicit LinearTransform(const Matrix<MPInt> &oMatrix);
2727

2828
// Returns a linear transform T such that MT is M in column echelon form.
2929
// Also returns the number of non-zero columns in MT.
@@ -32,7 +32,7 @@ class LinearTransform {
3232
// strictly below that of the previous column, and all columns which have only
3333
// zeros are at the end.
3434
static std::pair<unsigned, LinearTransform>
35-
makeTransformToColumnEchelon(const Matrix &m);
35+
makeTransformToColumnEchelon(const Matrix<MPInt> &m);
3636

3737
// Returns an IntegerRelation having a constraint vector vT for every
3838
// constraint vector v in rel, where T is this transform.
@@ -50,8 +50,12 @@ class LinearTransform {
5050
return matrix.postMultiplyWithColumn(colVec);
5151
}
5252

53+
// Compute the determinant of the transform by converting it to row echelon
54+
// form and then taking the product of the diagonal.
55+
MPInt determinant();
56+
5357
private:
54-
Matrix matrix;
58+
Matrix<MPInt> matrix;
5559
};
5660

5761
} // namespace presburger

mlir/include/mlir/Analysis/Presburger/Matrix.h

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,17 @@
77
//===----------------------------------------------------------------------===//
88
//
99
// This is a simple 2D matrix class that supports reading, writing, resizing,
10-
// swapping rows, and swapping columns.
10+
// swapping rows, and swapping columns. It can hold integers (MPInt) or rational
11+
// numbers (Fraction).
1112
//
1213
//===----------------------------------------------------------------------===//
1314

1415
#ifndef MLIR_ANALYSIS_PRESBURGER_MATRIX_H
1516
#define MLIR_ANALYSIS_PRESBURGER_MATRIX_H
1617

17-
#include "mlir/Analysis/Presburger/MPInt.h"
1818
#include "mlir/Support/LLVM.h"
19+
#include "mlir/Analysis/Presburger/Fraction.h"
20+
#include "mlir/Analysis/Presburger/Matrix.h"
1921
#include "llvm/ADT/ArrayRef.h"
2022
#include "llvm/Support/raw_ostream.h"
2123

@@ -32,7 +34,12 @@ namespace presburger {
3234
/// (i, j) is stored at data[i*nReservedColumns + j]. The reserved but unused
3335
/// columns always have all zero values. The reserved rows are just reserved
3436
/// space in the underlying SmallVector's capacity.
37+
/// This class only works for the types MPInt and Fraction, since the method
38+
/// implementations are in the Matrix.cpp file. Only these two types have
39+
/// been explicitly instantiated there.
40+
template<typename T>
3541
class Matrix {
42+
static_assert(std::is_same_v<T,MPInt> || std::is_same_v<T,Fraction>, "T must be MPInt or Fraction.");
3643
public:
3744
Matrix() = delete;
3845

@@ -49,21 +56,21 @@ class Matrix {
4956
static Matrix identity(unsigned dimension);
5057

5158
/// Access the element at the specified row and column.
52-
MPInt &at(unsigned row, unsigned column) {
59+
T &at(unsigned row, unsigned column) {
5360
assert(row < nRows && "Row outside of range");
5461
assert(column < nColumns && "Column outside of range");
5562
return data[row * nReservedColumns + column];
5663
}
5764

58-
MPInt at(unsigned row, unsigned column) const {
65+
T at(unsigned row, unsigned column) const {
5966
assert(row < nRows && "Row outside of range");
6067
assert(column < nColumns && "Column outside of range");
6168
return data[row * nReservedColumns + column];
6269
}
6370

64-
MPInt &operator()(unsigned row, unsigned column) { return at(row, column); }
71+
T &operator()(unsigned row, unsigned column) { return at(row, column); }
6572

66-
MPInt operator()(unsigned row, unsigned column) const {
73+
T operator()(unsigned row, unsigned column) const {
6774
return at(row, column);
6875
}
6976

@@ -87,11 +94,11 @@ class Matrix {
8794
void reserveRows(unsigned rows);
8895

8996
/// Get a [Mutable]ArrayRef corresponding to the specified row.
90-
MutableArrayRef<MPInt> getRow(unsigned row);
91-
ArrayRef<MPInt> getRow(unsigned row) const;
97+
MutableArrayRef<T> getRow(unsigned row);
98+
ArrayRef<T> getRow(unsigned row) const;
9299

93100
/// Set the specified row to `elems`.
94-
void setRow(unsigned row, ArrayRef<MPInt> elems);
101+
void setRow(unsigned row, ArrayRef<T> elems);
95102

96103
/// Insert columns having positions pos, pos + 1, ... pos + count - 1.
97104
/// Columns that were at positions 0 to pos - 1 will stay where they are;
@@ -125,23 +132,23 @@ class Matrix {
125132

126133
void copyRow(unsigned sourceRow, unsigned targetRow);
127134

128-
void fillRow(unsigned row, const MPInt &value);
129-
void fillRow(unsigned row, int64_t value) { fillRow(row, MPInt(value)); }
135+
void fillRow(unsigned row, const T &value);
136+
void fillRow(unsigned row, int64_t value) { fillRow(row, T(value)); }
130137

131138
/// Add `scale` multiples of the source row to the target row.
132-
void addToRow(unsigned sourceRow, unsigned targetRow, const MPInt &scale);
139+
void addToRow(unsigned sourceRow, unsigned targetRow, const T &scale);
133140
void addToRow(unsigned sourceRow, unsigned targetRow, int64_t scale) {
134-
addToRow(sourceRow, targetRow, MPInt(scale));
141+
addToRow(sourceRow, targetRow, T(scale));
135142
}
136143
/// Add `scale` multiples of the rowVec row to the specified row.
137-
void addToRow(unsigned row, ArrayRef<MPInt> rowVec, const MPInt &scale);
144+
void addToRow(unsigned row, ArrayRef<T> rowVec, const T &scale);
138145

139146
/// Add `scale` multiples of the source column to the target column.
140147
void addToColumn(unsigned sourceColumn, unsigned targetColumn,
141-
const MPInt &scale);
148+
const T &scale);
142149
void addToColumn(unsigned sourceColumn, unsigned targetColumn,
143150
int64_t scale) {
144-
addToColumn(sourceColumn, targetColumn, MPInt(scale));
151+
addToColumn(sourceColumn, targetColumn, T(scale));
145152
}
146153

147154
/// Negate the specified column.
@@ -152,18 +159,18 @@ class Matrix {
152159

153160
/// Divide the first `nCols` of the specified row by their GCD.
154161
/// Returns the GCD of the first `nCols` of the specified row.
155-
MPInt normalizeRow(unsigned row, unsigned nCols);
162+
T normalizeRow(unsigned row, unsigned nCols);
156163
/// Divide the columns of the specified row by their GCD.
157164
/// Returns the GCD of the columns of the specified row.
158-
MPInt normalizeRow(unsigned row);
165+
T normalizeRow(unsigned row);
159166

160167
/// The given vector is interpreted as a row vector v. Post-multiply v with
161168
/// this matrix, say M, and return vM.
162-
SmallVector<MPInt, 8> preMultiplyWithRow(ArrayRef<MPInt> rowVec) const;
169+
SmallVector<T, 8> preMultiplyWithRow(ArrayRef<T> rowVec) const;
163170

164171
/// The given vector is interpreted as a column vector v. Pre-multiply v with
165172
/// this matrix, say M, and return Mv.
166-
SmallVector<MPInt, 8> postMultiplyWithColumn(ArrayRef<MPInt> colVec) const;
173+
SmallVector<T, 8> postMultiplyWithColumn(ArrayRef<T> colVec) const;
167174

168175
/// Given the current matrix M, returns the matrices H, U such that H is the
169176
/// column hermite normal form of M, i.e. H = M * U, where U is unimodular and
@@ -192,7 +199,7 @@ class Matrix {
192199
unsigned appendExtraRow();
193200
/// Same as above, but copy the given elements into the row. The length of
194201
/// `elems` must be equal to the number of columns.
195-
unsigned appendExtraRow(ArrayRef<MPInt> elems);
202+
unsigned appendExtraRow(ArrayRef<T> elems);
196203

197204
/// Print the matrix.
198205
void print(raw_ostream &os) const;
@@ -211,7 +218,7 @@ class Matrix {
211218

212219
/// Stores the data. data.size() is equal to nRows * nReservedColumns.
213220
/// data.capacity() / nReservedColumns is the number of reserved rows.
214-
SmallVector<MPInt, 16> data;
221+
SmallVector<T, 16> data;
215222
};
216223

217224
} // namespace presburger

mlir/include/mlir/Analysis/Presburger/PWMAFunction.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,13 @@ enum class OrderingKind { EQ, NE, LT, LE, GT, GE };
4040
/// value of the function at a specified point.
4141
class MultiAffineFunction {
4242
public:
43-
MultiAffineFunction(const PresburgerSpace &space, const Matrix &output)
43+
MultiAffineFunction(const PresburgerSpace &space, const Matrix<MPInt> &output)
4444
: space(space), output(output),
4545
divs(space.getNumVars() - space.getNumRangeVars()) {
4646
assertIsConsistent();
4747
}
4848

49-
MultiAffineFunction(const PresburgerSpace &space, const Matrix &output,
49+
MultiAffineFunction(const PresburgerSpace &space, const Matrix<MPInt> &output,
5050
const DivisionRepr &divs)
5151
: space(space), output(output), divs(divs) {
5252
assertIsConsistent();
@@ -65,7 +65,7 @@ class MultiAffineFunction {
6565
PresburgerSpace getOutputSpace() const { return space.getRangeSpace(); }
6666

6767
/// Get a matrix with each row representing row^th output expression.
68-
const Matrix &getOutputMatrix() const { return output; }
68+
const Matrix<MPInt> &getOutputMatrix() const { return output; }
6969
/// Get the `i^th` output expression.
7070
ArrayRef<MPInt> getOutputExpr(unsigned i) const { return output.getRow(i); }
7171

@@ -124,7 +124,7 @@ class MultiAffineFunction {
124124
/// The function's output is a tuple of integers, with the ith element of the
125125
/// tuple defined by the affine expression given by the ith row of this output
126126
/// matrix.
127-
Matrix output;
127+
Matrix<MPInt> output;
128128

129129
/// Storage for division representation for each local variable in space.
130130
DivisionRepr divs;

mlir/include/mlir/Analysis/Presburger/Simplex.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ class SimplexBase {
338338
unsigned nSymbol;
339339

340340
/// The matrix representing the tableau.
341-
Matrix tableau;
341+
Matrix<MPInt> tableau;
342342

343343
/// This is true if the tableau has been detected to be empty, false
344344
/// otherwise.
@@ -861,7 +861,7 @@ class Simplex : public SimplexBase {
861861

862862
/// Reduce the given basis, starting at the specified level, using general
863863
/// basis reduction.
864-
void reduceBasis(Matrix &basis, unsigned level);
864+
void reduceBasis(Matrix<MPInt> &basis, unsigned level);
865865
};
866866

867867
/// Takes a snapshot of the simplex state on construction and rolls back to the

mlir/include/mlir/Analysis/Presburger/Utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ class DivisionRepr {
182182
/// Each row of the Matrix represents a single division dividend. The
183183
/// `i^th` row represents the dividend of the variable at `divOffset + i`
184184
/// in the constraint system (and the `i^th` local variable).
185-
Matrix dividends;
185+
Matrix<MPInt> dividends;
186186

187187
/// Denominators of each division. If a denominator of a division is `0`, the
188188
/// division variable is considered to not have a division representation.

mlir/lib/Analysis/FlatLinearValueConstraints.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1292,7 +1292,7 @@ mlir::getMultiAffineFunctionFromMap(AffineMap map,
12921292
"AffineMap cannot produce divs without local representation");
12931293

12941294
// TODO: We shouldn't have to do this conversion.
1295-
Matrix mat(map.getNumResults(), map.getNumInputs() + divs.getNumDivs() + 1);
1295+
Matrix<MPInt> mat(map.getNumResults(), map.getNumInputs() + divs.getNumDivs() + 1);
12961296
for (unsigned i = 0, e = flattenedExprs.size(); i < e; ++i)
12971297
for (unsigned j = 0, f = flattenedExprs[i].size(); j < f; ++j)
12981298
mat(i, j) = flattenedExprs[i][j];

mlir/lib/Analysis/Presburger/IntegerRelation.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ SymbolicLexOpt IntegerRelation::findSymbolicIntegerLexMax() const {
304304
// Get lexmax by flipping range sign in the PWMA constraints.
305305
for (auto &flippedPiece :
306306
flippedSymbolicIntegerLexMax.lexopt.getAllPieces()) {
307-
Matrix mat = flippedPiece.output.getOutputMatrix();
307+
Matrix<MPInt> mat = flippedPiece.output.getOutputMatrix();
308308
for (unsigned i = 0, e = mat.getNumRows(); i < e; i++)
309309
mat.negateRow(i);
310310
MultiAffineFunction maf(flippedPiece.output.getSpace(), mat);
@@ -738,7 +738,7 @@ bool IntegerRelation::isEmptyByGCDTest() const {
738738
//
739739
// It is sufficient to check the perpendiculars of the constraints, as the set
740740
// of perpendiculars which are bounded must span all bounded directions.
741-
Matrix IntegerRelation::getBoundedDirections() const {
741+
Matrix<MPInt> IntegerRelation::getBoundedDirections() const {
742742
// Note that it is necessary to add the equalities too (which the constructor
743743
// does) even though we don't need to check if they are bounded; whether an
744744
// inequality is bounded or not depends on what other constraints, including
@@ -759,7 +759,7 @@ Matrix IntegerRelation::getBoundedDirections() const {
759759
// The direction vector is given by the coefficients and does not include the
760760
// constant term, so the matrix has one fewer column.
761761
unsigned dirsNumCols = getNumCols() - 1;
762-
Matrix dirs(boundedIneqs.size() + getNumEqualities(), dirsNumCols);
762+
Matrix<MPInt> dirs(boundedIneqs.size() + getNumEqualities(), dirsNumCols);
763763

764764
// Copy the bounded inequalities.
765765
unsigned row = 0;
@@ -845,7 +845,7 @@ IntegerRelation::findIntegerSample() const {
845845
// m is a matrix containing, in each row, a vector in which S is
846846
// bounded, such that the linear span of all these dimensions contains all
847847
// bounded dimensions in S.
848-
Matrix m = getBoundedDirections();
848+
Matrix<MPInt> m = getBoundedDirections();
849849
// In column echelon form, each row of m occupies only the first rank(m)
850850
// columns and has zeros on the other columns. The transform T that brings S
851851
// to column echelon form is unimodular as well, so this is a suitable

mlir/lib/Analysis/Presburger/LinearTransform.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
using namespace mlir;
1313
using namespace presburger;
1414

15-
LinearTransform::LinearTransform(Matrix &&oMatrix) : matrix(oMatrix) {}
16-
LinearTransform::LinearTransform(const Matrix &oMatrix) : matrix(oMatrix) {}
15+
LinearTransform::LinearTransform(Matrix<MPInt> &&oMatrix) : matrix(oMatrix) {}
16+
LinearTransform::LinearTransform(const Matrix<MPInt> &oMatrix) : matrix(oMatrix) {}
1717

1818
std::pair<unsigned, LinearTransform>
19-
LinearTransform::makeTransformToColumnEchelon(const Matrix &m) {
19+
LinearTransform::makeTransformToColumnEchelon(const Matrix<MPInt> &m) {
2020
// Compute the hermite normal form of m. This, is by definition, is in column
2121
// echelon form.
2222
auto [h, u] = m.computeHermiteNormalForm();

0 commit comments

Comments
 (0)