Skip to content

Commit aab665f

Browse files
authored
Merge pull request #10 from EnzymeAD/Clang-Format
Clang Format
2 parents 392e55e + 3c31772 commit aab665f

19 files changed

+604
-461
lines changed

enzyme/Enzyme/MLIR/Dialect/Dialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
#include "Dialect/EnzymeOps.cpp.inc"
2121
#define GET_TYPEDEF_CLASSES
2222
#include "Dialect/EnzymeOpsTypes.cpp.inc"
23-
//#include "Dialect/EnzymeTypes.cpp.inc"
23+
// #include "Dialect/EnzymeTypes.cpp.inc"
2424

2525
using namespace mlir;
2626
using namespace mlir::enzyme;

enzyme/Enzyme/MLIR/Dialect/Ops.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030

3131
#include "llvm/ADT/TypeSwitch.h"
3232

33-
3433
#define DEBUG_TYPE "enzyme"
3534

3635
using namespace mlir;
@@ -54,9 +53,7 @@ ForwardDiffOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
5453
return success();
5554
}
5655

57-
58-
LogicalResult
59-
DiffOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
56+
LogicalResult DiffOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
6057
// TODO: Verify that the result type is same as the type of the referenced
6158
// func.func op.
6259
auto global =

enzyme/Enzyme/MLIR/Dialect/Ops.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
#include "Dialect/EnzymeOps.h.inc"
2222
#define GET_TYPEDEF_CLASSES
2323
#include "Dialect/EnzymeOpsTypes.h.inc"
24-
//#include "Dialect/EnzymeTypes.h.inc"
24+
// #include "Dialect/EnzymeTypes.h.inc"
2525

2626
#include "Dialect/EnzymeEnums.h.inc"
2727

enzyme/Enzyme/MLIR/Implementations/ArithAutoDiffOpInterfaceImpl.cpp

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -80,92 +80,107 @@ struct AddFOpInterface
8080
}
8181
};
8282

83-
void addToGradient(Value oldGradient, Value addedGradient, OpBuilder & builder, MGradientUtilsReverse *gutils){
83+
void addToGradient(Value oldGradient, Value addedGradient, OpBuilder &builder,
84+
MGradientUtilsReverse *gutils) {
8485
Value gradient = addedGradient;
85-
if(gutils->hasInvertPointer(oldGradient)){
86+
if (gutils->hasInvertPointer(oldGradient)) {
8687
Value operandGradient = gutils->invertPointerM(oldGradient, builder);
87-
gradient = builder.create<arith::AddFOp>(oldGradient.getLoc(), operandGradient, addedGradient);
88+
gradient = builder.create<arith::AddFOp>(oldGradient.getLoc(),
89+
operandGradient, addedGradient);
8890
}
8991
gutils->mapInvertPointer(oldGradient, gradient, builder);
9092
}
9193

92-
void defaultClearGradient(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils){
94+
void defaultClearGradient(Operation *op, OpBuilder &builder,
95+
MGradientUtilsReverse *gutils) {
9396
Value result = op->getOpResult(0);
94-
if (gutils->invertedPointersGlobal.contains(result)){
97+
if (gutils->invertedPointersGlobal.contains(result)) {
9598
FloatType floatType = result.getType().cast<FloatType>();
9699
APFloat apf(floatType.getFloatSemantics(), 0);
97100

98-
Value gradient = builder.create<arith::ConstantFloatOp>(op->getLoc(), apf, floatType);
101+
Value gradient =
102+
builder.create<arith::ConstantFloatOp>(op->getLoc(), apf, floatType);
99103
gutils->mapInvertPointer(result, gradient, builder);
100104
}
101105
}
102106

103-
struct AddFOpInterfaceReverse : public ReverseAutoDiffOpInterface::ExternalModel<AddFOpInterfaceReverse, arith::AddFOp> {
104-
void createReverseModeAdjoint(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils, SmallVector<Value> caches) const {
107+
struct AddFOpInterfaceReverse
108+
: public ReverseAutoDiffOpInterface::ExternalModel<AddFOpInterfaceReverse,
109+
arith::AddFOp> {
110+
void createReverseModeAdjoint(Operation *op, OpBuilder &builder,
111+
MGradientUtilsReverse *gutils,
112+
SmallVector<Value> caches) const {
105113
// Derivative of r = a + b -> dr = da + db
106114
auto addOp = cast<arith::AddFOp>(op);
107115

108-
if(gutils->hasInvertPointer(addOp)){
116+
if (gutils->hasInvertPointer(addOp)) {
109117
Value addedGradient = gutils->invertPointerM(addOp, builder);
110118
addToGradient(addOp.getLhs(), addedGradient, builder, gutils);
111119
addToGradient(addOp.getRhs(), addedGradient, builder, gutils);
112120
}
113121
}
114122

115-
SmallVector<Value> cacheValues(Operation *op, MGradientUtilsReverse *gutils) const {
123+
SmallVector<Value> cacheValues(Operation *op,
124+
MGradientUtilsReverse *gutils) const {
116125
return SmallVector<Value>();
117126
}
118127

119-
void createShadowValues(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils) const {
120-
121-
}
128+
void createShadowValues(Operation *op, OpBuilder &builder,
129+
MGradientUtilsReverse *gutils) const {}
122130
};
123131

124-
struct MulFOpInterfaceReverse : public ReverseAutoDiffOpInterface::ExternalModel<MulFOpInterfaceReverse, arith::MulFOp> {
125-
void createReverseModeAdjoint(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils, SmallVector<Value> caches) const {
132+
struct MulFOpInterfaceReverse
133+
: public ReverseAutoDiffOpInterface::ExternalModel<MulFOpInterfaceReverse,
134+
arith::MulFOp> {
135+
void createReverseModeAdjoint(Operation *op, OpBuilder &builder,
136+
MGradientUtilsReverse *gutils,
137+
SmallVector<Value> caches) const {
126138
auto mulOp = cast<arith::MulFOp>(op);
127139

128-
if(gutils->hasInvertPointer(mulOp)){
140+
if (gutils->hasInvertPointer(mulOp)) {
129141
Value own_gradient = gutils->invertPointerM(mulOp, builder);
130142
for (int i = 0; i < 2; i++) {
131143
if (!gutils->isConstantValue(mulOp.getOperand(i))) {
132144
Value cache = caches[i];
133145
Value retrievedValue = gutils->popCache(cache, builder);
134-
Value addedGradient = builder.create<arith::MulFOp>(mulOp.getLoc(), own_gradient, retrievedValue);
135-
146+
Value addedGradient = builder.create<arith::MulFOp>(
147+
mulOp.getLoc(), own_gradient, retrievedValue);
148+
136149
addToGradient(mulOp.getOperand(i), addedGradient, builder, gutils);
137150
}
138151
}
139152
}
140153
}
141154

142-
SmallVector<Value> cacheValues(Operation *op, MGradientUtilsReverse *gutils) const {
155+
SmallVector<Value> cacheValues(Operation *op,
156+
MGradientUtilsReverse *gutils) const {
143157
auto mulOp = cast<arith::MulFOp>(op);
144-
if(gutils->hasInvertPointer(mulOp)){
158+
if (gutils->hasInvertPointer(mulOp)) {
145159
OpBuilder cacheBuilder(gutils->getNewFromOriginal(op));
146160
SmallVector<Value> caches;
147161
for (int i = 0; i < 2; i++) {
148-
Value otherOperand = mulOp.getOperand((i+1)%2);
149-
Value cache = gutils->initAndPushCache(gutils->getNewFromOriginal(otherOperand), cacheBuilder);
162+
Value otherOperand = mulOp.getOperand((i + 1) % 2);
163+
Value cache = gutils->initAndPushCache(
164+
gutils->getNewFromOriginal(otherOperand), cacheBuilder);
150165
caches.push_back(cache);
151166
}
152167
return caches;
153168
}
154169
return SmallVector<Value>();
155170
}
156171

157-
void createShadowValues(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils) const {
158-
159-
}
172+
void createShadowValues(Operation *op, OpBuilder &builder,
173+
MGradientUtilsReverse *gutils) const {}
160174
};
161175

162-
}
176+
} // namespace
163177

164-
void mlir::enzyme::registerArithDialectAutoDiffInterface(DialectRegistry &registry) {
178+
void mlir::enzyme::registerArithDialectAutoDiffInterface(
179+
DialectRegistry &registry) {
165180
registry.addExtension(+[](MLIRContext *context, arith::ArithDialect *) {
166181
arith::AddFOp::attachInterface<AddFOpInterfaceReverse>(*context);
167182
arith::MulFOp::attachInterface<MulFOpInterfaceReverse>(*context);
168-
183+
169184
arith::AddFOp::attachInterface<AddFOpInterface>(*context);
170185
arith::MulFOp::attachInterface<MulFOpInterface>(*context);
171186
});

enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ class FloatTypeInterface
3434
loc, APFloat(fltType.getFloatSemantics(), 0), fltType);
3535
}
3636

37-
Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a, Value b) const {
37+
Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a,
38+
Value b) const {
3839
return builder.create<arith::AddFOp>(loc, a, b);
3940
}
4041

@@ -43,9 +44,7 @@ class FloatTypeInterface
4344
return self;
4445
}
4546

46-
bool requiresShadow(Type self) const{
47-
return false;
48-
}
47+
bool requiresShadow(Type self) const { return false; }
4948
};
5049
} // namespace
5150

enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ class PointerTypeInterface
7777
return builder.create<LLVM::NullOp>(loc, self);
7878
}
7979

80-
Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a, Value b) const {
80+
Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a,
81+
Value b) const {
8182
llvm_unreachable("TODO");
8283
}
8384

@@ -86,9 +87,7 @@ class PointerTypeInterface
8687
return self;
8788
}
8889

89-
bool requiresShadow(Type self) const{
90-
return true;
91-
}
90+
bool requiresShadow(Type self) const { return true; }
9291
};
9392
} // namespace
9493

0 commit comments

Comments
 (0)