@@ -80,92 +80,107 @@ struct AddFOpInterface
8080 }
8181};
8282
83- void addToGradient (Value oldGradient, Value addedGradient, OpBuilder & builder, MGradientUtilsReverse *gutils){
83+ void addToGradient (Value oldGradient, Value addedGradient, OpBuilder &builder,
84+ MGradientUtilsReverse *gutils) {
8485 Value gradient = addedGradient;
85- if (gutils->hasInvertPointer (oldGradient)){
86+ if (gutils->hasInvertPointer (oldGradient)) {
8687 Value operandGradient = gutils->invertPointerM (oldGradient, builder);
87- gradient = builder.create <arith::AddFOp>(oldGradient.getLoc (), operandGradient, addedGradient);
88+ gradient = builder.create <arith::AddFOp>(oldGradient.getLoc (),
89+ operandGradient, addedGradient);
8890 }
8991 gutils->mapInvertPointer (oldGradient, gradient, builder);
9092}
9193
92- void defaultClearGradient (Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils){
94+ void defaultClearGradient (Operation *op, OpBuilder &builder,
95+ MGradientUtilsReverse *gutils) {
9396 Value result = op->getOpResult (0 );
94- if (gutils->invertedPointersGlobal .contains (result)){
97+ if (gutils->invertedPointersGlobal .contains (result)) {
9598 FloatType floatType = result.getType ().cast <FloatType>();
9699 APFloat apf (floatType.getFloatSemantics (), 0 );
97100
98- Value gradient = builder.create <arith::ConstantFloatOp>(op->getLoc (), apf, floatType);
101+ Value gradient =
102+ builder.create <arith::ConstantFloatOp>(op->getLoc (), apf, floatType);
99103 gutils->mapInvertPointer (result, gradient, builder);
100104 }
101105}
102106
103- struct AddFOpInterfaceReverse : public ReverseAutoDiffOpInterface ::ExternalModel<AddFOpInterfaceReverse, arith::AddFOp> {
104- void createReverseModeAdjoint (Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils, SmallVector<Value> caches) const {
107+ struct AddFOpInterfaceReverse
108+ : public ReverseAutoDiffOpInterface::ExternalModel<AddFOpInterfaceReverse,
109+ arith::AddFOp> {
110+ void createReverseModeAdjoint (Operation *op, OpBuilder &builder,
111+ MGradientUtilsReverse *gutils,
112+ SmallVector<Value> caches) const {
105113 // Derivative of r = a + b -> dr = da + db
106114 auto addOp = cast<arith::AddFOp>(op);
107115
108- if (gutils->hasInvertPointer (addOp)){
116+ if (gutils->hasInvertPointer (addOp)) {
109117 Value addedGradient = gutils->invertPointerM (addOp, builder);
110118 addToGradient (addOp.getLhs (), addedGradient, builder, gutils);
111119 addToGradient (addOp.getRhs (), addedGradient, builder, gutils);
112120 }
113121 }
114122
115- SmallVector<Value> cacheValues (Operation *op, MGradientUtilsReverse *gutils) const {
123+ SmallVector<Value> cacheValues (Operation *op,
124+ MGradientUtilsReverse *gutils) const {
116125 return SmallVector<Value>();
117126 }
118127
119- void createShadowValues (Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils) const {
120-
121- }
128+ void createShadowValues (Operation *op, OpBuilder &builder,
129+ MGradientUtilsReverse *gutils) const {}
122130};
123131
124- struct MulFOpInterfaceReverse : public ReverseAutoDiffOpInterface ::ExternalModel<MulFOpInterfaceReverse, arith::MulFOp> {
125- void createReverseModeAdjoint (Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils, SmallVector<Value> caches) const {
132+ struct MulFOpInterfaceReverse
133+ : public ReverseAutoDiffOpInterface::ExternalModel<MulFOpInterfaceReverse,
134+ arith::MulFOp> {
135+ void createReverseModeAdjoint (Operation *op, OpBuilder &builder,
136+ MGradientUtilsReverse *gutils,
137+ SmallVector<Value> caches) const {
126138 auto mulOp = cast<arith::MulFOp>(op);
127139
128- if (gutils->hasInvertPointer (mulOp)){
140+ if (gutils->hasInvertPointer (mulOp)) {
129141 Value own_gradient = gutils->invertPointerM (mulOp, builder);
130142 for (int i = 0 ; i < 2 ; i++) {
131143 if (!gutils->isConstantValue (mulOp.getOperand (i))) {
132144 Value cache = caches[i];
133145 Value retrievedValue = gutils->popCache (cache, builder);
134- Value addedGradient = builder.create <arith::MulFOp>(mulOp.getLoc (), own_gradient, retrievedValue);
135-
146+ Value addedGradient = builder.create <arith::MulFOp>(
147+ mulOp.getLoc (), own_gradient, retrievedValue);
148+
136149 addToGradient (mulOp.getOperand (i), addedGradient, builder, gutils);
137150 }
138151 }
139152 }
140153 }
141154
142- SmallVector<Value> cacheValues (Operation *op, MGradientUtilsReverse *gutils) const {
155+ SmallVector<Value> cacheValues (Operation *op,
156+ MGradientUtilsReverse *gutils) const {
143157 auto mulOp = cast<arith::MulFOp>(op);
144- if (gutils->hasInvertPointer (mulOp)){
158+ if (gutils->hasInvertPointer (mulOp)) {
145159 OpBuilder cacheBuilder (gutils->getNewFromOriginal (op));
146160 SmallVector<Value> caches;
147161 for (int i = 0 ; i < 2 ; i++) {
148- Value otherOperand = mulOp.getOperand ((i+1 )%2 );
149- Value cache = gutils->initAndPushCache (gutils->getNewFromOriginal (otherOperand), cacheBuilder);
162+ Value otherOperand = mulOp.getOperand ((i + 1 ) % 2 );
163+ Value cache = gutils->initAndPushCache (
164+ gutils->getNewFromOriginal (otherOperand), cacheBuilder);
150165 caches.push_back (cache);
151166 }
152167 return caches;
153168 }
154169 return SmallVector<Value>();
155170 }
156171
157- void createShadowValues (Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils) const {
158-
159- }
172+ void createShadowValues (Operation *op, OpBuilder &builder,
173+ MGradientUtilsReverse *gutils) const {}
160174};
161175
162- }
176+ } // namespace
163177
164- void mlir::enzyme::registerArithDialectAutoDiffInterface (DialectRegistry ®istry) {
178+ void mlir::enzyme::registerArithDialectAutoDiffInterface (
179+ DialectRegistry ®istry) {
165180 registry.addExtension (+[](MLIRContext *context, arith::ArithDialect *) {
166181 arith::AddFOp::attachInterface<AddFOpInterfaceReverse>(*context);
167182 arith::MulFOp::attachInterface<MulFOpInterfaceReverse>(*context);
168-
183+
169184 arith::AddFOp::attachInterface<AddFOpInterface>(*context);
170185 arith::MulFOp::attachInterface<MulFOpInterface>(*context);
171186 });
0 commit comments