11
11
12
12
#include " mlir/Support/LLVM.h"
13
13
#include " mlir/Support/LogicalResult.h"
14
+ #include " llvm/ADT/APFloat.h"
14
15
#include " llvm/ADT/APInt.h"
15
16
#include " llvm/ADT/ArrayRef.h"
16
17
#include " llvm/ADT/Hashing.h"
17
- #include " llvm/ADT/SmallVector.h"
18
+ #include " llvm/ADT/SmallString.h"
19
+ #include " llvm/ADT/Twine.h"
20
+ #include " llvm/Support/raw_ostream.h"
18
21
19
22
namespace mlir {
20
23
@@ -27,98 +30,202 @@ namespace polynomial {
27
30
// / would want to specify 128-bit polynomials statically in the source code.
28
31
constexpr unsigned apintBitWidth = 64 ;
29
32
30
- // / A class representing a monomial of a single-variable polynomial with integer
31
- // / coefficients.
32
- class Monomial {
33
+ template <typename CoefficientType>
34
+ class MonomialBase {
33
35
public:
34
- Monomial (int64_t coeff, uint64_t expo)
35
- : coefficient(apintBitWidth, coeff), exponent(apintBitWidth, expo) {}
36
-
37
- Monomial (const APInt &coeff, const APInt &expo)
36
+ MonomialBase (const CoefficientType &coeff, const APInt &expo)
38
37
: coefficient(coeff), exponent(expo) {}
38
+ virtual ~MonomialBase () = 0 ;
39
39
40
- Monomial () : coefficient(apintBitWidth, 0 ), exponent(apintBitWidth, 0 ) {}
40
+ const CoefficientType &getCoefficient () const { return coefficient; }
41
+ CoefficientType &getMutableCoefficient () { return coefficient; }
42
+ const APInt &getExponent () const { return exponent; }
43
+ void setCoefficient (const CoefficientType &coeff) { coefficient = coeff; }
44
+ void setExponent (const APInt &exp) { exponent = exp; }
41
45
42
- bool operator ==(const Monomial &other) const {
46
+ bool operator ==(const MonomialBase &other) const {
43
47
return other.coefficient == coefficient && other.exponent == exponent;
44
48
}
45
- bool operator !=(const Monomial &other) const {
49
+ bool operator !=(const MonomialBase &other) const {
46
50
return other.coefficient != coefficient || other.exponent != exponent;
47
51
}
48
52
49
53
// / Monomials are ordered by exponent.
50
- bool operator <(const Monomial &other) const {
54
+ bool operator <(const MonomialBase &other) const {
51
55
return (exponent.ult (other.exponent ));
52
56
}
53
57
54
- friend ::llvm::hash_code hash_value (const Monomial &arg);
58
+ virtual bool isMonic () const = 0;
59
+ virtual void
60
+ coefficientToString (llvm::SmallString<16 > &coeffString) const = 0 ;
55
61
56
- public:
57
- APInt coefficient ;
62
+ template < typename T>
63
+ friend ::llvm::hash_code hash_value ( const MonomialBase<T> &arg) ;
58
64
59
- // Always unsigned
65
+ protected:
66
+ CoefficientType coefficient;
60
67
APInt exponent;
61
68
};
62
69
63
- // / A single-variable polynomial with integer coefficients.
64
- // /
65
- // / Eg: x^1024 + x + 1
66
- // /
67
- // / The symbols used as the polynomial's indeterminate don't matter, so long as
68
- // / it is used consistently throughout the polynomial.
69
- class Polynomial {
70
+ // / A class representing a monomial of a single-variable polynomial with integer
71
+ // / coefficients.
72
+ class IntMonomial : public MonomialBase <APInt> {
70
73
public:
71
- Polynomial () = delete ;
74
+ IntMonomial (int64_t coeff, uint64_t expo)
75
+ : MonomialBase(APInt(apintBitWidth, coeff), APInt(apintBitWidth, expo)) {}
72
76
73
- explicit Polynomial (ArrayRef<Monomial> terms) : terms(terms){};
77
+ IntMonomial ()
78
+ : MonomialBase(APInt(apintBitWidth, 0 ), APInt(apintBitWidth, 0 )) {}
74
79
75
- // Returns a Polynomial from a list of monomials.
76
- // Fails if two monomials have the same exponent.
77
- static FailureOr<Polynomial> fromMonomials (ArrayRef<Monomial> monomials);
80
+ ~IntMonomial () = default ;
78
81
79
- // / Returns a polynomial with coefficients given by `coeffs`. The value
80
- // / coeffs[i] is converted to a monomial with exponent i.
81
- static Polynomial fromCoefficients (ArrayRef<int64_t > coeffs);
82
+ bool isMonic () const override { return coefficient == 1 ; }
83
+
84
+ void coefficientToString (llvm::SmallString<16 > &coeffString) const override {
85
+ coefficient.toStringSigned (coeffString);
86
+ }
87
+ };
88
+
89
+ // / A class representing a monomial of a single-variable polynomial with integer
90
+ // / coefficients.
91
+ class FloatMonomial : public MonomialBase <APFloat> {
92
+ public:
93
+ FloatMonomial (double coeff, uint64_t expo)
94
+ : MonomialBase(APFloat(coeff), APInt(apintBitWidth, expo)) {}
95
+
96
+ FloatMonomial () : MonomialBase(APFloat((double )0 ), APInt(apintBitWidth, 0 )) {}
97
+
98
+ ~FloatMonomial () = default ;
99
+
100
+ bool isMonic () const override { return coefficient == APFloat (1.0 ); }
101
+
102
+ void coefficientToString (llvm::SmallString<16 > &coeffString) const override {
103
+ coefficient.toString (coeffString);
104
+ }
105
+ };
106
+
107
+ template <typename Monomial>
108
+ class PolynomialBase {
109
+ public:
110
+ PolynomialBase () = delete ;
111
+
112
+ explicit PolynomialBase (ArrayRef<Monomial> terms) : terms(terms){};
82
113
83
114
explicit operator bool () const { return !terms.empty (); }
84
- bool operator ==(const Polynomial &other) const {
115
+ bool operator ==(const PolynomialBase &other) const {
85
116
return other.terms == terms;
86
117
}
87
- bool operator !=(const Polynomial &other) const {
118
+ bool operator !=(const PolynomialBase &other) const {
88
119
return !(other.terms == terms);
89
120
}
90
121
91
- // Prints polynomial to 'os'.
92
- void print (raw_ostream &os) const ;
93
122
void print (raw_ostream &os, ::llvm::StringRef separator,
94
- ::llvm::StringRef exponentiation) const ;
123
+ ::llvm::StringRef exponentiation) const {
124
+ bool first = true ;
125
+ for (const Monomial &term : getTerms ()) {
126
+ if (first) {
127
+ first = false ;
128
+ } else {
129
+ os << separator;
130
+ }
131
+ std::string coeffToPrint;
132
+ if (term.isMonic () && term.getExponent ().uge (1 )) {
133
+ coeffToPrint = " " ;
134
+ } else {
135
+ llvm::SmallString<16 > coeffString;
136
+ term.coefficientToString (coeffString);
137
+ coeffToPrint = coeffString.str ();
138
+ }
139
+
140
+ if (term.getExponent () == 0 ) {
141
+ os << coeffToPrint;
142
+ } else if (term.getExponent () == 1 ) {
143
+ os << coeffToPrint << " x" ;
144
+ } else {
145
+ llvm::SmallString<16 > expString;
146
+ term.getExponent ().toStringSigned (expString);
147
+ os << coeffToPrint << " x" << exponentiation << expString;
148
+ }
149
+ }
150
+ }
151
+
152
+ // Prints polynomial to 'os'.
153
+ void print (raw_ostream &os) const { print (os, " + " , " **" ); }
154
+
95
155
void dump () const ;
96
156
97
157
// Prints polynomial so that it can be used as a valid identifier
98
- std::string toIdentifier () const ;
158
+ std::string toIdentifier () const {
159
+ std::string result;
160
+ llvm::raw_string_ostream os (result);
161
+ print (os, " _" , " " );
162
+ return os.str ();
163
+ }
99
164
100
- unsigned getDegree () const ;
165
+ unsigned getDegree () const {
166
+ return terms.back ().getExponent ().getZExtValue ();
167
+ }
101
168
102
169
ArrayRef<Monomial> getTerms () const { return terms; }
103
170
104
- friend ::llvm::hash_code hash_value (const Polynomial &arg);
171
+ template <typename T>
172
+ friend ::llvm::hash_code hash_value (const PolynomialBase<T> &arg);
105
173
106
174
private:
107
175
// The monomial terms for this polynomial.
108
176
SmallVector<Monomial> terms;
109
177
};
110
178
111
- // Make Polynomial hashable.
112
- inline ::llvm::hash_code hash_value (const Polynomial &arg) {
179
+ // / A single-variable polynomial with integer coefficients.
180
+ // /
181
+ // / Eg: x^1024 + x + 1
182
+ class IntPolynomial : public PolynomialBase <IntMonomial> {
183
+ public:
184
+ explicit IntPolynomial (ArrayRef<IntMonomial> terms) : PolynomialBase(terms) {}
185
+
186
+ // Returns a Polynomial from a list of monomials.
187
+ // Fails if two monomials have the same exponent.
188
+ static FailureOr<IntPolynomial>
189
+ fromMonomials (ArrayRef<IntMonomial> monomials);
190
+
191
+ // / Returns a polynomial with coefficients given by `coeffs`. The value
192
+ // / coeffs[i] is converted to a monomial with exponent i.
193
+ static IntPolynomial fromCoefficients (ArrayRef<int64_t > coeffs);
194
+ };
195
+
196
+ // / A single-variable polynomial with double coefficients.
197
+ // /
198
+ // / Eg: 1.0 x^1024 + 3.5 x + 1e-05
199
+ class FloatPolynomial : public PolynomialBase <FloatMonomial> {
200
+ public:
201
+ explicit FloatPolynomial (ArrayRef<FloatMonomial> terms)
202
+ : PolynomialBase(terms) {}
203
+
204
+ // Returns a Polynomial from a list of monomials.
205
+ // Fails if two monomials have the same exponent.
206
+ static FailureOr<FloatPolynomial>
207
+ fromMonomials (ArrayRef<FloatMonomial> monomials);
208
+
209
+ // / Returns a polynomial with coefficients given by `coeffs`. The value
210
+ // / coeffs[i] is converted to a monomial with exponent i.
211
+ static FloatPolynomial fromCoefficients (ArrayRef<double > coeffs);
212
+ };
213
+
214
+ // Make Polynomials hashable.
215
+ template <typename T>
216
+ inline ::llvm::hash_code hash_value (const PolynomialBase<T> &arg) {
113
217
return ::llvm::hash_combine_range (arg.terms .begin (), arg.terms .end ());
114
218
}
115
219
116
- inline ::llvm::hash_code hash_value (const Monomial &arg) {
220
+ template <typename T>
221
+ inline ::llvm::hash_code hash_value (const MonomialBase<T> &arg) {
117
222
return llvm::hash_combine (::llvm::hash_value (arg.coefficient ),
118
223
::llvm::hash_value (arg.exponent));
119
224
}
120
225
121
- inline raw_ostream &operator <<(raw_ostream &os, const Polynomial &polynomial) {
226
+ template <typename T>
227
+ inline raw_ostream &operator <<(raw_ostream &os,
228
+ const PolynomialBase<T> &polynomial) {
122
229
polynomial.print (os);
123
230
return os;
124
231
}
0 commit comments