22#include < cstdint>
33#include < llvm/ADT/APFloat.h>
44#include < llvm/ADT/APInt.h>
5+ #include < llvm/ADT/APSInt.h>
56#include < llvm/ADT/ArrayRef.h>
7+ #include < llvm/ADT/TypeSwitch.h>
68#include < llvm/Support/Casting.h>
79#include < llvm/Support/ErrorHandling.h>
810#include < mlir/Dialect/PDL/IR/Builtins.h>
@@ -56,52 +58,200 @@ mlir::Attribute addElemToArrayAttr(mlir::PatternRewriter &rewriter,
5658 return rewriter.getArrayAttr (values);
5759}
5860
59- LogicalResult add (mlir::PatternRewriter &rewriter, mlir::PDLResultList &results,
60- llvm::ArrayRef<mlir::PDLValue> args) {
61- assert (args.size () == 2 && " Expected 2 arguments" );
61+ template <UnaryOpKind T>
62+ LogicalResult unaryOp (PatternRewriter &rewriter, PDLResultList &results,
63+ ArrayRef<PDLValue> args) {
64+ assert (args.size () == 1 && " Expected one operand for unary operation" );
65+ auto operandAttr = args[0 ].cast <Attribute>();
66+
67+ if (auto operandIntAttr = dyn_cast_or_null<IntegerAttr>(operandAttr)) {
68+ auto integerType = cast<IntegerType>(operandIntAttr.getType ());
69+ auto bitWidth = integerType.getIntOrFloatBitWidth ();
70+
71+ if constexpr (T == UnaryOpKind::exp2) {
72+ uint64_t resultVal =
73+ integerType.isUnsigned () || integerType.isSignless ()
74+ ? std::pow (2 , operandIntAttr.getValue ().getZExtValue ())
75+ : std::pow (2 , operandIntAttr.getValue ().getSExtValue ());
76+
77+ APInt resultInt (bitWidth, resultVal, integerType.isSigned ());
78+
79+ bool isOverflow = integerType.isSigned ()
80+ ? resultInt.slt (operandIntAttr.getValue ())
81+ : resultInt.ult (operandIntAttr.getValue ());
82+
83+ if (isOverflow)
84+ return failure ();
85+
86+ results.push_back (rewriter.getIntegerAttr (integerType, resultInt));
87+ } else if constexpr (T == UnaryOpKind::log2) {
88+ auto getIntegerAsAttr = [&](const APSInt &value) {
89+ int32_t log2Value = value.exactLogBase2 ();
90+ assert (log2Value >= 0 &&
91+ " log2 of an integer is expected to return an exact integer." );
92+ return rewriter.getIntegerAttr (
93+ integerType,
94+ APSInt (APInt (bitWidth, log2Value), integerType.isUnsigned ()));
95+ };
96+ // for log2 we treat signless integer as signed
97+ if (integerType.isSignless ())
98+ results.push_back (
99+ getIntegerAsAttr (APSInt (operandIntAttr.getValue (), false )));
100+ else
101+ results.push_back (getIntegerAsAttr (operandIntAttr.getAPSInt ()));
102+ } else {
103+ llvm::llvm_unreachable_internal (
104+ " encountered an unsupported unary operator" );
105+ return failure ();
106+ }
107+ return success ();
108+ }
109+
110+ if (auto operandFloatAttr = dyn_cast_or_null<FloatAttr>(operandAttr)) {
111+ // auto floatType = operandFloatAttr.getType();
112+
113+ if constexpr (T == UnaryOpKind::exp2) {
114+ // auto maxVal = APFloat::getLargest(llvm::APFloat::IEEEhalf());
115+ // auto minVal = APFloat::getSmallest(llvm::APFloat::IEEEhalf());
116+
117+ auto type = operandFloatAttr.getType ();
118+
119+ return TypeSwitch<Type, LogicalResult>(type)
120+ .template Case <Float64Type>([&results, &rewriter,
121+ &operandFloatAttr](auto floatType) {
122+ APFloat resultAPFloat (
123+ std::exp2 (operandFloatAttr.getValue ().convertToDouble ()));
124+
125+ // check overflow
126+ if (!resultAPFloat.isNormal ())
127+ return failure ();
128+
129+ results.push_back (rewriter.getFloatAttr (floatType, resultAPFloat));
130+ return success ();
131+ })
132+ .template Case <Float32Type, Float16Type, BFloat16Type>(
133+ [&results, &rewriter, &operandFloatAttr](auto floatType) {
134+ APFloat resultAPFloat (
135+ std::exp2 (operandFloatAttr.getValue ().convertToFloat ()));
136+
137+ // check overflow and underflow
138+ // If overflow happens, resultAPFloat is inf
139+ // If underflow happens, resultAPFloat is 0
140+ if (!resultAPFloat.isNormal ())
141+ return failure ();
142+
143+ results.push_back (
144+ rewriter.getFloatAttr (floatType, resultAPFloat));
145+ return success ();
146+ })
147+ .Default ([](Type /* type*/ ) { return failure (); });
148+ } else if constexpr (T == UnaryOpKind::log2) {
149+ auto minF32 = APFloat::getSmallest (llvm::APFloat::IEEEsingle ());
150+
151+ APFloat resultFloat ((float )operandFloatAttr.getValue ().getExactLog2 ());
152+ results.push_back (
153+ rewriter.getFloatAttr (operandFloatAttr.getType (), resultFloat));
154+ }
155+ return success ();
156+ }
157+ return failure ();
158+ }
159+
160+ template <BinaryOpKind T>
161+ LogicalResult binaryOp (PatternRewriter &rewriter, PDLResultList &results,
162+ llvm::ArrayRef<PDLValue> args) {
163+ assert (args.size () == 2 && " Expected two operands for binary operation" );
62164 auto lhsAttr = args[0 ].cast <Attribute>();
63165 auto rhsAttr = args[1 ].cast <Attribute>();
64166
65- // Integer
66167 if (auto lhsIntAttr = dyn_cast_or_null<IntegerAttr>(lhsAttr)) {
67168 auto rhsIntAttr = dyn_cast_or_null<IntegerAttr>(rhsAttr);
68- if (!rhsIntAttr || lhsIntAttr.getType () != rhsIntAttr.getType ())
169+ if (!rhsIntAttr || lhsIntAttr.getType () != rhsIntAttr.getType ()) {
69170 return failure ();
171+ }
70172
71173 auto integerType = lhsIntAttr.getType ();
72-
73- bool isOverflow;
74- llvm::APInt resultAPInt;
75- if (integerType.isUnsignedInteger () || integerType.isSignlessInteger ()) {
76- resultAPInt =
77- lhsIntAttr.getValue ().uadd_ov (rhsIntAttr.getValue (), isOverflow);
174+ APInt resultAPInt;
175+ bool isOverflow = false ;
176+ if constexpr (T == BinaryOpKind::add) {
177+ if (integerType.isSignlessInteger () || integerType.isUnsignedInteger ()) {
178+ resultAPInt =
179+ lhsIntAttr.getValue ().uadd_ov (rhsIntAttr.getValue (), isOverflow);
180+ } else {
181+ resultAPInt =
182+ lhsIntAttr.getValue ().sadd_ov (rhsIntAttr.getValue (), isOverflow);
183+ }
184+ } else if constexpr (T == BinaryOpKind::sub) {
185+ if (integerType.isSignlessInteger () || integerType.isUnsignedInteger ()) {
186+ resultAPInt =
187+ lhsIntAttr.getValue ().usub_ov (rhsIntAttr.getValue (), isOverflow);
188+ } else {
189+ resultAPInt =
190+ lhsIntAttr.getValue ().ssub_ov (rhsIntAttr.getValue (), isOverflow);
191+ }
192+ } else if constexpr (T == BinaryOpKind::mul) {
193+ if (integerType.isSignlessInteger () || integerType.isUnsignedInteger ()) {
194+ resultAPInt =
195+ lhsIntAttr.getValue ().umul_ov (rhsIntAttr.getValue (), isOverflow);
196+ } else {
197+ resultAPInt =
198+ lhsIntAttr.getValue ().smul_ov (rhsIntAttr.getValue (), isOverflow);
199+ }
200+ } else if constexpr (T == BinaryOpKind::div) {
201+ if (integerType.isSignlessInteger () || integerType.isUnsignedInteger ()) {
202+ resultAPInt = lhsIntAttr.getValue ().udiv (rhsIntAttr.getValue ());
203+ } else {
204+ resultAPInt =
205+ lhsIntAttr.getValue ().sdiv_ov (rhsIntAttr.getValue (), isOverflow);
206+ }
207+ } else if constexpr (T == BinaryOpKind::mod) {
208+ if (integerType.isSignlessInteger () || integerType.isUnsignedInteger ()) {
209+ resultAPInt = lhsIntAttr.getValue ().urem (rhsIntAttr.getValue ());
210+ } else {
211+ resultAPInt = lhsIntAttr.getValue ().srem (rhsIntAttr.getValue ());
212+ }
78213 } else {
79- resultAPInt =
80- lhsIntAttr.getValue ().sadd_ov (rhsIntAttr.getValue (), isOverflow);
214+ assert (false && " Unsupported binary operator" );
81215 }
82216
83- if (isOverflow) {
217+ if (isOverflow)
84218 return failure ();
85- }
86219
87220 results.push_back (rewriter.getIntegerAttr (integerType, resultAPInt));
88221 return success ();
89222 }
90223
91- // Float
92224 if (auto lhsFloatAttr = dyn_cast_or_null<FloatAttr>(lhsAttr)) {
93225 auto rhsFloatAttr = dyn_cast_or_null<FloatAttr>(rhsAttr);
94- if (!rhsFloatAttr || lhsFloatAttr.getType () != rhsFloatAttr.getType ())
226+ if (!rhsFloatAttr || lhsFloatAttr.getType () != rhsFloatAttr.getType ()) {
95227 return failure ();
228+ }
96229
97230 APFloat lhsVal = lhsFloatAttr.getValue ();
98231 APFloat rhsVal = rhsFloatAttr.getValue ();
99232 APFloat resultVal (lhsVal);
100233 auto floatType = lhsFloatAttr.getType ();
101234
102- bool isOverflow =
103- resultVal.add (rhsVal, llvm::APFloatBase::rmNearestTiesToEven);
104- if (isOverflow) {
235+ APFloat::opStatus operationStatus;
236+ if constexpr (T == BinaryOpKind::add) {
237+ operationStatus =
238+ resultVal.add (rhsVal, llvm::APFloatBase::rmNearestTiesToEven);
239+ } else if constexpr (T == BinaryOpKind::sub) {
240+ operationStatus =
241+ resultVal.subtract (rhsVal, llvm::APFloatBase::rmNearestTiesToEven);
242+ } else if constexpr (T == BinaryOpKind::mul) {
243+ operationStatus =
244+ resultVal.multiply (rhsVal, llvm::APFloatBase::rmNearestTiesToEven);
245+ } else if constexpr (T == BinaryOpKind::div) {
246+ operationStatus =
247+ resultVal.divide (rhsVal, llvm::APFloatBase::rmNearestTiesToEven);
248+ } else if constexpr (T == BinaryOpKind::mod) {
249+ operationStatus = resultVal.mod (rhsVal);
250+ } else {
251+ assert (false && " Unsupported binary operator" );
252+ }
253+
254+ if (operationStatus != APFloat::opOK) {
105255 return failure ();
106256 }
107257
@@ -110,6 +260,41 @@ LogicalResult add(mlir::PatternRewriter &rewriter, mlir::PDLResultList &results,
110260 }
111261 return failure ();
112262}
263+
264+ LogicalResult add (mlir::PatternRewriter &rewriter, mlir::PDLResultList &results,
265+ llvm::ArrayRef<mlir::PDLValue> args) {
266+ return binaryOp<BinaryOpKind::add>(rewriter, results, args);
267+ }
268+
269+ LogicalResult sub (mlir::PatternRewriter &rewriter, mlir::PDLResultList &results,
270+ llvm::ArrayRef<mlir::PDLValue> args) {
271+ return binaryOp<BinaryOpKind::sub>(rewriter, results, args);
272+ }
273+
274+ LogicalResult mul (PatternRewriter &rewriter, PDLResultList &results,
275+ llvm::ArrayRef<PDLValue> args) {
276+ return binaryOp<BinaryOpKind::mul>(rewriter, results, args);
277+ }
278+
279+ LogicalResult div (PatternRewriter &rewriter, PDLResultList &results,
280+ llvm::ArrayRef<PDLValue> args) {
281+ return binaryOp<BinaryOpKind::div>(rewriter, results, args);
282+ }
283+
284+ LogicalResult mod (PatternRewriter &rewriter, PDLResultList &results,
285+ ArrayRef<PDLValue> args) {
286+ return binaryOp<BinaryOpKind::mod>(rewriter, results, args);
287+ }
288+
289+ LogicalResult exp2 (PatternRewriter &rewriter, PDLResultList &results,
290+ llvm::ArrayRef<PDLValue> args) {
291+ return unaryOp<UnaryOpKind::exp2>(rewriter, results, args);
292+ }
293+
294+ LogicalResult log2 (PatternRewriter &rewriter, PDLResultList &results,
295+ llvm::ArrayRef<PDLValue> args) {
296+ return unaryOp<UnaryOpKind::log2>(rewriter, results, args);
297+ }
113298} // namespace builtin
114299
115300void registerBuiltins (PDLPatternModule &pdlPattern) {
@@ -128,6 +313,27 @@ void registerBuiltins(PDLPatternModule &pdlPattern) {
128313 pdlPattern.registerConstraintFunctionWithResults (
129314 " __builtin_addEntryToDictionaryAttr_constraint" ,
130315 addEntryToDictionaryAttr);
131- pdlPattern.registerConstraintFunctionWithResults (" __builtin_add" , add);
316+ pdlPattern.registerRewriteFunction (" __builtin_mulRewrite" , mul);
317+ pdlPattern.registerRewriteFunction (" __builtin_divRewrite" , div);
318+ pdlPattern.registerRewriteFunction (" __builtin_modRewrite" , mod);
319+ pdlPattern.registerRewriteFunction (" __builtin_addRewrite" , add);
320+ pdlPattern.registerRewriteFunction (" __builtin_subRewrite" , sub);
321+ pdlPattern.registerRewriteFunction (" __builtin_log2Rewrite" , log2);
322+ pdlPattern.registerRewriteFunction (" __builtin_exp2Rewrite" , exp2);
323+
324+ pdlPattern.registerConstraintFunctionWithResults (" __builtin_mulConstraint" ,
325+ mul);
326+ pdlPattern.registerConstraintFunctionWithResults (" __builtin_divConstraint" ,
327+ div);
328+ pdlPattern.registerConstraintFunctionWithResults (" __builtin_modConstraint" ,
329+ mod);
330+ pdlPattern.registerConstraintFunctionWithResults (" __builtin_addConstraint" ,
331+ add);
332+ pdlPattern.registerConstraintFunctionWithResults (" __builtin_subConstraint" ,
333+ sub);
334+ pdlPattern.registerConstraintFunctionWithResults (" __builtin_log2Constraint" ,
335+ log2);
336+ pdlPattern.registerConstraintFunctionWithResults (" __builtin_exp2Constraint" ,
337+ exp2);
132338}
133- } // namespace mlir::pdl
339+ } // namespace mlir::pdl
0 commit comments