7
7
// ===----------------------------------------------------------------------===//
8
8
//
9
9
// This is a simple 2D matrix class that supports reading, writing, resizing,
10
- // swapping rows, and swapping columns.
10
+ // swapping rows, and swapping columns. It can hold integers (MPInt) or rational
11
+ // numbers (Fraction).
11
12
//
12
13
// ===----------------------------------------------------------------------===//
13
14
14
15
#ifndef MLIR_ANALYSIS_PRESBURGER_MATRIX_H
15
16
#define MLIR_ANALYSIS_PRESBURGER_MATRIX_H
16
17
17
- #include " mlir/Analysis/Presburger/MPInt.h"
18
18
#include " mlir/Support/LLVM.h"
19
+ #include " mlir/Analysis/Presburger/Fraction.h"
20
+ #include " mlir/Analysis/Presburger/Matrix.h"
19
21
#include " llvm/ADT/ArrayRef.h"
20
22
#include " llvm/Support/raw_ostream.h"
21
23
@@ -32,7 +34,12 @@ namespace presburger {
32
34
// / (i, j) is stored at data[i*nReservedColumns + j]. The reserved but unused
33
35
// / columns always have all zero values. The reserved rows are just reserved
34
36
// / space in the underlying SmallVector's capacity.
37
+ // / This class only works for the types MPInt and Fraction, since the method
38
+ // / implementations are in the Matrix.cpp file. Only these two types have
39
+ // / been explicitly instantiated there.
40
+ template <typename T>
35
41
class Matrix {
42
+ static_assert (std::is_same_v<T,MPInt> || std::is_same_v<T,Fraction>, " T must be MPInt or Fraction." );
36
43
public:
37
44
Matrix () = delete ;
38
45
@@ -49,21 +56,21 @@ class Matrix {
49
56
static Matrix identity (unsigned dimension);
50
57
51
58
// / Access the element at the specified row and column.
52
- MPInt &at (unsigned row, unsigned column) {
59
+ T &at (unsigned row, unsigned column) {
53
60
assert (row < nRows && " Row outside of range" );
54
61
assert (column < nColumns && " Column outside of range" );
55
62
return data[row * nReservedColumns + column];
56
63
}
57
64
58
- MPInt at (unsigned row, unsigned column) const {
65
+ T at (unsigned row, unsigned column) const {
59
66
assert (row < nRows && " Row outside of range" );
60
67
assert (column < nColumns && " Column outside of range" );
61
68
return data[row * nReservedColumns + column];
62
69
}
63
70
64
- MPInt &operator ()(unsigned row, unsigned column) { return at (row, column); }
71
+ T &operator ()(unsigned row, unsigned column) { return at (row, column); }
65
72
66
- MPInt operator ()(unsigned row, unsigned column) const {
73
+ T operator ()(unsigned row, unsigned column) const {
67
74
return at (row, column);
68
75
}
69
76
@@ -87,11 +94,11 @@ class Matrix {
87
94
void reserveRows (unsigned rows);
88
95
89
96
// / Get a [Mutable]ArrayRef corresponding to the specified row.
90
- MutableArrayRef<MPInt > getRow (unsigned row);
91
- ArrayRef<MPInt > getRow (unsigned row) const ;
97
+ MutableArrayRef<T > getRow (unsigned row);
98
+ ArrayRef<T > getRow (unsigned row) const ;
92
99
93
100
// / Set the specified row to `elems`.
94
- void setRow (unsigned row, ArrayRef<MPInt > elems);
101
+ void setRow (unsigned row, ArrayRef<T > elems);
95
102
96
103
// / Insert columns having positions pos, pos + 1, ... pos + count - 1.
97
104
// / Columns that were at positions 0 to pos - 1 will stay where they are;
@@ -125,23 +132,23 @@ class Matrix {
125
132
126
133
void copyRow (unsigned sourceRow, unsigned targetRow);
127
134
128
- void fillRow (unsigned row, const MPInt &value);
129
- void fillRow (unsigned row, int64_t value) { fillRow (row, MPInt (value)); }
135
+ void fillRow (unsigned row, const T &value);
136
+ void fillRow (unsigned row, int64_t value) { fillRow (row, T (value)); }
130
137
131
138
// / Add `scale` multiples of the source row to the target row.
132
- void addToRow (unsigned sourceRow, unsigned targetRow, const MPInt &scale);
139
+ void addToRow (unsigned sourceRow, unsigned targetRow, const T &scale);
133
140
void addToRow (unsigned sourceRow, unsigned targetRow, int64_t scale) {
134
- addToRow (sourceRow, targetRow, MPInt (scale));
141
+ addToRow (sourceRow, targetRow, T (scale));
135
142
}
136
143
// / Add `scale` multiples of the rowVec row to the specified row.
137
- void addToRow (unsigned row, ArrayRef<MPInt > rowVec, const MPInt &scale);
144
+ void addToRow (unsigned row, ArrayRef<T > rowVec, const T &scale);
138
145
139
146
// / Add `scale` multiples of the source column to the target column.
140
147
void addToColumn (unsigned sourceColumn, unsigned targetColumn,
141
- const MPInt &scale);
148
+ const T &scale);
142
149
void addToColumn (unsigned sourceColumn, unsigned targetColumn,
143
150
int64_t scale) {
144
- addToColumn (sourceColumn, targetColumn, MPInt (scale));
151
+ addToColumn (sourceColumn, targetColumn, T (scale));
145
152
}
146
153
147
154
// / Negate the specified column.
@@ -152,18 +159,18 @@ class Matrix {
152
159
153
160
// / Divide the first `nCols` of the specified row by their GCD.
154
161
// / Returns the GCD of the first `nCols` of the specified row.
155
- MPInt normalizeRow (unsigned row, unsigned nCols);
162
+ T normalizeRow (unsigned row, unsigned nCols);
156
163
// / Divide the columns of the specified row by their GCD.
157
164
// / Returns the GCD of the columns of the specified row.
158
- MPInt normalizeRow (unsigned row);
165
+ T normalizeRow (unsigned row);
159
166
160
167
// / The given vector is interpreted as a row vector v. Post-multiply v with
161
168
// / this matrix, say M, and return vM.
162
- SmallVector<MPInt , 8 > preMultiplyWithRow (ArrayRef<MPInt > rowVec) const ;
169
+ SmallVector<T , 8 > preMultiplyWithRow (ArrayRef<T > rowVec) const ;
163
170
164
171
// / The given vector is interpreted as a column vector v. Pre-multiply v with
165
172
// / this matrix, say M, and return Mv.
166
- SmallVector<MPInt , 8 > postMultiplyWithColumn (ArrayRef<MPInt > colVec) const ;
173
+ SmallVector<T , 8 > postMultiplyWithColumn (ArrayRef<T > colVec) const ;
167
174
168
175
// / Given the current matrix M, returns the matrices H, U such that H is the
169
176
// / column hermite normal form of M, i.e. H = M * U, where U is unimodular and
@@ -192,7 +199,7 @@ class Matrix {
192
199
unsigned appendExtraRow ();
193
200
// / Same as above, but copy the given elements into the row. The length of
194
201
// / `elems` must be equal to the number of columns.
195
- unsigned appendExtraRow (ArrayRef<MPInt > elems);
202
+ unsigned appendExtraRow (ArrayRef<T > elems);
196
203
197
204
// / Print the matrix.
198
205
void print (raw_ostream &os) const ;
@@ -211,7 +218,7 @@ class Matrix {
211
218
212
219
// / Stores the data. data.size() is equal to nRows * nReservedColumns.
213
220
// / data.capacity() / nReservedColumns is the number of reserved rows.
214
- SmallVector<MPInt , 16 > data;
221
+ SmallVector<T , 16 > data;
215
222
};
216
223
217
224
} // namespace presburger
0 commit comments