Skip to content

[mlir][polynomial] implement add for polynomial data structure #92169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 65 additions & 18 deletions mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace polynomial {
/// would want to specify 128-bit polynomials statically in the source code.
constexpr unsigned apintBitWidth = 64;

template <typename CoefficientType>
template <class Derived, typename CoefficientType>
class MonomialBase {
public:
MonomialBase(const CoefficientType &coeff, const APInt &expo)
Expand All @@ -55,12 +55,21 @@ class MonomialBase {
return (exponent.ult(other.exponent));
}

Derived add(const Derived &other) {
assert(exponent == other.exponent);
CoefficientType newCoeff = coefficient + other.coefficient;
Derived result;
result.setCoefficient(newCoeff);
result.setExponent(exponent);
return result;
}

virtual bool isMonic() const = 0;
virtual void
coefficientToString(llvm::SmallString<16> &coeffString) const = 0;

template <typename T>
friend ::llvm::hash_code hash_value(const MonomialBase<T> &arg);
template <class D, typename T>
friend ::llvm::hash_code hash_value(const MonomialBase<D, T> &arg);

protected:
CoefficientType coefficient;
Expand All @@ -69,15 +78,15 @@ class MonomialBase {

/// A class representing a monomial of a single-variable polynomial with integer
/// coefficients.
class IntMonomial : public MonomialBase<APInt> {
class IntMonomial : public MonomialBase<IntMonomial, APInt> {
public:
IntMonomial(int64_t coeff, uint64_t expo)
: MonomialBase(APInt(apintBitWidth, coeff), APInt(apintBitWidth, expo)) {}

IntMonomial()
: MonomialBase(APInt(apintBitWidth, 0), APInt(apintBitWidth, 0)) {}

~IntMonomial() = default;
~IntMonomial() override = default;

bool isMonic() const override { return coefficient == 1; }

Expand All @@ -88,14 +97,14 @@ class IntMonomial : public MonomialBase<APInt> {

/// A class representing a monomial of a single-variable polynomial with integer
/// coefficients.
class FloatMonomial : public MonomialBase<APFloat> {
class FloatMonomial : public MonomialBase<FloatMonomial, APFloat> {
public:
FloatMonomial(double coeff, uint64_t expo)
: MonomialBase(APFloat(coeff), APInt(apintBitWidth, expo)) {}

FloatMonomial() : MonomialBase(APFloat((double)0), APInt(apintBitWidth, 0)) {}

~FloatMonomial() = default;
~FloatMonomial() override = default;

bool isMonic() const override { return coefficient == APFloat(1.0); }

Expand All @@ -104,7 +113,7 @@ class FloatMonomial : public MonomialBase<APFloat> {
}
};

template <typename Monomial>
template <class Derived, typename Monomial>
class PolynomialBase {
public:
PolynomialBase() = delete;
Expand Down Expand Up @@ -149,6 +158,44 @@ class PolynomialBase {
}
}

Derived add(const Derived &other) {
SmallVector<Monomial> newTerms;
auto it1 = terms.begin();
auto it2 = other.terms.begin();
while (it1 != terms.end() || it2 != other.terms.end()) {
if (it1 == terms.end()) {
newTerms.emplace_back(*it2);
it2++;
continue;
}

if (it2 == other.terms.end()) {
newTerms.emplace_back(*it1);
it1++;
continue;
}

while (it1->getExponent().ult(it2->getExponent())) {
newTerms.emplace_back(*it1);
it1++;
if (it1 == terms.end())
break;
}

while (it2->getExponent().ult(it1->getExponent())) {
newTerms.emplace_back(*it2);
it2++;
if (it2 == terms.end())
break;
}

newTerms.emplace_back(it1->add(*it2));
it1++;
it2++;
}
return Derived(newTerms);
}

// Prints polynomial to 'os'.
void print(raw_ostream &os) const { print(os, " + ", "**"); }

Expand All @@ -168,8 +215,8 @@ class PolynomialBase {

ArrayRef<Monomial> getTerms() const { return terms; }

template <typename T>
friend ::llvm::hash_code hash_value(const PolynomialBase<T> &arg);
template <class D, typename T>
friend ::llvm::hash_code hash_value(const PolynomialBase<D, T> &arg);

private:
// The monomial terms for this polynomial.
Expand All @@ -179,7 +226,7 @@ class PolynomialBase {
/// A single-variable polynomial with integer coefficients.
///
/// Eg: x^1024 + x + 1
class IntPolynomial : public PolynomialBase<IntMonomial> {
class IntPolynomial : public PolynomialBase<IntPolynomial, IntMonomial> {
public:
explicit IntPolynomial(ArrayRef<IntMonomial> terms) : PolynomialBase(terms) {}

Expand All @@ -196,7 +243,7 @@ class IntPolynomial : public PolynomialBase<IntMonomial> {
/// A single-variable polynomial with double coefficients.
///
/// Eg: 1.0 x^1024 + 3.5 x + 1e-05
class FloatPolynomial : public PolynomialBase<FloatMonomial> {
class FloatPolynomial : public PolynomialBase<FloatPolynomial, FloatMonomial> {
public:
explicit FloatPolynomial(ArrayRef<FloatMonomial> terms)
: PolynomialBase(terms) {}
Expand All @@ -212,20 +259,20 @@ class FloatPolynomial : public PolynomialBase<FloatMonomial> {
};

// Make Polynomials hashable.
template <typename T>
inline ::llvm::hash_code hash_value(const PolynomialBase<T> &arg) {
template <class D, typename T>
inline ::llvm::hash_code hash_value(const PolynomialBase<D, T> &arg) {
return ::llvm::hash_combine_range(arg.terms.begin(), arg.terms.end());
}

template <typename T>
inline ::llvm::hash_code hash_value(const MonomialBase<T> &arg) {
template <class D, typename T>
inline ::llvm::hash_code hash_value(const MonomialBase<D, T> &arg) {
return llvm::hash_combine(::llvm::hash_value(arg.coefficient),
::llvm::hash_value(arg.exponent));
}

template <typename T>
template <class D, typename T>
inline raw_ostream &operator<<(raw_ostream &os,
const PolynomialBase<T> &polynomial) {
const PolynomialBase<D, T> &polynomial) {
polynomial.print(os);
return os;
}
Expand Down
1 change: 1 addition & 0 deletions mlir/unittests/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_subdirectory(Index)
add_subdirectory(LLVMIR)
add_subdirectory(MemRef)
add_subdirectory(OpenACC)
add_subdirectory(Polynomial)
add_subdirectory(SCF)
add_subdirectory(SparseTensor)
add_subdirectory(SPIRV)
Expand Down
8 changes: 8 additions & 0 deletions mlir/unittests/Dialect/Polynomial/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
add_mlir_unittest(MLIRPolynomialTests
PolynomialMathTest.cpp
)
target_link_libraries(MLIRPolynomialTests
PRIVATE
MLIRIR
MLIRPolynomialDialect
)
44 changes: 44 additions & 0 deletions mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//===- PolynomialMathTest.cpp - Polynomial math Tests ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
#include "gtest/gtest.h"

using namespace mlir;
using namespace mlir::polynomial;

TEST(AddTest, checkSameDegreeAdditionOfIntPolynomial) {
IntPolynomial x = IntPolynomial::fromCoefficients({1, 2, 3});
IntPolynomial y = IntPolynomial::fromCoefficients({2, 3, 4});
IntPolynomial expected = IntPolynomial::fromCoefficients({3, 5, 7});
EXPECT_EQ(expected, x.add(y));
}

TEST(AddTest, checkDifferentDegreeAdditionOfIntPolynomial) {
IntMonomial term2t = IntMonomial(2, 1);
IntPolynomial x = IntPolynomial::fromMonomials({term2t}).value();
IntPolynomial y = IntPolynomial::fromCoefficients({2, 3, 4});
IntPolynomial expected = IntPolynomial::fromCoefficients({2, 5, 4});
EXPECT_EQ(expected, x.add(y));
EXPECT_EQ(expected, y.add(x));
}

TEST(AddTest, checkSameDegreeAdditionOfFloatPolynomial) {
FloatPolynomial x = FloatPolynomial::fromCoefficients({1.5, 2.5, 3.5});
FloatPolynomial y = FloatPolynomial::fromCoefficients({2.5, 3.5, 4.5});
FloatPolynomial expected = FloatPolynomial::fromCoefficients({4, 6, 8});
EXPECT_EQ(expected, x.add(y));
}

TEST(AddTest, checkDifferentDegreeAdditionOfFloatPolynomial) {
FloatPolynomial x = FloatPolynomial::fromCoefficients({1.5, 2.5});
FloatPolynomial y = FloatPolynomial::fromCoefficients({2.5, 3.5, 4.5});
FloatPolynomial expected = FloatPolynomial::fromCoefficients({4, 6, 4.5});
EXPECT_EQ(expected, x.add(y));
EXPECT_EQ(expected, y.add(x));
}
Loading