Skip to content

Commit 375a641

Browse files
authored
Return bool from isZero (EnzymeAD#2403)
* Return bool from isZero * fmt * fix
1 parent d48bb60 commit 375a641

File tree

6 files changed

+50
-72
lines changed

6 files changed

+50
-72
lines changed

enzyme/Enzyme/MLIR/Dialect/Ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ class FwdInpOpt final : public OpRewritePattern<ForwardDiffOp> {
230230
auto ET = inp.getType();
231231
auto ETintf = dyn_cast<AutoDiffTypeInterface>(ET);
232232

233-
if (ETintf && !isMutable(ET) && ETintf.isZero(inp).succeeded()) {
233+
if (ETintf && !isMutable(ET) && ETintf.isZero(inp)) {
234234
// skip and promote to const
235235
auto new_const = mlir::enzyme::ActivityAttr::get(
236236
rewriter.getContext(), mlir::enzyme::Activity::enzyme_const);
@@ -254,7 +254,7 @@ class FwdInpOpt final : public OpRewritePattern<ForwardDiffOp> {
254254
auto ET = inp.getType();
255255
auto ETintf = dyn_cast<AutoDiffTypeInterface>(ET);
256256

257-
if (ETintf && !isMutable(ET) && ETintf.isZero(inp).succeeded()) {
257+
if (ETintf && !isMutable(ET) && ETintf.isZero(inp)) {
258258
// skip and promote to const
259259
auto new_const = mlir::enzyme::ActivityAttr::get(
260260
rewriter.getContext(), mlir::enzyme::Activity::enzyme_const);

enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp

Lines changed: 34 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,12 @@ class FloatTypeInterface : public AutoDiffTypeInterface::ExternalModel<
7272
return failure();
7373
}
7474

75-
LogicalResult isZero(Type self, Value val) const {
76-
if (matchPattern(val, m_AnyZeroFloat())) {
77-
return success();
78-
}
79-
return failure();
75+
bool isZero(Type self, Value val) const {
76+
return matchPattern(val, m_AnyZeroFloat());
8077
}
8178

82-
LogicalResult isZeroAttr(Type self, Attribute attr) const {
83-
if (matchPattern(attr, m_AnyZeroFloat())) {
84-
return success();
85-
}
86-
return failure();
79+
bool isZeroAttr(Type self, Attribute attr) const {
80+
return matchPattern(attr, m_AnyZeroFloat());
8781
}
8882

8983
int64_t getApproxSize(Type self) const {
@@ -150,39 +144,39 @@ class TensorTypeInterface
150144
return failure();
151145
}
152146

153-
LogicalResult isZero(Type self, Value val) const {
147+
bool isZero(Type self, Value val) const {
154148
auto tenType = cast<TensorType>(self);
155149
auto ET = tenType.getElementType();
156150
DenseElementsAttr eAttr;
157151

158152
if (!matchPattern(val, m_Constant(&eAttr)))
159-
return failure();
153+
return false;
160154

161-
if (eAttr.isSplat()) {
162-
// recurse on the individual element type
163-
auto splatVal = eAttr.getSplatValue<Attribute>();
164-
auto ADET = dyn_cast<AutoDiffTypeInterface>(ET);
165-
return ADET ? ADET.isZeroAttr(splatVal) : failure();
166-
}
167-
168-
return failure();
155+
if (!eAttr.isSplat())
156+
return false;
157+
// recurse on the individual element type
158+
auto splatVal = eAttr.getSplatValue<Attribute>();
159+
auto ADET = dyn_cast<AutoDiffTypeInterface>(ET);
160+
return ADET && ADET.isZeroAttr(splatVal);
169161
}
170-
LogicalResult isZeroAttr(Type self, Attribute attr) const {
162+
163+
bool isZeroAttr(Type self, Attribute attr) const {
171164
auto eAttr = dyn_cast<DenseElementsAttr>(attr);
172165
if (!eAttr)
173-
return failure();
166+
return false;
167+
168+
if (!eAttr.isSplat())
169+
return false;
174170

175171
auto ET = eAttr.getType().getElementType();
176172
auto ADET = dyn_cast<AutoDiffTypeInterface>(ET);
177173

178174
if (!ADET)
179-
return failure();
175+
return false;
180176

181-
if (eAttr.isSplat()) {
182-
return ADET.isZeroAttr(eAttr.getSplatValue<Attribute>());
183-
} else
184-
return failure();
177+
return ADET.isZeroAttr(eAttr.getSplatValue<Attribute>());
185178
}
179+
186180
int64_t getApproxSize(Type self) const {
187181
auto tenType = cast<TensorType>(self);
188182
auto elType = cast<AutoDiffTypeInterface>(tenType.getElementType());
@@ -228,19 +222,14 @@ class IntegerTypeInterface
228222
return failure();
229223
}
230224

231-
LogicalResult isZero(Type self, Value val) const {
232-
if (matchPattern(val, m_Zero())) {
233-
return success();
234-
}
235-
return failure();
225+
bool isZero(Type self, Value val) const {
226+
return matchPattern(val, m_Zero());
236227
}
237228

238-
LogicalResult isZeroAttr(Type self, Attribute attr) const {
239-
if (matchPattern(attr, m_Zero())) {
240-
return success();
241-
}
242-
return failure();
229+
bool isZeroAttr(Type self, Attribute attr) const {
230+
return matchPattern(attr, m_Zero());
243231
}
232+
244233
int64_t getApproxSize(Type self) const {
245234
return self.getIntOrFloatBitWidth();
246235
}
@@ -278,36 +267,36 @@ class ComplexTypeInterface
278267
return failure();
279268
}
280269

281-
LogicalResult isZero(Type self, Value val) const {
270+
bool isZero(Type self, Value val) const {
282271
ArrayAttr arrayAttr;
283272

284273
if (!matchPattern(val, m_Constant(&arrayAttr))) {
285-
return failure();
274+
return false;
286275
}
287276
// reuse attributr check
288277
return this->isZeroAttr(self, arrayAttr);
289278
}
290279

291-
LogicalResult isZeroAttr(Type self, Attribute attr) const {
280+
bool isZeroAttr(Type self, Attribute attr) const {
292281
auto arrayAttr = dyn_cast<ArrayAttr>(attr);
293282
if (!arrayAttr || arrayAttr.size() != 2)
294-
return failure();
283+
return false;
295284

296285
// get the element type
297286
auto compType = cast<ComplexType>(self);
298287
auto elType = compType.getElementType();
299288
auto eltIntf = dyn_cast<AutoDiffTypeInterface>(elType);
300289

301290
if (!eltIntf)
302-
return failure();
291+
return false;
303292

304293
// recurse and accumulate info per attribute
305-
bool acc = true;
306294
for (auto eltAttr : arrayAttr) {
307-
acc = acc && mlir::succeeded(eltIntf.isZeroAttr(eltAttr));
295+
if (!eltIntf.isZeroAttr(eltAttr))
296+
return false;
308297
}
309298

310-
return success(acc);
299+
return true;
311300
}
312301

313302
int64_t getApproxSize(Type self) const {

enzyme/Enzyme/MLIR/Implementations/ComplexAutoDiffOpInterfaceImpl.cpp

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,32 +27,21 @@ using namespace mlir::enzyme;
2727
namespace {
2828
#include "Implementations/ComplexDerivatives.inc"
2929

30-
bool isZero(mlir::Value v) {
31-
ArrayAttr lhs;
32-
matchPattern(v, m_Constant(&lhs));
33-
if (lhs) {
34-
for (auto e : lhs) {
35-
if (!cast<FloatAttr>(e).getValue().isZero())
36-
return false;
37-
}
38-
return true;
39-
}
40-
return false;
41-
}
42-
4330
struct ComplexAddSimplifyMathInterface
4431
: public MathSimplifyInterface::ExternalModel<
4532
ComplexAddSimplifyMathInterface, complex::AddOp> {
4633
mlir::LogicalResult simplifyMath(Operation *src,
4734
PatternRewriter &rewriter) const {
4835
auto op = cast<complex::AddOp>(src);
4936

50-
if (isZero(op.getLhs())) {
37+
auto ATy = cast<AutoDiffTypeInterface>(op.getLhs().getType());
38+
39+
if (ATy.isZero(op.getLhs())) {
5140
rewriter.replaceOp(op, op.getRhs());
5241
return success();
5342
}
5443

55-
if (isZero(op.getRhs())) {
44+
if (ATy.isZero(op.getRhs())) {
5645
rewriter.replaceOp(op, op.getLhs());
5746
return success();
5847
}
@@ -68,12 +57,14 @@ struct ComplexSubSimplifyMathInterface
6857
PatternRewriter &rewriter) const {
6958
auto op = cast<complex::SubOp>(src);
7059

71-
if (isZero(op.getRhs())) {
60+
auto ATy = cast<AutoDiffTypeInterface>(op.getLhs().getType());
61+
62+
if (ATy.isZero(op.getRhs())) {
7263
rewriter.replaceOp(op, op.getLhs());
7364
return success();
7465
}
7566

76-
if (isZero(op.getLhs())) {
67+
if (ATy.isZero(op.getLhs())) {
7768
rewriter.replaceOpWithNewOp<complex::NegOp>(op, op.getRhs());
7869
return success();
7970
}

enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,8 @@ class PointerTypeInterface
6969
return failure();
7070
}
7171

72-
LogicalResult isZero(Type self, Value val) const { return failure(); }
73-
LogicalResult isZeroAttr(Type self, Attribute attr) const {
74-
return failure();
75-
}
72+
bool isZero(Type self, Value val) const { return false; }
73+
bool isZeroAttr(Type self, Attribute attr) const { return false; }
7674
};
7775
} // namespace
7876

enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ class MemRefTypeInterface
235235
return success();
236236
}
237237

238-
LogicalResult isZero(Type self, Value val) const { return failure(); }
239-
LogicalResult isZeroAttr(Type self, Attribute val) const { return failure(); }
238+
bool isZero(Type self, Value val) const { return false; }
239+
bool isZeroAttr(Type self, Attribute val) const { return false; }
240240
};
241241
} // namespace
242242

enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,15 @@ def AutoDiffTypeInterface : TypeInterface<"AutoDiffTypeInterface"> {
5353
/*desc=*/[{
5454
Check if the value with the given type is 0
5555
}],
56-
/*retTy=*/"::mlir::LogicalResult",
56+
/*retTy=*/"bool",
5757
/*methodName=*/"isZero",
5858
/*args=*/(ins "::mlir::Value":$val)
5959
>,
6060
InterfaceMethod<
6161
/*desc=*/[{
6262
Check if the mlir Attribute with the given type is 0.
6363
}],
64-
/*retTy=*/"::mlir::LogicalResult",
64+
/*retTy=*/"bool",
6565
/*methodName=*/"isZeroAttr",
6666
/*args=*/(ins "::mlir::Attribute":$attr)
6767
>,

0 commit comments

Comments
 (0)