diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h index 709e4f8438356c..6add6e8aa9b244 100644 --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -536,7 +536,6 @@ class IntegerRelation : public PresburgerSpace { Matrix inequalities; }; -struct SymbolicLexMin; /// An IntegerPolyhedron is a PresburgerSpace subject to affine /// constraints. Affine constraints can be inequalities or equalities in the /// form: @@ -594,28 +593,6 @@ class IntegerPolyhedron : public IntegerRelation { /// column position (i.e., not relative to the kind of identifier) of the /// first added identifier. unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override; - - /// Compute the symbolic integer lexmin of the polyhedron. - /// This finds, for every assignment to the symbols, the lexicographically - /// minimum value attained by the dimensions. For example, the symbolic lexmin - /// of the set - /// - /// (x, y)[a, b, c] : (a <= x, b <= x, x <= c) - /// - /// can be written as - /// - /// x = a if b <= a, a <= c - /// x = b if a < b, b <= c - /// - /// This function is stored in the `lexmin` function in the result. - /// Some assignments to the symbols might make the set empty. - /// Such points are not part of the function's domain. - /// In the above example, this happens when max(a, b) > c. - /// - /// For some values of the symbols, the lexmin may be unbounded. - /// `SymbolicLexMin` stores these parts of the symbolic domain in a separate - /// `PresburgerSet`, `unboundedDomain`. - SymbolicLexMin findSymbolicIntegerLexMin() const; }; } // namespace presburger diff --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h index e2ad543070a4b9..940b88d8148f43 100644 --- a/mlir/include/mlir/Analysis/Presburger/Matrix.h +++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h @@ -151,9 +151,6 @@ class Matrix { /// Add an extra row at the bottom of the matrix and return its position. unsigned appendExtraRow(); - /// Same as above, but copy the given elements into the row. The length of - /// `elems` must be equal to the number of columns. - unsigned appendExtraRow(ArrayRef elems); /// Print the matrix. void print(raw_ostream &os) const; diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h index f4bffe5b4e7a49..ce0d77da9bc2cf 100644 --- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h +++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h @@ -106,11 +106,6 @@ class MultiAffineFunction : protected IntegerPolyhedron { /// outside the domain, an empty optional is returned. Optional> valueAt(ArrayRef point) const; - /// Truncate the output dimensions to the first `count` dimensions. - /// - /// TODO: refactor so that this can be accomplished through removeIdRange. - void truncateOutput(unsigned count); - void print(raw_ostream &os) const; void dump() const; @@ -170,11 +165,6 @@ class PWMAFunction : public PresburgerSpace { /// value at every point in the domain. bool isEqual(const PWMAFunction &other) const; - /// Truncate the output dimensions to the first `count` dimensions. - /// - /// TODO: refactor so that this can be accomplished through removeIdRange. - void truncateOutput(unsigned count); - void print(raw_ostream &os) const; void dump() const; diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h index 67a4b5f68e202d..66d408dbf8b69c 100644 --- a/mlir/include/mlir/Analysis/Presburger/Simplex.h +++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h @@ -18,7 +18,6 @@ #include "mlir/Analysis/Presburger/Fraction.h" #include "mlir/Analysis/Presburger/IntegerRelation.h" #include "mlir/Analysis/Presburger/Matrix.h" -#include "mlir/Analysis/Presburger/PWMAFunction.h" #include "mlir/Analysis/Presburger/Utils.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/ArrayRef.h" @@ -42,9 +41,8 @@ class GBRSimplex; /// these constraints that are redundant, i.e. a subset of constraints that /// doesn't constrain the affine set further after adding the non-redundant /// constraints. The LexSimplex class provides support for computing the -/// lexicographic minimum of an IntegerRelation. The SymbolicLexMin class -/// provides support for computing symbolic lexicographic minimums. All of these -/// classes can be constructed from an IntegerRelation, and all inherit common +/// lexicographical minimum of an IntegerRelation. Both these classes can be +/// constructed from an IntegerRelation, and both inherit common /// functionality from SimplexBase. /// /// The implementations of the Simplex and SimplexBase classes, other than the @@ -74,22 +72,19 @@ class GBRSimplex; /// respectively. As described above, the first column is the common /// denominator. The second column represents the constant term, explained in /// more detail below. These two are _fixed columns_; they always retain their -/// position as the first and second columns. Additionally, LexSimplexBase -/// stores a so-call big M parameter (explained below) in the third column, so -/// LexSimplexBase has three fixed columns. Finally, SymbolicLexSimplex has -/// `nSymbol` variables designated as symbols. These occupy the next `nSymbol` -/// columns, viz. the columns [3, 3 + nSymbol). For more information on symbols, -/// see LexSimplexBase and SymbolicLexSimplex. +/// position as the first and second columns. Additionally, LexSimplex stores +/// a so-call big M parameter (explained below) in the third column, so +/// LexSimplex has three fixed columns. /// -/// LexSimplexBase does not directly support variables which can be negative, so -/// we introduce the so-called big M parameter, an artificial variable that is +/// LexSimplex does not directly support variables which can be negative, so we +/// introduce the so-called big M parameter, an artificial variable that is /// considered to have an arbitrarily large value. We then transform the /// variables, say x, y, z, ... to M, M + x, M + y, M + z. Since M has been /// added to these variables, they are now known to have non-negative values. -/// For more details, see the documentation for LexSimplexBase. The big M -/// parameter is not considered a real unknown and is not stored in the `var` -/// data structure; rather the tableau just has an extra fixed column for it -/// just like the constant term. +/// For more details, see the documentation for LexSimplex. The big M parameter +/// is not considered a real unknown and is not stored in the `var` data +/// structure; rather the tableau just has an extra fixed column for it just +/// like the constant term. /// /// The vectors var and con store information about the variables and /// constraints respectively, namely, whether they are in row or column @@ -151,8 +146,8 @@ class GBRSimplex; /// operation from the end until we reach the snapshot's location. SimplexBase /// also supports taking a snapshot including the exact set of basis unknowns; /// if this functionality is used, then on rolling back the exact basis will -/// also be restored. This is used by LexSimplexBase because the lex algorithm, -/// unlike `Simplex`, is sensitive to the exact basis used at a point. +/// also be restored. This is used by LexSimplex because its algorithm, unlike +/// Simplex, is sensitive to the exact basis used at a point. class SimplexBase { public: SimplexBase() = delete; @@ -216,8 +211,7 @@ class SimplexBase { /// constant term, whereas LexSimplex has an extra fixed column for the /// so-called big M parameter. For more information see the documentation for /// LexSimplex. - SimplexBase(unsigned nVar, bool mustUseBigM, unsigned symbolOffset, - unsigned nSymbol); + SimplexBase(unsigned nVar, bool mustUseBigM); enum class Orientation { Row, Column }; @@ -229,14 +223,11 @@ class SimplexBase { /// always be non-negative and if it cannot be made non-negative without /// violating other constraints, the tableau is empty. struct Unknown { - Unknown(Orientation oOrientation, bool oRestricted, unsigned oPos, - bool oIsSymbol = false) - : pos(oPos), orientation(oOrientation), restricted(oRestricted), - isSymbol(oIsSymbol) {} + Unknown(Orientation oOrientation, bool oRestricted, unsigned oPos) + : pos(oPos), orientation(oOrientation), restricted(oRestricted) {} unsigned pos; Orientation orientation; bool restricted : 1; - bool isSymbol : 1; void print(raw_ostream &os) const { os << (orientation == Orientation::Row ? "r" : "c"); @@ -335,10 +326,6 @@ class SimplexBase { /// nRedundant rows. unsigned nRedundant; - /// The number of parameters. This must be consistent with the number of - /// Unknowns in `var` below that have `isSymbol` set to true. - unsigned nSymbol; - /// The matrix representing the tableau. Matrix tableau; @@ -376,45 +363,62 @@ class SimplexBase { /// introduce an artifical variable M that is considered to have a value of /// +infinity and instead of the variables x, y, z, we internally use variables /// M + x, M + y, M + z, which are now guaranteed to be non-negative. See the -/// documentation for SimplexBase for more details. M is also considered to be -/// an integer that is divisible by everything. -/// -/// The whole algorithm is performed with M treated as a symbol; -/// it is just considered to be infinite throughout and it never appears in the -/// final outputs. We will deal with sample values throughout that may in -/// general be some affine expression involving M, like pM + q or aM + b. We can -/// compare these with each other. They have a total order: -/// -/// aM + b < pM + q iff a < p or (a == p and b < q). +/// documentation for Simplex for more details. The whole algorithm is performed +/// without having to fix a "big enough" value of the big M parameter; it is +/// just considered to be infinite throughout and it never appears in the final +/// outputs. We will deal with sample values throughout that may in general be +/// some linear expression involving M like pM + q or aM + b. We can compare +/// these with each other. They have a total order: +/// aM + b < pM + q iff a < p or (a == p and b < q). /// In particular, aM + b < 0 iff a < 0 or (a == 0 and b < 0). /// -/// When performing symbolic optimization, sample values will be affine -/// expressions in M and the symbols. For example, we could have sample values -/// aM + bS + c and pM + qS + r, where S is a symbol. Now we have -/// aM + bS + c < pM + qS + r iff (a < p) or (a == p and bS + c < qS + r). -/// bS + c < qS + r can be always true, always false, or neither, -/// depending on the set of values S can take. The symbols are always stored -/// in columns [3, 3 + nSymbols). For more details, see the -/// documentation for SymbolicLexSimplex. -/// /// Initially all the constraints to be added are added as rows, with no attempt /// to keep the tableau consistent. Pivots are only performed when some query /// is made, such as a call to getRationalLexMin. Care is taken to always /// maintain a lexicopositive basis transform, explained below. /// -/// Let the variables be x = (x_1, ... x_n). -/// Let the symbols be s = (s_1, ... s_m). Let the basis unknowns at a -/// particular point be y = (y_1, ... y_n). We know that x = A*y + T*s + b for -/// some n x n matrix A, n x m matrix s, and n x 1 column vector b. We want -/// every column in A to be lexicopositive, i.e., have at least one non-zero -/// element, with the first such element being positive. This property is -/// preserved throughout the operation of LexSimplexBase. Note that on -/// construction, the basis transform A is the identity matrix and so every -/// column is lexicopositive. Note that for LexSimplexBase, for the tableau to -/// be consistent we must have non-negative sample values not only for the -/// constraints but also for the variables. So if the tableau is consistent then -/// x >= 0 and y >= 0, by which we mean every element in these vectors is -/// non-negative. (note that this is a different concept from lexicopositivity!) +/// Let the variables be x = (x_1, ... x_n). Let the basis unknowns at a +/// particular point be y = (y_1, ... y_n). We know that x = A*y + b for some +/// n x n matrix A and n x 1 column vector b. We want every column in A to be +/// lexicopositive, i.e., have at least one non-zero element, with the first +/// such element being positive. This property is preserved throughout the +/// operation of LexSimplex. Note that on construction, the basis transform A is +/// the indentity matrix and so every column is lexicopositive. Note that for +/// LexSimplex, for the tableau to be consistent we must have non-negative +/// sample values not only for the constraints but also for the variables. +/// So if the tableau is consistent then x >= 0 and y >= 0, by which we mean +/// every element in these vectors is non-negative. (note that this is a +/// different concept from lexicopositivity!) +/// +/// When we arrive at a basis such the basis transform is lexicopositive and the +/// tableau is consistent, the sample point is the lexiographically minimum +/// point in the polytope. We will show that A*y is zero or lexicopositive when +/// y >= 0. Adding a lexicopositive vector to b will make it lexicographically +/// bigger, so A*y + b is lexicographically bigger than b for any y >= 0 except +/// y = 0. This shows that no point lexicographically smaller than x = b can be +/// obtained. Since we already know that x = b is valid point in the space, this +/// shows that x = b is the lexicographic minimum. +/// +/// Proof that A*y is lexicopositive or zero when y > 0. Recall that every +/// column of A is lexicopositive. Begin by considering A_1, the first row of A. +/// If this row is all zeros, then (A*y)_1 = (A_1)*y = 0; proceed to the next +/// row. If we run out of rows, A*y is zero and we are done; otherwise, we +/// encounter some row A_i that has a non-zero element. Every column is +/// lexicopositive and so has some positive element before any negative elements +/// occur, so the element in this row for any column, if non-zero, must be +/// positive. Consider (A*y)_i = (A_i)*y. All the elements in both vectors are +/// non-negative, so if this is non-zero then it must be positive. Then the +/// first non-zero element of A*y is positive so A*y is lexicopositive. +/// +/// Otherwise, if (A_i)*y is zero, then for every column j that had a non-zero +/// element in A_i, y_j is zero. Thus these columns have no contribution to A*y +/// and we can completely ignore these columns of A. We now continue downwards, +/// looking for rows of A that have a non-zero element other than in the ignored +/// columns. If we find one, say A_k, once again these elements must be positive +/// since they are the first non-zero element in each of these columns, so if +/// (A_k)*y is not zero then we have that A*y is lexicopositive and if not we +/// ignore more columns; eventually if all these dot products become zero then +/// A*y is zero and we are done. class LexSimplexBase : public SimplexBase { public: ~LexSimplexBase() override = default; @@ -431,37 +435,25 @@ class LexSimplexBase : public SimplexBase { unsigned getSnapshot() { return SimplexBase::getSnapshotBasis(); } protected: - LexSimplexBase(unsigned nVar, unsigned symbolOffset, unsigned nSymbol) - : SimplexBase(nVar, /*mustUseBigM=*/true, symbolOffset, nSymbol) {} + LexSimplexBase(unsigned nVar) : SimplexBase(nVar, /*mustUseBigM=*/true) {} explicit LexSimplexBase(const IntegerRelation &constraints) - : LexSimplexBase(constraints.getNumIds(), - constraints.getIdKindOffset(IdKind::Symbol), - constraints.getNumSymbolIds()) { + : LexSimplexBase(constraints.getNumIds()) { intersectIntegerRelation(constraints); } - /// Add new symbolic variables to the end of the list of variables. - void appendSymbol(); - /// Try to move the specified row to column orientation while preserving the - /// lexicopositivity of the basis transform. The row must have a negative - /// sample value. If this is not possible, return failure. This only occurs - /// when the constraints have no solution; the tableau will be marked empty in - /// such a case. + /// lexicopositivity of the basis transform. If this is not possible, return + /// failure. This only occurs when the constraints have no solution; the + /// tableau will be marked empty in such a case. LogicalResult moveRowUnknownToColumn(unsigned row); - /// Given a row that has a non-integer sample value, add an inequality to cut - /// away this fractional sample value from the polytope without removing any - /// integer points. The integer lexmin, if one existed, remains the same on - /// return. + /// Given a row that has a non-integer sample value, add an inequality such + /// that this fractional sample value is cut away from the polytope. The added + /// inequality will be such that no integer points are removed. /// - /// This assumes that the symbolic part of the sample is integral, - /// i.e., if the symbolic sample is (c + aM + b_1*s_1 + ... b_n*s_n)/d, - /// where s_1, ... s_n are symbols, this assumes that - /// (b_1*s_1 + ... + b_n*s_n)/s is integral. - /// - /// Return failure if the tableau became empty, and success if it didn't. - /// Failure status indicates that the polytope was integer empty. + /// Returns whether the cut constraint could be enforced, i.e. failure if the + /// cut made the polytope empty, and success if it didn't. Failure status + /// indicates that the polytope didn't have any integer points. LogicalResult addCut(unsigned row); /// Undo the addition of the last constraint. This is only called while @@ -469,19 +461,14 @@ class LexSimplexBase : public SimplexBase { void undoLastConstraint() final; /// Given two potential pivot columns for a row, return the one that results - /// in the lexicographically smallest sample vector. The row's sample value - /// must be negative. If symbols are involved, the sample value must be - /// negative for all possible assignments to the symbols. + /// in the lexicographically smallest sample vector. unsigned getLexMinPivotColumn(unsigned row, unsigned colA, unsigned colB) const; }; -/// A class for lexicographic optimization without any symbols. This also -/// provides support for integer-exact redundancy and separateness checks. class LexSimplex : public LexSimplexBase { public: - explicit LexSimplex(unsigned nVar) - : LexSimplexBase(nVar, /*symbolOffset=*/0, /*nSymbol=*/0) {} + explicit LexSimplex(unsigned nVar) : LexSimplexBase(nVar) {} explicit LexSimplex(const IntegerRelation &constraints) : LexSimplexBase(constraints) { assert(constraints.getNumSymbolIds() == 0 && @@ -515,7 +502,7 @@ class LexSimplex : public LexSimplexBase { MaybeOptimum> getRationalSample() const; /// Make the tableau configuration consistent. - LogicalResult restoreRationalConsistency(); + void restoreRationalConsistency(); /// Return whether the specified row is violated; bool rowIsViolated(unsigned row) const; @@ -527,122 +514,11 @@ class LexSimplex : public LexSimplexBase { /// Get a row corresponding to a var that has a non-integral sample value, if /// one exists. Otherwise, return an empty optional. Optional maybeGetNonIntegralVarRow() const; -}; - -/// Represents the result of a symbolic lexicographic minimization computation. -struct SymbolicLexMin { - SymbolicLexMin(unsigned nSymbols, unsigned nNonSymbols) - : lexmin(PresburgerSpace::getSetSpace(nSymbols), nNonSymbols), - unboundedDomain( - PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(nSymbols))) {} - - /// This maps assignments of symbols to the corresponding lexmin. - /// Takes no value when no integer sample exists for the assignment or if the - /// lexmin is unbounded. - PWMAFunction lexmin; - /// Contains all assignments to the symbols that made the lexmin unbounded. - /// Note that the symbols of the input set to the symbolic lexmin are dims - /// of this PrebsurgerSet. - PresburgerSet unboundedDomain; -}; - -/// A class to perform symbolic lexicographic optimization, -/// i.e., to find, for every assignment to the symbols the specified -/// `symbolDomain`, the lexicographically minimum value integer value attained -/// by the non-symbol variables. -/// -/// The input is a set parametrized by some symbols, i.e., the constant terms -/// of the constraints in the set are affine expressions in the symbols, and -/// every assignment to the symbols defines a non-symbolic set. -/// -/// Accordingly, the sample values of the rows in our tableau will be affine -/// expressions in the symbols, and every assignment to the symbols will define -/// a non-symbolic LexSimplex. We then run the algorithm of -/// LexSimplex::findIntegerLexMin simultaneously for every value of the symbols -/// in the domain. -/// -/// Often, the pivot to be performed is the same for all values of the symbols, -/// in which case we just do it. For example, if the symbolic sample of a row is -/// negative for all values in the symbol domain, the row needs to be pivoted -/// irrespective of the precise value of the symbols. To answer queries like -/// "Is this symbolic sample always negative in the symbol domain?", we maintain -/// a `LexSimplex domainSimplex` correponding to the symbol domain. -/// -/// In other cases, it may be that the symbolic sample is violated at some -/// values in the symbol domain and not violated at others. In this case, -/// the pivot to be performed does depend on the value of the symbols. We -/// handle this by splitting the symbol domain. We run the algorithm for the -/// case where the row isn't violated, and then come back and run the case -/// where it is. -class SymbolicLexSimplex : public LexSimplexBase { -public: - /// `constraints` is the set for which the symbolic lexmin will be computed. - /// `symbolDomain` is the set of values of the symbols for which the lexmin - /// will be computed. `symbolDomain` should have a dim id for every symbol in - /// `constraints`, and no other ids. - SymbolicLexSimplex(const IntegerPolyhedron &constraints, - const IntegerPolyhedron &symbolDomain) - : LexSimplexBase(constraints), domainPoly(symbolDomain), - domainSimplex(symbolDomain) { - assert(domainPoly.getNumIds() == constraints.getNumSymbolIds()); - assert(domainPoly.getNumDimIds() == constraints.getNumSymbolIds()); - } - - /// The lexmin will be stored as a function `lexmin` from symbols to - /// non-symbols in the result. - /// - /// For some values of the symbols, the lexmin may be unbounded. - /// These parts of the symbol domain will be stored in `unboundedDomain`. - SymbolicLexMin computeSymbolicIntegerLexMin(); -private: - /// Perform all pivots that do not require branching. - /// - /// Return failure if the tableau became empty, indicating that the polytope - /// is always integer empty in the current symbol domain. - /// Return success otherwise. - LogicalResult doNonBranchingPivots(); - - /// Get a row that is always violated in the current domain, if one exists. - Optional maybeGetAlwaysViolatedRow(); - - /// Get a row corresponding to a variable with non-integral sample value, if - /// one exists. - Optional maybeGetNonIntegralVarRow(); - - /// Given a row that has a non-integer sample value, cut away this fractional - /// sample value witahout removing any integer points, i.e., the integer - /// lexmin, if it exists, remains the same after a call to this function. This - /// may add constraints or local variables to the tableau, as well as to the - /// domain. - /// - /// Returns whether the cut constraint could be enforced, i.e. failure if the - /// cut made the polytope empty, and success if it didn't. Failure status - /// indicates that the polytope is always integer empty in the symbol domain - /// at the time of the call. (This function may modify the symbol domain, but - /// failure statu indicates that the polytope was empty for all symbol values - /// in the initial domain.) - LogicalResult addSymbolicCut(unsigned row); - - /// Get the numerator of the symbolic sample of the specific row. - /// This is an affine expression in the symbols with integer coefficients. - /// The last element is the constant term. This ignores the big M coefficient. - SmallVector getSymbolicSampleNumerator(unsigned row) const; - - /// Return whether all the coefficients of the symbolic sample are integers. - /// - /// This does not consult the domain to check if the specified expression - /// is always integral despite coefficients being fractional. - bool isSymbolicSampleIntegral(unsigned row) const; - - /// Record a lexmin. The tableau must be consistent with all variables - /// having symbolic samples with integer coefficients. - void recordOutput(SymbolicLexMin &result) const; - - /// The symbol domain. - IntegerPolyhedron domainPoly; - /// Simplex corresponding to the symbol domain. - LexSimplex domainSimplex; + /// Given two potential pivot columns for a row, return the one that results + /// in the lexicographically smallest sample vector. + unsigned getLexMinPivotColumn(unsigned row, unsigned colA, + unsigned colB) const; }; /// The Simplex class uses the Normal pivot rule and supports integer emptiness @@ -664,9 +540,7 @@ class Simplex : public SimplexBase { enum class Direction { Up, Down }; Simplex() = delete; - explicit Simplex(unsigned nVar) - : SimplexBase(nVar, /*mustUseBigM=*/false, /*symbolOffset=*/0, - /*nSymbol=*/0) {} + explicit Simplex(unsigned nVar) : SimplexBase(nVar, /*mustUseBigM=*/false) {} explicit Simplex(const IntegerRelation &constraints) : Simplex(constraints.getNumIds()) { intersectIntegerRelation(constraints); diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp index 5e527b5467f548..bfa9a6539077da 100644 --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -14,7 +14,6 @@ #include "mlir/Analysis/Presburger/IntegerRelation.h" #include "mlir/Analysis/Presburger/LinearTransform.h" -#include "mlir/Analysis/Presburger/PWMAFunction.h" #include "mlir/Analysis/Presburger/PresburgerRelation.h" #include "mlir/Analysis/Presburger/Simplex.h" #include "mlir/Analysis/Presburger/Utils.h" @@ -146,21 +145,6 @@ void IntegerRelation::truncate(const CountsSnapshot &counts) { removeEqualityRange(counts.getNumEqs(), getNumEqualities()); } -SymbolicLexMin IntegerPolyhedron::findSymbolicIntegerLexMin() const { - // Compute the symbolic lexmin of the dims and locals, with the symbols being - // the actual symbols of this set. - SymbolicLexMin result = - SymbolicLexSimplex( - *this, PresburgerSpace::getSetSpace(/*numDims=*/getNumSymbolIds())) - .computeSymbolicIntegerLexMin(); - - // We want to return only the lexmin over the dims, so strip the locals from - // the computed lexmin. - result.lexmin.truncateOutput(result.lexmin.getNumOutputs() - - getNumLocalIds()); - return result; -} - unsigned IntegerRelation::insertId(IdKind kind, unsigned pos, unsigned num) { assert(pos <= getNumIdKind(kind)); diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp index 680e4509b7cc88..219d490e7368af 100644 --- a/mlir/lib/Analysis/Presburger/Matrix.cpp +++ b/mlir/lib/Analysis/Presburger/Matrix.cpp @@ -66,14 +66,6 @@ unsigned Matrix::appendExtraRow() { return nRows - 1; } -unsigned Matrix::appendExtraRow(ArrayRef elems) { - assert(elems.size() == nColumns && "elems must match row length!"); - unsigned row = appendExtraRow(); - for (unsigned col = 0; col < nColumns; ++col) - at(row, col) = elems[col]; - return row; -} - void Matrix::resizeHorizontally(unsigned newNColumns) { if (newNColumns < nColumns) removeColumns(newNColumns, nColumns - newNColumns); diff --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp index 711e99aab35b45..b995bc00a19c84 100644 --- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp +++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp @@ -114,18 +114,6 @@ void MultiAffineFunction::eliminateRedundantLocalId(unsigned posA, IntegerPolyhedron::eliminateRedundantLocalId(posA, posB); } -void MultiAffineFunction::truncateOutput(unsigned count) { - assert(count <= output.getNumRows()); - output.resizeVertically(count); -} - -void PWMAFunction::truncateOutput(unsigned count) { - assert(count <= numOutputs); - for (MultiAffineFunction &piece : pieces) - piece.truncateOutput(count); - numOutputs = count; -} - bool MultiAffineFunction::isEqualWhereDomainsOverlap( MultiAffineFunction other) const { if (!isSpaceCompatible(other)) diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp index f3bf42f40b177d..57e8f485742d2e 100644 --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -18,24 +18,15 @@ using Direction = Simplex::Direction; const int nullIndex = std::numeric_limits::max(); -SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM, unsigned symbolOffset, - unsigned nSymbol) +SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM) : usingBigM(mustUseBigM), nRow(0), nCol(getNumFixedCols() + nVar), - nRedundant(0), nSymbol(nSymbol), tableau(0, nCol), empty(false) { - assert(symbolOffset + nSymbol <= nVar); - + nRedundant(0), tableau(0, nCol), empty(false) { colUnknown.insert(colUnknown.begin(), getNumFixedCols(), nullIndex); for (unsigned i = 0; i < nVar; ++i) { var.emplace_back(Orientation::Column, /*restricted=*/false, /*pos=*/getNumFixedCols() + i); colUnknown.push_back(i); } - - // Move the symbols to be in columns [3, 3 + nSymbol). - for (unsigned i = 0; i < nSymbol; ++i) { - var[symbolOffset + i].isSymbol = true; - swapColumns(var[symbolOffset + i].pos, getNumFixedCols() + i); - } } const Simplex::Unknown &SimplexBase::unknownFromIndex(int index) const { @@ -105,13 +96,9 @@ unsigned SimplexBase::addRow(ArrayRef coeffs, bool makeRestricted) { // where M is the big M parameter. As such, when the user tries to add // a row ax + by + cz + d, we express it in terms of our internal variables // as -(a + b + c)M + a(M + x) + b(M + y) + c(M + z) + d. - // - // Symbols don't use the big M parameter since they do not get lex - // optimized. int64_t bigMCoeff = 0; for (unsigned i = 0; i < coeffs.size() - 1; ++i) - if (!var[i].isSymbol) - bigMCoeff -= coeffs[i]; + bigMCoeff -= coeffs[i]; // The coefficient to the big M parameter is stored in column 2. tableau(nRow - 1, 2) = bigMCoeff; } @@ -177,97 +164,19 @@ Direction flippedDirection(Direction direction) { } } // namespace -/// We simply make the tableau consistent while maintaining a lexicopositive -/// basis transform, and then return the sample value. If the tableau becomes -/// empty, we return empty. -/// -/// Let the variables be x = (x_1, ... x_n). -/// Let the basis unknowns be y = (y_1, ... y_n). -/// We have that x = A*y + b for some n x n matrix A and n x 1 column vector b. -/// -/// As we will show below, A*y is either zero or lexicopositive. -/// Adding a lexicopositive vector to b will make it lexicographically -/// greater, so A*y + b is always equal to or lexicographically greater than b. -/// Thus, since we can attain x = b, that is the lexicographic minimum. -/// -/// We have that that every column in A is lexicopositive, i.e., has at least -/// one non-zero element, with the first such element being positive. Since for -/// the tableau to be consistent we must have non-negative sample values not -/// only for the constraints but also for the variables, we also have x >= 0 and -/// y >= 0, by which we mean every element in these vectors is non-negative. -/// -/// Proof that if every column in A is lexicopositive, and y >= 0, then -/// A*y is zero or lexicopositive. Begin by considering A_1, the first row of A. -/// If this row is all zeros, then (A*y)_1 = (A_1)*y = 0; proceed to the next -/// row. If we run out of rows, A*y is zero and we are done; otherwise, we -/// encounter some row A_i that has a non-zero element. Every column is -/// lexicopositive and so has some positive element before any negative elements -/// occur, so the element in this row for any column, if non-zero, must be -/// positive. Consider (A*y)_i = (A_i)*y. All the elements in both vectors are -/// non-negative, so if this is non-zero then it must be positive. Then the -/// first non-zero element of A*y is positive so A*y is lexicopositive. -/// -/// Otherwise, if (A_i)*y is zero, then for every column j that had a non-zero -/// element in A_i, y_j is zero. Thus these columns have no contribution to A*y -/// and we can completely ignore these columns of A. We now continue downwards, -/// looking for rows of A that have a non-zero element other than in the ignored -/// columns. If we find one, say A_k, once again these elements must be positive -/// since they are the first non-zero element in each of these columns, so if -/// (A_k)*y is not zero then we have that A*y is lexicopositive and if not we -/// add these to the set of ignored columns and continue to the next row. If we -/// run out of rows, then A*y is zero and we are done. MaybeOptimum> LexSimplex::findRationalLexMin() { - if (restoreRationalConsistency().failed()) - return OptimumKind::Empty; + restoreRationalConsistency(); return getRationalSample(); } -/// Given a row that has a non-integer sample value, add an inequality such -/// that this fractional sample value is cut away from the polytope. The added -/// inequality will be such that no integer points are removed. i.e., the -/// integer lexmin, if it exists, is the same with and without this constraint. -/// -/// Let the row be -/// (c + coeffM*M + a_1*s_1 + ... + a_m*s_m + b_1*y_1 + ... + b_n*y_n)/d, -/// where s_1, ... s_m are the symbols and -/// y_1, ... y_n are the other basis unknowns. -/// -/// For this to be an integer, we want -/// coeffM*M + a_1*s_1 + ... + a_m*s_m + b_1*y_1 + ... + b_n*y_n = -c (mod d) -/// Note that this constraint must always hold, independent of the basis, -/// becuse the row unknown's value always equals this expression, even if *we* -/// later compute the sample value from a different expression based on a -/// different basis. -/// -/// Let us assume that M has a factor of d in it. Imposing this constraint on M -/// does not in any way hinder us from finding a value of M that is big enough. -/// Moreover, this function is only called when the symbolic part of the sample, -/// a_1*s_1 + ... + a_m*s_m, is known to be an integer. -/// -/// Also, we can safely reduce the coefficients modulo d, so we have: -/// -/// (b_1%d)y_1 + ... + (b_n%d)y_n = (-c%d) + k*d for some integer `k` -/// -/// Note that all coefficient modulos here are non-negative. Also, all the -/// unknowns are non-negative here as both constraints and variables are -/// non-negative in LexSimplexBase. (We used the big M trick to make the -/// variables non-negative). Therefore, the LHS here is non-negative. -/// Since 0 <= (-c%d) < d, k is the quotient of dividing the LHS by d and -/// is therefore non-negative as well. -/// -/// So we have -/// ((b_1%d)y_1 + ... + (b_n%d)y_n - (-c%d))/d >= 0. -/// -/// The constraint is violated when added (it would be useless otherwise) -/// so we immediately try to move it to a column. LogicalResult LexSimplexBase::addCut(unsigned row) { - int64_t d = tableau(row, 0); + int64_t denom = tableau(row, 0); addZeroRow(/*makeRestricted=*/true); - tableau(nRow - 1, 0) = d; - tableau(nRow - 1, 1) = -mod(-tableau(row, 1), d); // -c%d. - tableau(nRow - 1, 2) = 0; - for (unsigned col = 3 + nSymbol; col < nCol; ++col) - tableau(nRow - 1, col) = mod(tableau(row, col), d); // b_i%d. + tableau(nRow - 1, 0) = denom; + tableau(nRow - 1, 1) = -mod(-tableau(row, 1), denom); + tableau(nRow - 1, 2) = 0; // M has all factors in it. + for (unsigned col = 3; col < nCol; ++col) + tableau(nRow - 1, col) = mod(tableau(row, col), denom); return moveRowUnknownToColumn(nRow - 1); } @@ -276,7 +185,7 @@ Optional LexSimplex::maybeGetNonIntegralVarRow() const { if (u.orientation == Orientation::Column) continue; // If the sample value is of the form (a/d)M + b/d, we need b to be - // divisible by d. We assume M contains all possible + // divisible by d. We assume M is very large and contains all possible // factors and is divisible by everything. unsigned row = u.pos; if (tableau(row, 1) % tableau(row, 0) != 0) @@ -286,34 +195,28 @@ Optional LexSimplex::maybeGetNonIntegralVarRow() const { } MaybeOptimum> LexSimplex::findIntegerLexMin() { - // We first try to make the tableau consistent. - if (restoreRationalConsistency().failed()) - return OptimumKind::Empty; - - // Then, if the sample value is integral, we are done. - while (Optional maybeRow = maybeGetNonIntegralVarRow()) { - // Otherwise, for the variable whose row has a non-integral sample value, - // we add a cut, a constraint that remove this rational point - // while preserving all integer points, thus keeping the lexmin the same. - // We then again try to make the tableau with the new constraint - // consistent. This continues until the tableau becomes empty, in which - // case there is no integer point, or until there are no variables with - // non-integral sample values. - // - // Failure indicates that the tableau became empty, which occurs when the - // polytope is integer empty. - if (addCut(*maybeRow).failed()) - return OptimumKind::Empty; - if (restoreRationalConsistency().failed()) + while (!empty) { + restoreRationalConsistency(); + if (empty) return OptimumKind::Empty; + + if (Optional maybeRow = maybeGetNonIntegralVarRow()) { + // Failure occurs when the polytope is integer empty. + if (failed(addCut(*maybeRow))) + return OptimumKind::Empty; + continue; + } + + MaybeOptimum> sample = getRationalSample(); + assert(!sample.isEmpty() && "If we reached here the sample should exist!"); + if (sample.isUnbounded()) + return OptimumKind::Unbounded; + return llvm::to_vector<8>( + llvm::map_range(*sample, std::mem_fn(&Fraction::getAsInteger))); } - MaybeOptimum> sample = getRationalSample(); - assert(!sample.isEmpty() && "If we reached here the sample should exist!"); - if (sample.isUnbounded()) - return OptimumKind::Unbounded; - return llvm::to_vector<8>( - llvm::map_range(*sample, std::mem_fn(&Fraction::getAsInteger))); + // Polytope is integer empty. + return OptimumKind::Empty; } bool LexSimplex::isSeparateInequality(ArrayRef coeffs) { @@ -325,319 +228,6 @@ bool LexSimplex::isSeparateInequality(ArrayRef coeffs) { bool LexSimplex::isRedundantInequality(ArrayRef coeffs) { return isSeparateInequality(getComplementIneq(coeffs)); } - -SmallVector -SymbolicLexSimplex::getSymbolicSampleNumerator(unsigned row) const { - SmallVector sample; - sample.reserve(nSymbol + 1); - for (unsigned col = 3; col < 3 + nSymbol; ++col) - sample.push_back(tableau(row, col)); - sample.push_back(tableau(row, 1)); - return sample; -} - -void LexSimplexBase::appendSymbol() { - appendVariable(); - swapColumns(3 + nSymbol, nCol - 1); - var.back().isSymbol = true; - nSymbol++; -} - -static bool isRangeDivisibleBy(ArrayRef range, int64_t divisor) { - assert(divisor > 0 && "divisor must be positive!"); - return llvm::all_of(range, [divisor](int64_t x) { return x % divisor == 0; }); -} - -bool SymbolicLexSimplex::isSymbolicSampleIntegral(unsigned row) const { - int64_t denom = tableau(row, 0); - return tableau(row, 1) % denom == 0 && - isRangeDivisibleBy(tableau.getRow(row).slice(3, nSymbol), denom); -} - -/// This proceeds similarly to LexSimplex::addCut(). We are given a row that has -/// a symbolic sample value with fractional coefficients. -/// -/// Let the row be -/// (c + coeffM*M + sum_i a_i*s_i + sum_j b_j*y_j)/d, -/// where s_1, ... s_m are the symbols and -/// y_1, ... y_n are the other basis unknowns. -/// -/// As in LexSimplex::addCut, for this to be an integer, we want -/// -/// coeffM*M + sum_j b_j*y_j = -c + sum_i (-a_i*s_i) (mod d) -/// -/// This time, a_1*s_1 + ... + a_m*s_m may not be an integer. We find that -/// -/// sum_i (b_i%d)y_i = ((-c%d) + sum_i (-a_i%d)s_i)%d + k*d for some integer k -/// -/// where we take a modulo of the whole symbolic expression on the right to -/// bring it into the range [0, d - 1]. Therefore, as in LexSimplex::addCut, -/// k is the quotient on dividing the LHS by d, and since LHS >= 0, we have -/// k >= 0 as well. We realize the modulo of the symbolic expression by adding a -/// division variable -/// -/// q = ((-c%d) + sum_i (-a_i%d)s_i)/d -/// -/// to the symbol domain, so the equality becomes -/// -/// sum_i (b_i%d)y_i = (-c%d) + sum_i (-a_i%d)s_i - q*d + k*d for some integer k -/// -/// So the cut is -/// (sum_i (b_i%d)y_i - (-c%d) - sum_i (-a_i%d)s_i + q*d)/d >= 0 -/// This constraint is violated when added so we immediately try to move it to a -/// column. -LogicalResult SymbolicLexSimplex::addSymbolicCut(unsigned row) { - int64_t d = tableau(row, 0); - - // Add the division variable `q` described above to the symbol domain. - // q = ((-c%d) + sum_i (-a_i%d)s_i)/d. - SmallVector domainDivCoeffs; - domainDivCoeffs.reserve(nSymbol + 1); - for (unsigned col = 3; col < 3 + nSymbol; ++col) - domainDivCoeffs.push_back(mod(-tableau(row, col), d)); // (-a_i%d)s_i - domainDivCoeffs.push_back(mod(-tableau(row, 1), d)); // -c%d. - - domainSimplex.addDivisionVariable(domainDivCoeffs, d); - domainPoly.addLocalFloorDiv(domainDivCoeffs, d); - - // Update `this` to account for the additional symbol we just added. - appendSymbol(); - - // Add the cut (sum_i (b_i%d)y_i - (-c%d) + sum_i -(-a_i%d)s_i + q*d)/d >= 0. - addZeroRow(/*makeRestricted=*/true); - tableau(nRow - 1, 0) = d; - tableau(nRow - 1, 2) = 0; - - tableau(nRow - 1, 1) = -mod(-tableau(row, 1), d); // -(-c%d). - for (unsigned col = 3; col < 3 + nSymbol - 1; ++col) - tableau(nRow - 1, col) = -mod(-tableau(row, col), d); // -(-a_i%d)s_i. - tableau(nRow - 1, 3 + nSymbol - 1) = d; // q*d. - - for (unsigned col = 3 + nSymbol; col < nCol; ++col) - tableau(nRow - 1, col) = mod(tableau(row, col), d); // (b_i%d)y_i. - return moveRowUnknownToColumn(nRow - 1); -} - -void SymbolicLexSimplex::recordOutput(SymbolicLexMin &result) const { - Matrix output(0, domainPoly.getNumIds() + 1); - output.reserveRows(result.lexmin.getNumOutputs()); - for (const Unknown &u : var) { - if (u.isSymbol) - continue; - - if (u.orientation == Orientation::Column) { - // M + u has a sample value of zero so u has a sample value of -M, i.e, - // unbounded. - result.unboundedDomain.unionInPlace(domainPoly); - return; - } - - int64_t denom = tableau(u.pos, 0); - if (tableau(u.pos, 2) < denom) { - // M + u has a sample value of fM + something, where f < 1, so - // u = (f - 1)M + something, which has a negative coefficient for M, - // and so is unbounded. - result.unboundedDomain.unionInPlace(domainPoly); - return; - } - assert(tableau(u.pos, 2) == denom && - "Coefficient of M should not be greater than 1!"); - - SmallVector sample = getSymbolicSampleNumerator(u.pos); - for (int64_t &elem : sample) { - assert(elem % denom == 0 && "coefficients must be integral!"); - elem /= denom; - } - output.appendExtraRow(sample); - } - result.lexmin.addPiece(domainPoly, output); -} - -Optional SymbolicLexSimplex::maybeGetAlwaysViolatedRow() { - // First look for rows that are clearly violated just from the big M - // coefficient, without needing to perform any simplex queries on the domain. - for (unsigned row = 0; row < nRow; ++row) - if (tableau(row, 2) < 0) - return row; - - for (unsigned row = 0; row < nRow; ++row) { - if (tableau(row, 2) > 0) - continue; - if (domainSimplex.isSeparateInequality(getSymbolicSampleNumerator(row))) { - // Sample numerator always takes negative values in the symbol domain. - return row; - } - } - return {}; -} - -Optional SymbolicLexSimplex::maybeGetNonIntegralVarRow() { - for (const Unknown &u : var) { - if (u.orientation == Orientation::Column) - continue; - assert(!u.isSymbol && "Symbol should not be in row orientation!"); - if (!isSymbolicSampleIntegral(u.pos)) - return u.pos; - } - return {}; -} - -/// The non-branching pivots are just the ones moving the rows -/// that are always violated in the symbol domain. -LogicalResult SymbolicLexSimplex::doNonBranchingPivots() { - while (Optional row = maybeGetAlwaysViolatedRow()) - if (moveRowUnknownToColumn(*row).failed()) - return failure(); - return success(); -} - -SymbolicLexMin SymbolicLexSimplex::computeSymbolicIntegerLexMin() { - SymbolicLexMin result(nSymbol, var.size() - nSymbol); - - /// The algorithm is more naturally expressed recursively, but we implement - /// it iteratively here to avoid potential issues with stack overflows in the - /// compiler. We explicitly maintain the stack frames in a vector. - /// - /// To "recurse", we store the current "stack frame", i.e., state variables - /// that we will need when we "return", into `stack`, increment `level`, and - /// `continue`. To "tail recurse", we just `continue`. - /// To "return", we decrement `level` and `continue`. - /// - /// When there is no stack frame for the current `level`, this indicates that - /// we have just "recursed" or "tail recursed". When there does exist one, - /// this indicates that we have just "returned" from recursing. There is only - /// one point at which non-tail calls occur so we always "return" there. - unsigned level = 1; - struct StackFrame { - int splitIndex; - unsigned snapshot; - unsigned domainSnapshot; - IntegerRelation::CountsSnapshot domainPolyCounts; - }; - SmallVector stack; - - while (level > 0) { - assert(level >= stack.size()); - if (level > stack.size()) { - if (empty || domainSimplex.findIntegerLexMin().isEmpty()) { - // No integer points; return. - --level; - continue; - } - - if (doNonBranchingPivots().failed()) { - // Could not find pivots for violated constraints; return. - --level; - continue; - } - - unsigned splitRow; - SmallVector symbolicSample; - for (splitRow = 0; splitRow < nRow; ++splitRow) { - if (tableau(splitRow, 2) > 0) - continue; - assert(tableau(splitRow, 2) == 0 && - "Non-branching pivots should have been handled already!"); - - symbolicSample = getSymbolicSampleNumerator(splitRow); - if (domainSimplex.isRedundantInequality(symbolicSample)) - continue; - - // It's neither redundant nor separate, so it takes both positive and - // negative values, and hence constitutes a row for which we need to - // split the domain and separately run each case. - assert(!domainSimplex.isSeparateInequality(symbolicSample) && - "Non-branching pivots should have been handled already!"); - break; - } - - if (splitRow < nRow) { - unsigned domainSnapshot = domainSimplex.getSnapshot(); - IntegerRelation::CountsSnapshot domainPolyCounts = - domainPoly.getCounts(); - - // First, we consider the part of the domain where the row is not - // violated. We don't have to do any pivots for the row in this case, - // but we record the additional constraint that defines this part of - // the domain. - domainSimplex.addInequality(symbolicSample); - domainPoly.addInequality(symbolicSample); - - // Recurse. - // - // On return, the basis as a set is preserved but not the internal - // ordering within rows or columns. Thus, we take note of the index of - // the Unknown that caused the split, which may be in a different - // row when we come back from recursing. We will need this to recurse - // on the other part of the split domain, where the row is violated. - // - // Note that we have to capture the index above and not a reference to - // the Unknown itself, since the array it lives in might get - // reallocated. - int splitIndex = rowUnknown[splitRow]; - unsigned snapshot = getSnapshot(); - stack.push_back( - {splitIndex, snapshot, domainSnapshot, domainPolyCounts}); - ++level; - continue; - } - - // The tableau is rationally consistent for the current domain. - // Now we look for non-integral sample values and add cuts for them. - if (Optional row = maybeGetNonIntegralVarRow()) { - if (addSymbolicCut(*row).failed()) { - // No integral points; return. - --level; - continue; - } - - // Rerun this level with the added cut constraint (tail recurse). - continue; - } - - // Record output and return. - recordOutput(result); - --level; - continue; - } - - if (level == stack.size()) { - // We have "returned" from "recursing". - const StackFrame &frame = stack.back(); - domainPoly.truncate(frame.domainPolyCounts); - domainSimplex.rollback(frame.domainSnapshot); - rollback(frame.snapshot); - const Unknown &u = unknownFromIndex(frame.splitIndex); - - // Drop the frame. We don't need it anymore. - stack.pop_back(); - - // Now we consider the part of the domain where the unknown `splitIndex` - // was negative. - assert(u.orientation == Orientation::Row && - "The split row should have been returned to row orientation!"); - SmallVector splitIneq = - getComplementIneq(getSymbolicSampleNumerator(u.pos)); - if (moveRowUnknownToColumn(u.pos).failed()) { - // The unknown can't be made non-negative; return. - --level; - continue; - } - - // The unknown can be made negative; recurse with the corresponding domain - // constraints. - domainSimplex.addInequality(splitIneq); - domainPoly.addInequality(splitIneq); - - // We are now taking care of the second half of the domain and we don't - // need to do anything else here after returning, so it's a tail recurse. - continue; - } - } - - return result; -} - bool LexSimplex::rowIsViolated(unsigned row) const { if (tableau(row, 2) < 0) return true; @@ -653,20 +243,19 @@ Optional LexSimplex::maybeGetViolatedRow() const { return {}; } -/// We simply look for violated rows and keep trying to move them to column -/// orientation, which always succeeds unless the constraints have no solution -/// in which case we just give up and return. -LogicalResult LexSimplex::restoreRationalConsistency() { - if (empty) - return failure(); - while (Optional maybeViolatedRow = maybeGetViolatedRow()) - if (moveRowUnknownToColumn(*maybeViolatedRow).failed()) - return failure(); - return success(); +// We simply look for violated rows and keep trying to move them to column +// orientation, which always succeeds unless the constraints have no solution +// in which case we just give up and return. +void LexSimplex::restoreRationalConsistency() { + while (Optional maybeViolatedRow = maybeGetViolatedRow()) { + LogicalResult status = moveRowUnknownToColumn(*maybeViolatedRow); + if (failed(status)) + return; + } } // Move the row unknown to column orientation while preserving lexicopositivity -// of the basis transform. The sample value of the row must be negative. +// of the basis transform. // // We only consider pivots where the pivot element is positive. Suppose no such // pivot exists, i.e., some violated row has no positive coefficient for any @@ -729,7 +318,7 @@ LogicalResult LexSimplex::restoreRationalConsistency() { // minimizes the change in sample value. LogicalResult LexSimplexBase::moveRowUnknownToColumn(unsigned row) { Optional maybeColumn; - for (unsigned col = 3 + nSymbol; col < nCol; ++col) { + for (unsigned col = 3; col < nCol; ++col) { if (tableau(row, col) <= 0) continue; maybeColumn = @@ -747,7 +336,6 @@ LogicalResult LexSimplexBase::moveRowUnknownToColumn(unsigned row) { unsigned LexSimplexBase::getLexMinPivotColumn(unsigned row, unsigned colA, unsigned colB) const { - // First, let's consider the non-symbolic case. // A pivot causes the following change. (in the diagram the matrix elements // are shown as rationals and there is no common denominator used) // @@ -771,7 +359,7 @@ unsigned LexSimplexBase::getLexMinPivotColumn(unsigned row, unsigned colA, // (-p/a)M + (-b/a), i.e. 0 to -(pM + b)/a. Thus the change in the sample // value is -s/a. // - // If the variable is the pivot row, its sample value goes from s to 0, for a + // If the variable is the pivot row, it sampel value goes from s to 0, for a // change of -s. // // If the variable is a non-pivot row, its sample value changes from @@ -785,12 +373,8 @@ unsigned LexSimplexBase::getLexMinPivotColumn(unsigned row, unsigned colA, // comparisons involved and can be ignored, since -s is strictly positive. // // Thus we take away this common factor and just return 0, 1/a, 1, or c/a as - // appropriate. This allows us to run the entire algorithm treating M - // symbolically, as the pivot to be performed does not depend on the value - // of M, so long as the sample value s is negative. Note that this is not - // because of any special feature of M; by the same argument, we ignore the - // symbols too. The caller ensure that the sample value s is negative for - // all possible values of the symbols. + // appropriate. This allows us to run the entire algorithm without ever having + // to fix a value of M. auto getSampleChangeCoeffForVar = [this, row](unsigned col, const Unknown &u) -> Fraction { int64_t a = tableau(row, col); @@ -905,7 +489,6 @@ void SimplexBase::pivot(Pivot pair) { pivot(pair.row, pair.column); } /// element. void SimplexBase::pivot(unsigned pivotRow, unsigned pivotCol) { assert(pivotCol >= getNumFixedCols() && "Refusing to pivot invalid column"); - assert(!unknownFromColumn(pivotCol).isSymbol); swapRowWithCol(pivotRow, pivotCol); std::swap(tableau(pivotRow, 0), tableau(pivotRow, pivotCol)); @@ -1195,9 +778,6 @@ void SimplexBase::undo(UndoLogEntry entry) { assert(var.back().orientation == Orientation::Column && "Variable to be removed must be in column orientation!"); - if (var.back().isSymbol) - nSymbol--; - // Move this variable to the last column and remove the column from the // tableau. swapColumns(var.back().pos, nCol - 1); diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp index 2cb6ada89397aa..4149d85d8759fd 100644 --- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp @@ -8,7 +8,6 @@ #include "./Utils.h" #include "mlir/Analysis/Presburger/IntegerRelation.h" -#include "mlir/Analysis/Presburger/PWMAFunction.h" #include "mlir/Analysis/Presburger/Simplex.h" #include @@ -1135,229 +1134,6 @@ TEST(IntegerPolyhedronTest, findIntegerLexMin) { ">= 0, -11*z + 5*y - 3*x + 7 >= 0)")); } -void expectSymbolicIntegerLexMin( - StringRef polyStr, - ArrayRef, 8>>> - expectedLexminRepr, - ArrayRef expectedUnboundedDomainRepr) { - IntegerPolyhedron poly = parsePoly(polyStr); - - ASSERT_NE(poly.getNumDimIds(), 0u); - ASSERT_NE(poly.getNumSymbolIds(), 0u); - - PWMAFunction expectedLexmin = - parsePWMAF(/*numInputs=*/poly.getNumSymbolIds(), - /*numOutputs=*/poly.getNumDimIds(), expectedLexminRepr); - - PresburgerSet expectedUnboundedDomain = parsePresburgerSetFromPolyStrings( - poly.getNumSymbolIds(), expectedUnboundedDomainRepr); - - SymbolicLexMin result = poly.findSymbolicIntegerLexMin(); - - EXPECT_TRUE(result.lexmin.isEqual(expectedLexmin)); - if (!result.lexmin.isEqual(expectedLexmin)) { - llvm::errs() << "got:\n"; - result.lexmin.dump(); - llvm::errs() << "expected:\n"; - expectedLexmin.dump(); - } - - EXPECT_TRUE(result.unboundedDomain.isEqual(expectedUnboundedDomain)); - if (!result.unboundedDomain.isEqual(expectedUnboundedDomain)) - result.unboundedDomain.dump(); -} - -void expectSymbolicIntegerLexMin( - StringRef polyStr, - ArrayRef, 8>>> - result) { - expectSymbolicIntegerLexMin(polyStr, result, {}); -} - -TEST(IntegerPolyhedronTest, findSymbolicIntegerLexMin) { - expectSymbolicIntegerLexMin("(x)[a] : (x - a >= 0)", - { - {"(a) : ()", {{1, 0}}}, // a - }); - - expectSymbolicIntegerLexMin( - "(x)[a, b] : (x - a >= 0, x - b >= 0)", - { - {"(a, b) : (a - b >= 0)", {{1, 0, 0}}}, // a - {"(a, b) : (b - a - 1 >= 0)", {{0, 1, 0}}}, // b - }); - - expectSymbolicIntegerLexMin( - "(x)[a, b, c] : (x -a >= 0, x - b >= 0, x - c >= 0)", - { - {"(a, b, c) : (a - b >= 0, a - c >= 0)", {{1, 0, 0, 0}}}, // a - {"(a, b, c) : (b - a - 1 >= 0, b - c >= 0)", {{0, 1, 0, 0}}}, // b - {"(a, b, c) : (c - a - 1 >= 0, c - b - 1 >= 0)", {{0, 0, 1, 0}}}, // c - }); - - expectSymbolicIntegerLexMin("(x, y)[a] : (x - a >= 0, x + y >= 0)", - { - {"(a) : ()", {{1, 0}, {-1, 0}}}, // (a, -a) - }); - - expectSymbolicIntegerLexMin( - "(x, y)[a] : (x - a >= 0, x + y >= 0, y >= 0)", - { - {"(a) : (a >= 0)", {{1, 0}, {0, 0}}}, // (a, 0) - {"(a) : (-a - 1 >= 0)", {{1, 0}, {-1, 0}}}, // (a, -a) - }); - - expectSymbolicIntegerLexMin( - "(x, y)[a, b, c] : (x - a >= 0, y - b >= 0, c - x - y >= 0)", - { - {"(a, b, c) : (c - a - b >= 0)", - {{1, 0, 0, 0}, {0, 1, 0, 0}}}, // (a, b) - }); - - expectSymbolicIntegerLexMin( - "(x, y, z)[a, b, c] : (c - z >= 0, b - y >= 0, x + y + z - a == 0)", - { - {"(a, b, c) : ()", - {{1, -1, -1, 0}, {0, 1, 0, 0}, {0, 0, 1, 0}}}, // (a - b - c, b, c) - }); - - expectSymbolicIntegerLexMin( - "(x)[a, b] : (a >= 0, b >= 0, x >= 0, a + b + x - 1 >= 0)", - { - {"(a, b) : (a >= 0, b >= 0, a + b - 1 >= 0)", {{0, 0, 0}}}, // 0 - {"(a, b) : (a == 0, b == 0)", {{0, 0, 1}}}, // 1 - }); - - expectSymbolicIntegerLexMin( - "(x)[a, b] : (1 - a >= 0, a >= 0, 1 - b >= 0, b >= 0, 1 - x >= 0, x >= " - "0, a + b + x - 1 >= 0)", - { - {"(a, b) : (1 - a >= 0, a >= 0, 1 - b >= 0, b >= 0, a + b - 1 >= 0)", - {{0, 0, 0}}}, // 0 - {"(a, b) : (a == 0, b == 0)", {{0, 0, 1}}}, // 1 - }); - - expectSymbolicIntegerLexMin( - "(x, y, z)[a, b] : (x - a == 0, y - b == 0, x >= 0, y >= 0, z >= 0, x + " - "y + z - 1 >= 0)", - { - {"(a, b) : (a >= 0, b >= 0, 1 - a - b >= 0)", - {{1, 0, 0}, {0, 1, 0}, {-1, -1, 1}}}, // (a, b, 1 - a - b) - {"(a, b) : (a >= 0, b >= 0, a + b - 2 >= 0)", - {{1, 0, 0}, {0, 1, 0}, {0, 0, 0}}}, // (a, b, 0) - }); - - expectSymbolicIntegerLexMin("(x)[a, b] : (x - a == 0, x - b >= 0)", - { - {"(a, b) : (a - b >= 0)", {{1, 0, 0}}}, // a - }); - - expectSymbolicIntegerLexMin( - "(q)[a] : (a - 1 - 3*q == 0, q >= 0)", - { - {"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)", - {{0, 1, 0}}}, // a floordiv 3 - }); - - expectSymbolicIntegerLexMin( - "(r, q)[a] : (a - r - 3*q == 0, q >= 0, 1 - r >= 0, r >= 0)", - { - {"(a) : (a - 0 - 3*(a floordiv 3) == 0, a >= 0)", - {{0, 0, 0}, {0, 1, 0}}}, // (0, a floordiv 3) - {"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)", - {{0, 0, 1}, {0, 1, 0}}}, // (1 a floordiv 3) - }); - - expectSymbolicIntegerLexMin( - "(r, q)[a] : (a - r - 3*q == 0, q >= 0, 2 - r >= 0, r - 1 >= 0)", - { - {"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)", - {{0, 0, 1}, {0, 1, 0}}}, // (1, a floordiv 3) - {"(a) : (a - 2 - 3*(a floordiv 3) == 0, a >= 0)", - {{0, 0, 2}, {0, 1, 0}}}, // (2, a floordiv 3) - }); - - expectSymbolicIntegerLexMin( - "(r, q)[a] : (a - r - 3*q == 0, q >= 0, r >= 0)", - { - {"(a) : (a - 3*(a floordiv 3) == 0, a >= 0)", - {{0, 0, 0}, {0, 1, 0}}}, // (0, a floordiv 3) - {"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)", - {{0, 0, 1}, {0, 1, 0}}}, // (1, a floordiv 3) - {"(a) : (a - 2 - 3*(a floordiv 3) == 0, a >= 0)", - {{0, 0, 2}, {0, 1, 0}}}, // (2, a floordiv 3) - }); - - expectSymbolicIntegerLexMin( - "(x, y, z, w)[g] : (" - // x, y, z, w are boolean variables. - "1 - x >= 0, x >= 0, 1 - y >= 0, y >= 0," - "1 - z >= 0, z >= 0, 1 - w >= 0, w >= 0," - // We have some constraints on them: - "x + y + z - 1 >= 0," // x or y or z - "x + y + w - 1 >= 0," // x or y or w - "1 - x + 1 - y + 1 - w - 1 >= 0," // ~x or ~y or ~w - // What's the lexmin solution using exactly g true vars? - "g - x - y - z - w == 0)", - { - {"(g) : (g - 1 == 0)", - {{0, 0}, {0, 1}, {0, 0}, {0, 0}}}, // (0, 1, 0, 0) - {"(g) : (g - 2 == 0)", - {{0, 0}, {0, 0}, {0, 1}, {0, 1}}}, // (0, 0, 1, 1) - {"(g) : (g - 3 == 0)", - {{0, 0}, {0, 1}, {0, 1}, {0, 1}}}, // (0, 1, 1, 1) - }); - - // Bezout's lemma: if a, b are constants, - // the set of values that ax + by can take is all multiples of gcd(a, b). - expectSymbolicIntegerLexMin( - // If (x, y) is a solution for a given [a, r], then so is (x - 5, y + 2). - // So the lexmin is unbounded if it exists. - "(x, y)[a, r] : (a >= 0, r - a + 14*x + 35*y == 0)", {}, - // According to Bezout's lemma, 14x + 35y can take on all multiples - // of 7 and no other values. So the solution exists iff r - a is a - // multiple of 7. - {"(a, r) : (a >= 0, r - a - 7*((r - a) floordiv 7) == 0)"}); - - // The lexmins are unbounded. - expectSymbolicIntegerLexMin("(x, y)[a] : (9*x - 4*y - 2*a >= 0)", {}, - {"(a) : ()"}); - - // Test cases adapted from isl. - expectSymbolicIntegerLexMin( - // a = 2b - 2(c - b), c - b >= 0. - // So b is minimized when c = b. - "(b, c)[a] : (a - 4*b + 2*c == 0, c - b >= 0)", - { - {"(a) : (a - 2*(a floordiv 2) == 0)", - {{0, 1, 0}, {0, 1, 0}}}, // (a floordiv 2, a floordiv 2) - }); - - expectSymbolicIntegerLexMin( - // 0 <= b <= 255, 1 <= a - 512b <= 509, - // b + 8 >= 1 + 16*(b + 8 floordiv 16) // i.e. b % 16 != 8 - "(b)[a] : (255 - b >= 0, b >= 0, a - 512*b - 1 >= 0, 512*b -a + 509 >= " - "0, b + 7 - 16*((8 + b) floordiv 16) >= 0)", - { - {"(a) : (255 - (a floordiv 512) >= 0, a >= 0, a - 512*(a floordiv " - "512) - 1 >= 0, 512*(a floordiv 512) - a + 509 >= 0, (a floordiv " - "512) + 7 - 16*((8 + (a floordiv 512)) floordiv 16) >= 0)", - {{0, 1, 0, 0}}}, // (a floordiv 2, a floordiv 2) - }); - - expectSymbolicIntegerLexMin( - "(a, b)[K, N, x, y] : (N - K - 2 >= 0, K + 4 - N >= 0, x - 4 >= 0, x + 6 " - "- 2*N >= 0, K+N - x - 1 >= 0, a - N + 1 >= 0, K+N-1-a >= 0,a + 6 - b - " - "N >= 0, 2*N - 4 - a >= 0," - "2*N - 3*K + a - b >= 0, 4*N - K + 1 - 3*b >= 0, b - N >= 0, a - x - 1 " - ">= 0)", - {{ - "(K, N, x, y) : (x + 6 - 2*N >= 0, 2*N - 5 - x >= 0, x + 1 -3*K + N " - ">= 0, N + K - 2 - x >= 0, x - 4 >= 0)", - {{0, 0, 1, 0, 1}, {0, 1, 0, 0, 0}} // (1 + x, N) - }}); -} - static void expectComputedVolumeIsValidOverapprox(const IntegerPolyhedron &poly, Optional trueVolume,