Skip to content

Commit 04614be

Browse files
committed
try constBinaryFold
1 parent b340e5c commit 04614be

File tree

3 files changed

+73
-2
lines changed

3 files changed

+73
-2
lines changed

mlir/include/mlir/Dialect/CommonFolders.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
8282
if (!elementResult)
8383
return {};
8484

85-
return DenseElementsAttr::get(cast<ShapedType>(resultType), *elementResult);
85+
return DenseElementsAttr::get(cast<ShapedType>(resultType),
86+
llvm::ArrayRef(*elementResult));
8687
}
8788

8889
if (isa<ElementsAttr>(operands[0]) && isa<ElementsAttr>(operands[1])) {

mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,23 @@ def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynom
8383
let hasCustomAssemblyFormat = 1;
8484
}
8585

86+
def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr<
87+
"TypedIntPolynomial", "typed_int_polynomial", [TypedAttrInterface]> {
88+
let summary = "A typed variant of int_polynomial for constant folding.";
89+
let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::IntPolynomial":$value);
90+
let assemblyFormat = "`<` struct(params) `>`";
91+
let builders = [
92+
AttrBuilderWithInferredContext<(ins "Type":$type,
93+
"const IntPolynomial &":$value), [{
94+
return $_get(type.getContext(), type, value);
95+
}]>
96+
];
97+
let extraClassDeclaration = [{
98+
// used for constFoldBinaryOp
99+
using ValueType = ::mlir::polynomial::IntPolynomial;
100+
}];
101+
}
102+
86103
def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_polynomial"> {
87104
let summary = "An attribute containing a single-variable polynomial with double precision floating point coefficients.";
88105
let description = [{
@@ -105,6 +122,23 @@ def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_p
105122
let hasCustomAssemblyFormat = 1;
106123
}
107124

125+
def Polynomial_TypedFloatPolynomialAttr : Polynomial_Attr<
126+
"TypedFloatPolynomial", "typed_float_polynomial", [TypedAttrInterface]> {
127+
let summary = "A typed variant of float_polynomial for constant folding.";
128+
let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::FloatPolynomial":$value);
129+
let assemblyFormat = "`<` struct(params) `>`";
130+
let builders = [
131+
AttrBuilderWithInferredContext<(ins "Type":$type,
132+
"const FloatPolynomial &":$value), [{
133+
return $_get(type.getContext(), type, value);
134+
}]>
135+
];
136+
let extraClassDeclaration = [{
137+
// used for constFoldBinaryOp
138+
using ValueType = ::mlir::polynomial::FloatPolynomial;
139+
}];
140+
}
141+
108142
def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
109143
let summary = "An attribute specifying a polynomial ring.";
110144
let description = [{
@@ -221,6 +255,7 @@ def Polynomial_AddOp : Polynomial_BinaryOp<"add", [Commutative]> {
221255
%2 = polynomial.add %0, %1 : !polynomial.polynomial<#ring>
222256
```
223257
}];
258+
let hasFolder = 1;
224259
}
225260

226261
def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> {
@@ -442,7 +477,7 @@ def Polynomial_AnyPolynomialAttr : AnyAttrOf<[
442477
]>;
443478

444479
// Not deriving from Polynomial_Op due to need for custom assembly format
445-
def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
480+
def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure, ConstantLike]> {
446481
let summary = "Define a constant polynomial via an attribute.";
447482
let description = [{
448483
Example:
@@ -459,6 +494,7 @@ def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
459494
let arguments = (ins Polynomial_AnyPolynomialAttr:$value);
460495
let results = (outs Polynomial_PolynomialType:$output);
461496
let assemblyFormat = "attr-dict `:` type($output)";
497+
let hasFolder = 1;
462498
}
463499

464500
def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {

mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88

99
#include "mlir/Dialect/Polynomial/IR/PolynomialOps.h"
1010
#include "mlir/Dialect/Arith/IR/Arith.h"
11+
#include "mlir/Dialect/CommonFolders.h"
1112
#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
1213
#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
1314
#include "mlir/Dialect/Polynomial/IR/PolynomialTypes.h"
1415
#include "mlir/IR/Builders.h"
16+
#include "mlir/IR/BuiltinAttributes.h"
1517
#include "mlir/IR/BuiltinTypes.h"
1618
#include "mlir/IR/Dialect.h"
1719
#include "mlir/IR/PatternMatch.h"
@@ -21,6 +23,38 @@
2123
using namespace mlir;
2224
using namespace mlir::polynomial;
2325

26+
OpFoldResult ConstantOp::fold(ConstantOp::FoldAdaptor adaptor) {
27+
PolynomialType ty = dyn_cast<PolynomialType>(getOutput().getType());
28+
29+
if (isa<FloatPolynomialAttr>(ty.getRing().getPolynomialModulus()))
30+
return TypedFloatPolynomialAttr::get(ty, cast<FloatPolynomialAttr>(getValue()).getPolynomial());
31+
32+
assert(isa<IntPolynomialAttr>(ty.getRing().getPolynomialModulus()) &&
33+
"expected float or integer polynomial");
34+
return TypedIntPolynomialAttr::get(ty,cast<IntPolynomialAttr>(getValue()).getPolynomial());
35+
}
36+
37+
OpFoldResult AddOp::fold(AddOp::FoldAdaptor adaptor) {
38+
// Folded input attributes can either be typed_int_polynomial or
39+
// typed_float_polynomial, and those require different invocations of
40+
// constFoldBinaryOp.
41+
PolynomialType ty = dyn_cast<PolynomialType>(getLhs().getType());
42+
if (!ty) {
43+
ShapedType shapedTy = dyn_cast<ShapedType>(getLhs().getType());
44+
assert(shapedTy && "lhs must be a polynomial or a shaped type");
45+
ty = cast<PolynomialType>(shapedTy.getElementType());
46+
}
47+
48+
if (isa<FloatPolynomialAttr>(ty.getRing().getPolynomialModulus()))
49+
return constFoldBinaryOp<TypedFloatPolynomialAttr>(
50+
adaptor.getOperands(), getLhs().getType(),
51+
[](FloatPolynomial a, const FloatPolynomial &b) { return a.add(b); });
52+
53+
return constFoldBinaryOp<TypedIntPolynomialAttr>(
54+
adaptor.getOperands(), getLhs().getType(),
55+
[](IntPolynomial a, const IntPolynomial &b) { return a.add(b); });
56+
}
57+
2458
void FromTensorOp::build(OpBuilder &builder, OperationState &result,
2559
Value input, RingAttr ring) {
2660
TensorType tensorType = dyn_cast<TensorType>(input.getType());

0 commit comments

Comments
 (0)