Skip to content

Commit 91a14db

Browse files
authored
Support polynomial attributes with floating point coefficients (#91137)
In summary: - `Monomial` -> `MonomialBase` with two inheriting `IntMonomial` and `FloatMonomial` for the different coefficient types - `Polynomial` -> `PolynomialBase` with `IntPolynomial` and `FloatPolynomial` inheriting - `PolynomialAttr` -> `IntPolynomialAttr`, and new `FloatPolynomialAttr` attribute, both of which may be input to `polynomial.constant` - Refactoring common parts of attribute parsers. --------- Co-authored-by: Jeremy Kun <j2kun@users.noreply.github.com>
1 parent 7964356 commit 91a14db

File tree

8 files changed

+451
-348
lines changed

8 files changed

+451
-348
lines changed

mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h

Lines changed: 150 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111

1212
#include "mlir/Support/LLVM.h"
1313
#include "mlir/Support/LogicalResult.h"
14+
#include "llvm/ADT/APFloat.h"
1415
#include "llvm/ADT/APInt.h"
1516
#include "llvm/ADT/ArrayRef.h"
1617
#include "llvm/ADT/Hashing.h"
17-
#include "llvm/ADT/SmallVector.h"
18+
#include "llvm/ADT/SmallString.h"
19+
#include "llvm/ADT/Twine.h"
20+
#include "llvm/Support/raw_ostream.h"
1821

1922
namespace mlir {
2023

@@ -27,98 +30,202 @@ namespace polynomial {
2730
/// would want to specify 128-bit polynomials statically in the source code.
2831
constexpr unsigned apintBitWidth = 64;
2932

30-
/// A class representing a monomial of a single-variable polynomial with integer
31-
/// coefficients.
32-
class Monomial {
33+
template <typename CoefficientType>
34+
class MonomialBase {
3335
public:
34-
Monomial(int64_t coeff, uint64_t expo)
35-
: coefficient(apintBitWidth, coeff), exponent(apintBitWidth, expo) {}
36-
37-
Monomial(const APInt &coeff, const APInt &expo)
36+
MonomialBase(const CoefficientType &coeff, const APInt &expo)
3837
: coefficient(coeff), exponent(expo) {}
38+
virtual ~MonomialBase() = 0;
3939

40-
Monomial() : coefficient(apintBitWidth, 0), exponent(apintBitWidth, 0) {}
40+
const CoefficientType &getCoefficient() const { return coefficient; }
41+
CoefficientType &getMutableCoefficient() { return coefficient; }
42+
const APInt &getExponent() const { return exponent; }
43+
void setCoefficient(const CoefficientType &coeff) { coefficient = coeff; }
44+
void setExponent(const APInt &exp) { exponent = exp; }
4145

42-
bool operator==(const Monomial &other) const {
46+
bool operator==(const MonomialBase &other) const {
4347
return other.coefficient == coefficient && other.exponent == exponent;
4448
}
45-
bool operator!=(const Monomial &other) const {
49+
bool operator!=(const MonomialBase &other) const {
4650
return other.coefficient != coefficient || other.exponent != exponent;
4751
}
4852

4953
/// Monomials are ordered by exponent.
50-
bool operator<(const Monomial &other) const {
54+
bool operator<(const MonomialBase &other) const {
5155
return (exponent.ult(other.exponent));
5256
}
5357

54-
friend ::llvm::hash_code hash_value(const Monomial &arg);
58+
virtual bool isMonic() const = 0;
59+
virtual void
60+
coefficientToString(llvm::SmallString<16> &coeffString) const = 0;
5561

56-
public:
57-
APInt coefficient;
62+
template <typename T>
63+
friend ::llvm::hash_code hash_value(const MonomialBase<T> &arg);
5864

59-
// Always unsigned
65+
protected:
66+
CoefficientType coefficient;
6067
APInt exponent;
6168
};
6269

63-
/// A single-variable polynomial with integer coefficients.
64-
///
65-
/// Eg: x^1024 + x + 1
66-
///
67-
/// The symbols used as the polynomial's indeterminate don't matter, so long as
68-
/// it is used consistently throughout the polynomial.
69-
class Polynomial {
70+
/// A class representing a monomial of a single-variable polynomial with integer
71+
/// coefficients.
72+
class IntMonomial : public MonomialBase<APInt> {
7073
public:
71-
Polynomial() = delete;
74+
IntMonomial(int64_t coeff, uint64_t expo)
75+
: MonomialBase(APInt(apintBitWidth, coeff), APInt(apintBitWidth, expo)) {}
7276

73-
explicit Polynomial(ArrayRef<Monomial> terms) : terms(terms){};
77+
IntMonomial()
78+
: MonomialBase(APInt(apintBitWidth, 0), APInt(apintBitWidth, 0)) {}
7479

75-
// Returns a Polynomial from a list of monomials.
76-
// Fails if two monomials have the same exponent.
77-
static FailureOr<Polynomial> fromMonomials(ArrayRef<Monomial> monomials);
80+
~IntMonomial() = default;
7881

79-
/// Returns a polynomial with coefficients given by `coeffs`. The value
80-
/// coeffs[i] is converted to a monomial with exponent i.
81-
static Polynomial fromCoefficients(ArrayRef<int64_t> coeffs);
82+
bool isMonic() const override { return coefficient == 1; }
83+
84+
void coefficientToString(llvm::SmallString<16> &coeffString) const override {
85+
coefficient.toStringSigned(coeffString);
86+
}
87+
};
88+
89+
/// A class representing a monomial of a single-variable polynomial with integer
90+
/// coefficients.
91+
class FloatMonomial : public MonomialBase<APFloat> {
92+
public:
93+
FloatMonomial(double coeff, uint64_t expo)
94+
: MonomialBase(APFloat(coeff), APInt(apintBitWidth, expo)) {}
95+
96+
FloatMonomial() : MonomialBase(APFloat((double)0), APInt(apintBitWidth, 0)) {}
97+
98+
~FloatMonomial() = default;
99+
100+
bool isMonic() const override { return coefficient == APFloat(1.0); }
101+
102+
void coefficientToString(llvm::SmallString<16> &coeffString) const override {
103+
coefficient.toString(coeffString);
104+
}
105+
};
106+
107+
template <typename Monomial>
108+
class PolynomialBase {
109+
public:
110+
PolynomialBase() = delete;
111+
112+
explicit PolynomialBase(ArrayRef<Monomial> terms) : terms(terms){};
82113

83114
explicit operator bool() const { return !terms.empty(); }
84-
bool operator==(const Polynomial &other) const {
115+
bool operator==(const PolynomialBase &other) const {
85116
return other.terms == terms;
86117
}
87-
bool operator!=(const Polynomial &other) const {
118+
bool operator!=(const PolynomialBase &other) const {
88119
return !(other.terms == terms);
89120
}
90121

91-
// Prints polynomial to 'os'.
92-
void print(raw_ostream &os) const;
93122
void print(raw_ostream &os, ::llvm::StringRef separator,
94-
::llvm::StringRef exponentiation) const;
123+
::llvm::StringRef exponentiation) const {
124+
bool first = true;
125+
for (const Monomial &term : getTerms()) {
126+
if (first) {
127+
first = false;
128+
} else {
129+
os << separator;
130+
}
131+
std::string coeffToPrint;
132+
if (term.isMonic() && term.getExponent().uge(1)) {
133+
coeffToPrint = "";
134+
} else {
135+
llvm::SmallString<16> coeffString;
136+
term.coefficientToString(coeffString);
137+
coeffToPrint = coeffString.str();
138+
}
139+
140+
if (term.getExponent() == 0) {
141+
os << coeffToPrint;
142+
} else if (term.getExponent() == 1) {
143+
os << coeffToPrint << "x";
144+
} else {
145+
llvm::SmallString<16> expString;
146+
term.getExponent().toStringSigned(expString);
147+
os << coeffToPrint << "x" << exponentiation << expString;
148+
}
149+
}
150+
}
151+
152+
// Prints polynomial to 'os'.
153+
void print(raw_ostream &os) const { print(os, " + ", "**"); }
154+
95155
void dump() const;
96156

97157
// Prints polynomial so that it can be used as a valid identifier
98-
std::string toIdentifier() const;
158+
std::string toIdentifier() const {
159+
std::string result;
160+
llvm::raw_string_ostream os(result);
161+
print(os, "_", "");
162+
return os.str();
163+
}
99164

100-
unsigned getDegree() const;
165+
unsigned getDegree() const {
166+
return terms.back().getExponent().getZExtValue();
167+
}
101168

102169
ArrayRef<Monomial> getTerms() const { return terms; }
103170

104-
friend ::llvm::hash_code hash_value(const Polynomial &arg);
171+
template <typename T>
172+
friend ::llvm::hash_code hash_value(const PolynomialBase<T> &arg);
105173

106174
private:
107175
// The monomial terms for this polynomial.
108176
SmallVector<Monomial> terms;
109177
};
110178

111-
// Make Polynomial hashable.
112-
inline ::llvm::hash_code hash_value(const Polynomial &arg) {
179+
/// A single-variable polynomial with integer coefficients.
180+
///
181+
/// Eg: x^1024 + x + 1
182+
class IntPolynomial : public PolynomialBase<IntMonomial> {
183+
public:
184+
explicit IntPolynomial(ArrayRef<IntMonomial> terms) : PolynomialBase(terms) {}
185+
186+
// Returns a Polynomial from a list of monomials.
187+
// Fails if two monomials have the same exponent.
188+
static FailureOr<IntPolynomial>
189+
fromMonomials(ArrayRef<IntMonomial> monomials);
190+
191+
/// Returns a polynomial with coefficients given by `coeffs`. The value
192+
/// coeffs[i] is converted to a monomial with exponent i.
193+
static IntPolynomial fromCoefficients(ArrayRef<int64_t> coeffs);
194+
};
195+
196+
/// A single-variable polynomial with double coefficients.
197+
///
198+
/// Eg: 1.0 x^1024 + 3.5 x + 1e-05
199+
class FloatPolynomial : public PolynomialBase<FloatMonomial> {
200+
public:
201+
explicit FloatPolynomial(ArrayRef<FloatMonomial> terms)
202+
: PolynomialBase(terms) {}
203+
204+
// Returns a Polynomial from a list of monomials.
205+
// Fails if two monomials have the same exponent.
206+
static FailureOr<FloatPolynomial>
207+
fromMonomials(ArrayRef<FloatMonomial> monomials);
208+
209+
/// Returns a polynomial with coefficients given by `coeffs`. The value
210+
/// coeffs[i] is converted to a monomial with exponent i.
211+
static FloatPolynomial fromCoefficients(ArrayRef<double> coeffs);
212+
};
213+
214+
// Make Polynomials hashable.
215+
template <typename T>
216+
inline ::llvm::hash_code hash_value(const PolynomialBase<T> &arg) {
113217
return ::llvm::hash_combine_range(arg.terms.begin(), arg.terms.end());
114218
}
115219

116-
inline ::llvm::hash_code hash_value(const Monomial &arg) {
220+
template <typename T>
221+
inline ::llvm::hash_code hash_value(const MonomialBase<T> &arg) {
117222
return llvm::hash_combine(::llvm::hash_value(arg.coefficient),
118223
::llvm::hash_value(arg.exponent));
119224
}
120225

121-
inline raw_ostream &operator<<(raw_ostream &os, const Polynomial &polynomial) {
226+
template <typename T>
227+
inline raw_ostream &operator<<(raw_ostream &os,
228+
const PolynomialBase<T> &polynomial) {
122229
polynomial.print(os);
123230
return os;
124231
}

0 commit comments

Comments
 (0)