77// ===----------------------------------------------------------------------===//
88
99#include " mlir/Analysis/Presburger/Barvinok.h"
10+ #include " llvm/ADT/Sequence.h"
1011
1112using namespace mlir ;
1213using namespace presburger ;
@@ -24,7 +25,7 @@ ConeV mlir::presburger::detail::getDual(ConeH cone) {
2425 // is represented as a row [a1, ..., an, b]
2526 // and that b = 0.
2627
27- for (unsigned i = 0 ; i < numIneq; ++i ) {
28+ for (auto i : llvm::seq< int >( 0 , numIneq) ) {
2829 assert (cone.atIneq (i, numVar) == 0 &&
2930 " H-representation of cone is not centred at the origin!" );
3031 for (unsigned j = 0 ; j < numVar; ++j) {
@@ -63,3 +64,83 @@ MPInt mlir::presburger::detail::getIndex(ConeV cone) {
6364
6465 return cone.determinant ();
6566}
67+
68+ // / Compute the generating function for a unimodular cone.
69+ // / This consists of a single term of the form
70+ // / sign * x^num / prod_j (1 - x^den_j)
71+ // /
72+ // / sign is either +1 or -1.
73+ // / den_j is defined as the set of generators of the cone.
74+ // / num is computed by expressing the vertex as a weighted
75+ // / sum of the generators, and then taking the floor of the
76+ // / coefficients.
77+ GeneratingFunction mlir::presburger::detail::unimodularConeGeneratingFunction (
78+ ParamPoint vertex, int sign, ConeH cone) {
79+ // Consider a cone with H-representation [0 -1].
80+ // [-1 -2]
81+ // Let the vertex be given by the matrix [ 2 2 0], with 2 params.
82+ // [-1 -1/2 1]
83+
84+ // `cone` must be unimodular.
85+ assert (getIndex (getDual (cone)) == 1 && " input cone is not unimodular!" );
86+
87+ unsigned numVar = cone.getNumVars ();
88+ unsigned numIneq = cone.getNumInequalities ();
89+
90+ // Thus its ray matrix, U, is the inverse of the
91+ // transpose of its inequality matrix, `cone`.
92+ // The last column of the inequality matrix is null,
93+ // so we remove it to obtain a square matrix.
94+ FracMatrix transp = FracMatrix (cone.getInequalities ()).transpose ();
95+ transp.removeRow (numVar);
96+
97+ FracMatrix generators (numVar, numIneq);
98+ transp.determinant (/* inverse=*/ &generators); // This is the U-matrix.
99+ // Thus the generators are given by U = [2 -1].
100+ // [-1 0]
101+
102+ // The powers in the denominator of the generating
103+ // function are given by the generators of the cone,
104+ // i.e., the rows of the matrix U.
105+ std::vector<Point> denominator (numIneq);
106+ ArrayRef<Fraction> row;
107+ for (auto i : llvm::seq<int >(0 , numVar)) {
108+ row = generators.getRow (i);
109+ denominator[i] = Point (row);
110+ }
111+
112+ // The vertex is v \in Z^{d x (n+1)}
113+ // We need to find affine functions of parameters λ_i(p)
114+ // such that v = Σ λ_i(p)*u_i,
115+ // where u_i are the rows of U (generators)
116+ // The λ_i are given by the columns of Λ = v^T U^{-1}, and
117+ // we have transp = U^{-1}.
118+ // Then the exponent in the numerator will be
119+ // Σ -floor(-λ_i(p))*u_i.
120+ // Thus we store the (exponent of the) numerator as the affine function -Λ,
121+ // since the generators u_i are already stored as the exponent of the
122+ // denominator. Note that the outer -1 will have to be accounted for, as it is
123+ // not stored. See end for an example.
124+
125+ unsigned numColumns = vertex.getNumColumns ();
126+ unsigned numRows = vertex.getNumRows ();
127+ ParamPoint numerator (numColumns, numRows);
128+ SmallVector<Fraction> ithCol (numRows);
129+ for (auto i : llvm::seq<int >(0 , numColumns)) {
130+ for (auto j : llvm::seq<int >(0 , numRows))
131+ ithCol[j] = vertex (j, i);
132+ numerator.setRow (i, transp.preMultiplyWithRow (ithCol));
133+ numerator.negateRow (i);
134+ }
135+ // Therefore Λ will be given by [ 1 0 ] and the negation of this will be
136+ // [ 1/2 -1 ]
137+ // [ -1 -2 ]
138+ // stored as the numerator.
139+ // Algebraically, the numerator exponent is
140+ // [ -2 ⌊ - N - M/2 + 1 ⌋ + 1 ⌊ 0 + M + 2 ⌋ ] -> first COLUMN of U is [2, -1]
141+ // [ 1 ⌊ - N - M/2 + 1 ⌋ + 0 ⌊ 0 + M + 2 ⌋ ] -> second COLUMN of U is [-1, 0]
142+
143+ return GeneratingFunction (numColumns - 1 , SmallVector<int >(1 , sign),
144+ std::vector ({numerator}),
145+ std::vector ({denominator}));
146+ }
0 commit comments