1- #include " Interfaces/GradientUtils.h"
2- #include " Interfaces/GradientUtilsReverse.h"
31#include " Dialect/Ops.h"
42#include " Interfaces/AutoDiffOpInterface.h"
53#include " Interfaces/AutoDiffTypeInterface.h"
1715#include " llvm/ADT/BreadthFirstIterator.h"
1816#include " mlir/IR/Dominance.h"
1917
20- #include " GradientUtils.h"
2118#include " EnzymeLogic.h"
19+ #include " Interfaces/GradientUtils.h"
20+ #include " Interfaces/GradientUtilsReverse.h"
2221
2322using namespace mlir ;
2423using namespace mlir ::enzyme;
2524
26- SmallVector<mlir::Block*> getDominatorToposort (MGradientUtilsReverse *gutils){
25+ SmallVector<mlir::Block*> MEnzymeLogic:: getDominatorToposort (MGradientUtilsReverse *gutils, Region& region ){
2726 SmallVector<mlir::Block*> dominatorToposortBlocks;
28- if (gutils-> oldFunc . getFunctionBody () .hasOneBlock ()){
29- dominatorToposortBlocks.push_back (&*(gutils-> oldFunc . getFunctionBody () .begin ()));
27+ if (region .hasOneBlock ()){
28+ dominatorToposortBlocks.push_back (&*(region .begin ()));
3029 }
3130 else {
3231 auto dInfo = mlir::detail::DominanceInfoBase<false >(nullptr );
3332 llvm::DominatorTreeBase<Block, false > & dt = dInfo.getDomTree (&(gutils->oldFunc .getFunctionBody ()));
34- auto root = dt.getNode (&*(gutils-> oldFunc . getFunctionBody () .begin ()));
33+ auto root = dt.getNode (&*(region .begin ()));
3534
3635 for (llvm::DomTreeNodeBase<mlir::Block> * node : llvm::breadth_first (root)){
3736 dominatorToposortBlocks.push_back (node->getBlock ());
@@ -41,28 +40,43 @@ SmallVector<mlir::Block*> getDominatorToposort(MGradientUtilsReverse *gutils){
4140 return dominatorToposortBlocks;
4241}
4342
44- void mapInvertArguments (Block * oBB, Block * reverseBB, MDiffeGradientUtilsReverse * gutils){
43+ void MEnzymeLogic:: mapInvertArguments (Block * oBB, Block * reverseBB, MGradientUtilsReverse * gutils){
4544 for (int i = 0 ; i < (int )gutils->mapBlockArguments [oBB].size (); i++){
4645 auto x = gutils->mapBlockArguments [oBB][i];
4746 OpBuilder builder (reverseBB, reverseBB->begin ());
4847 gutils->mapInvertPointer (x.second , reverseBB->getArgument (i), builder);
4948 }
5049}
5150
52- void handleReturns (Block * oBB, Block * newBB, Block * reverseBB, MDiffeGradientUtilsReverse * gutils){
51+ void MEnzymeLogic:: handleReturns (Block * oBB, Block * newBB, Block * reverseBB, MGradientUtilsReverse * gutils, bool parentRegion ){
5352 if (oBB->getNumSuccessors () == 0 ){
54- Operation * returnStatement = newBB->getTerminator ();
55- gutils->erase (returnStatement);
53+ if (parentRegion){
54+ Operation * returnStatement = newBB->getTerminator ();
55+ gutils->erase (returnStatement);
5656
57- OpBuilder forwardToBackwardBuilder (newBB, newBB->end ());
58- gutils->mapInvertPointer (oBB->getTerminator ()->getOperand (0 ), gutils->newFunc .getArgument (gutils->newFunc .getNumArguments () - 1 ), forwardToBackwardBuilder); // TODO handle multiple return values
59- Operation * newBranchOp = forwardToBackwardBuilder.create <cf::BranchOp>(oBB->getTerminator ()->getLoc (), reverseBB);
60-
61- gutils->originalToNewFnOps [oBB->getTerminator ()] = newBranchOp;
57+ OpBuilder forwardToBackwardBuilder (newBB, newBB->end ());
58+ gutils->mapInvertPointer (oBB->getTerminator ()->getOperand (0 ), gutils->newFunc .getArgument (gutils->newFunc .getNumArguments () - 1 ), forwardToBackwardBuilder); // TODO handle multiple return values
59+ Operation * newBranchOp = forwardToBackwardBuilder.create <cf::BranchOp>(oBB->getTerminator ()->getLoc (), reverseBB);
60+
61+ gutils->originalToNewFnOps [oBB->getTerminator ()] = newBranchOp;
62+ }
63+ else {
64+ Operation * terminator = oBB->getTerminator ();
65+ OpBuilder builder (reverseBB, reverseBB->begin ());
66+
67+ int i = 0 ;
68+ for (OpOperand & operand : terminator->getOpOperands ()){
69+ Value val = operand.get ();
70+ if (auto iface = val.getType ().dyn_cast <AutoDiffTypeInterface>()) {
71+ gutils->mapInvertPointer (val, reverseBB->getArgument (i), builder);
72+ i++;
73+ }
74+ }
75+ }
6276 }
6377}
6478
65- bool visitChildCustom (Operation * op, OpBuilder &builder, MDiffeGradientUtilsReverse * gutils){
79+ bool MEnzymeLogic:: visitChildCustom (Operation * op, OpBuilder &builder, MGradientUtilsReverse * gutils){
6680 std::string nameDiffe = " diffe_" + op->getName ().getDialectNamespace ().str () + " _" + op->getName ().stripDialect ().str ();
6781 std::string nameStore = " store_" + op->getName ().getDialectNamespace ().str () + " _" + op->getName ().stripDialect ().str ();
6882
@@ -128,9 +142,9 @@ bool visitChildCustom(Operation * op, OpBuilder &builder, MDiffeGradientUtilsRev
128142/*
129143Create reverse mode adjoint for an operation.
130144*/
131- void visitChild (Operation * op, OpBuilder &builder, MDiffeGradientUtilsReverse * gutils){
145+ void MEnzymeLogic:: visitChild (Operation * op, OpBuilder &builder, MGradientUtilsReverse * gutils){
132146 if (auto ifaceOp = dyn_cast<ReverseAutoDiffOpInterface>(op)) {
133- ValueRange caches = ifaceOp.cacheValues (gutils);
147+ SmallVector<Value> caches = ifaceOp.cacheValues (gutils);
134148 ifaceOp.createReverseModeAdjoint (builder, gutils, caches);
135149
136150 for (int indexResult = 0 ; indexResult < (int )op->getNumResults (); indexResult++){
@@ -140,7 +154,7 @@ void visitChild(Operation * op, OpBuilder &builder, MDiffeGradientUtilsReverse *
140154 }
141155}
142156
143- void visitChildren (Block * oBB, Block * reverseBB, MDiffeGradientUtilsReverse * gutils){
157+ void MEnzymeLogic:: visitChildren (Block * oBB, Block * reverseBB, MGradientUtilsReverse * gutils){
144158 OpBuilder revBuilder (reverseBB, reverseBB->end ());
145159 if (!oBB->empty ()){
146160 auto first = oBB->rbegin ();
@@ -155,21 +169,22 @@ void visitChildren(Block * oBB, Block * reverseBB, MDiffeGradientUtilsReverse *
155169 }
156170}
157171
158- void handlePredecessors (Block * oBB, Block * reverseBB, MDiffeGradientUtilsReverse * gutils){
172+ void MEnzymeLogic:: handlePredecessors (Block * oBB, Block * newBB, Block * reverseBB, MGradientUtilsReverse * gutils, buildReturnFunction buildRetrunOp ){
159173 OpBuilder revBuilder (reverseBB, reverseBB->end ());
160174 if (oBB->hasNoPredecessors ()){
161- SmallVector<mlir::Value, 2 > retargs;
175+ SmallVector<mlir::Value> retargs;
162176 for (Value attribute : gutils->oldFunc .getFunctionBody ().getArguments ()) {
163177 Value attributeGradient = gutils->invertPointerM (attribute, revBuilder);
164178 retargs.push_back (attributeGradient);
165179 }
166- revBuilder.create <func::ReturnOp>(oBB->rbegin ()->getLoc (), retargs);
180+ buildRetrunOp (revBuilder, oBB->rbegin ()->getLoc (), retargs);
181+ // revBuilder.create<func::ReturnOp>(oBB->rbegin()->getLoc(), retargs);
167182 }
168183 else {
169184 SmallVector<Block *> blocks;
170185 SmallVector<APInt> indices;
171186 SmallVector<ValueRange> arguments;
172- ValueRange defaultArguments;
187+ SmallVector<Value> defaultArguments;
173188 Block * defaultBlock;
174189 int i = 1 ;
175190 for (Block * predecessor : oBB->getPredecessors ()){
@@ -188,7 +203,7 @@ void handlePredecessors(Block * oBB, Block * reverseBB, MDiffeGradientUtilsRever
188203 operands.push_back (nullValue);
189204 }
190205 else {
191- llvm_unreachable (" non canonial null value found" );
206+ llvm_unreachable (" no canonial null value found" );
192207 }
193208 }
194209 }
@@ -197,37 +212,55 @@ void handlePredecessors(Block * oBB, Block * reverseBB, MDiffeGradientUtilsRever
197212 if (predecessor != *(oBB->getPredecessors ().begin ())){
198213 blocks.push_back (predecessorRevMode);
199214 indices.push_back (APInt (32 , i++));
200- arguments.push_back (ValueRange ( operands) );
215+ arguments.push_back (operands);
201216 }
202217 else {
203218 defaultBlock = predecessorRevMode;
204- defaultArguments = ValueRange ( operands) ;
219+ defaultArguments = operands;
205220 }
206221 }
222+ Location loc = oBB->rbegin ()->getLoc ();
207223 // Remove Dependency to CF dialect
208224 if (std::next (oBB->getPredecessors ().begin ()) == oBB->getPredecessors ().end ()){
209225 // If there is only one block we can directly create a branch for simplicity sake
210- revBuilder.create <cf::BranchOp>(gutils-> getNewFromOriginal (&*(oBB-> rbegin ()))-> getLoc () , defaultBlock, defaultArguments);
226+ revBuilder.create <cf::BranchOp>(loc , defaultBlock, defaultArguments);
211227 }
212228 else {
213229 Value cache = gutils->insertInit (gutils->getIndexCacheType ());
214- Value flag = revBuilder.create <enzyme::PopOp>(oBB-> rbegin ()-> getLoc () , gutils->getIndexType (), cache);
230+ Value flag = revBuilder.create <enzyme::PopOp>(loc , gutils->getIndexType (), cache);
215231
216- revBuilder.create <cf::SwitchOp>(oBB-> rbegin ()-> getLoc () , flag, defaultBlock, defaultArguments, ArrayRef<APInt>(indices), ArrayRef<Block *>(blocks), ArrayRef<ValueRange>(arguments));
232+ revBuilder.create <cf::SwitchOp>(loc , flag, defaultBlock, defaultArguments, ArrayRef<APInt>(indices), ArrayRef<Block *>(blocks), ArrayRef<ValueRange>(arguments));
217233
234+ Value origin = newBB->addArgument (gutils->getIndexType (), loc);
235+
236+ OpBuilder newBuilder (newBB, newBB->begin ());
237+ newBuilder.create <enzyme::PushOp>(loc, cache, origin);
238+
218239 int j = 0 ;
219240 for (Block * predecessor : oBB->getPredecessors ()){
220241 Block * newPredecessor = gutils->getNewFromOriginal (predecessor);
221- OpBuilder predecessorBuilder (newPredecessor, std::prev (newPredecessor->end ()));
222242
223- Value indicator = predecessorBuilder.create <arith::ConstantIntOp>(oBB->rbegin ()->getLoc (), j++, 32 );
224- predecessorBuilder.create <enzyme::PushOp>(oBB->rbegin ()->getLoc (), cache, indicator);
243+ OpBuilder predecessorBuilder (newPredecessor, std::prev (newPredecessor->end ()));
244+ Value indicator = predecessorBuilder.create <arith::ConstantIntOp>(loc, j++, 32 );
245+
246+ Operation * terminator = newPredecessor->getTerminator ();
247+ if (auto binst = dyn_cast<BranchOpInterface>(terminator)) {
248+ for (int i = 0 ; i < terminator->getNumSuccessors (); i++){
249+ if (terminator->getSuccessor (i) == newBB){
250+ SuccessorOperands sOps = binst.getSuccessorOperands (i);
251+ sOps .append (indicator);
252+ }
253+ }
254+ }
255+ else {
256+ llvm_unreachable (" invalid terminator" );
257+ }
225258 }
226259 }
227260 }
228261}
229262
230- void initializeShadowValues (SmallVector<mlir::Block*>& dominatorToposortBlocks, MDiffeGradientUtilsReverse * gutils){
263+ void MEnzymeLogic:: initializeShadowValues (SmallVector<mlir::Block*>& dominatorToposortBlocks, MGradientUtilsReverse * gutils){
231264 for (auto it = dominatorToposortBlocks.begin (); it != dominatorToposortBlocks.end (); ++it){
232265 Block * oBB = *it;
233266
@@ -245,19 +278,10 @@ void initializeShadowValues(SmallVector<mlir::Block*>& dominatorToposortBlocks,
245278 }
246279}
247280
281+ void MEnzymeLogic::differentiate (MGradientUtilsReverse * gutils, Region & oldRegion, Region & newRegion, bool parentRegion, buildReturnFunction buildFuncRetrunOp){
282+ gutils->createReverseModeBlocks (oldRegion, newRegion, parentRegion);
248283
249- FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateReverseDiff (FunctionOpInterface fn, DIFFE_TYPE retType, std::vector<DIFFE_TYPE> constants, MTypeAnalysis &TA, bool returnUsed, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector<bool > volatile_args, void *augmented, SymbolTableCollection &symbolTable) {
250-
251- if (fn.getFunctionBody ().empty ()) {
252- llvm::errs () << fn << " \n " ;
253- llvm_unreachable (" Differentiating empty function" );
254- }
255-
256- ReturnType returnValue = ReturnType::Tape;
257- MDiffeGradientUtilsReverse * gutils = MDiffeGradientUtilsReverse::CreateFromClone (*this , mode, width, fn, TA, type_args, retType, /* diffeReturnArg*/ true , constants, returnValue, addedType, symbolTable);
258-
259- SmallVector<mlir::Block*> dominatorToposortBlocks = getDominatorToposort (gutils);
260-
284+ SmallVector<mlir::Block*> dominatorToposortBlocks = getDominatorToposort (gutils, oldRegion);
261285 initializeShadowValues (dominatorToposortBlocks, gutils);
262286
263287 for (auto it = dominatorToposortBlocks.rbegin (); it != dominatorToposortBlocks.rend (); ++it){
@@ -266,19 +290,38 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateReverseDiff(FunctionOpInte
266290 Block * reverseBB = gutils->mapReverseModeBlocks .lookupOrNull (oBB);
267291
268292 mapInvertArguments (oBB, reverseBB, gutils);
269-
270- handleReturns (oBB, newBB, reverseBB, gutils);
271-
293+ handleReturns (oBB, newBB, reverseBB, gutils, parentRegion);
272294 visitChildren (oBB, reverseBB, gutils);
273-
274- handlePredecessors (oBB, reverseBB, gutils);
295+ handlePredecessors (oBB, newBB, reverseBB, gutils, buildFuncRetrunOp);
296+ }
297+ }
298+
299+ FunctionOpInterface MEnzymeLogic::CreateReverseDiff (FunctionOpInterface fn, DIFFE_TYPE retType, std::vector<DIFFE_TYPE> constants, MTypeAnalysis &TA, bool returnUsed, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector<bool > volatile_args, void *augmented, SymbolTableCollection &symbolTable) {
300+
301+ if (fn.getFunctionBody ().empty ()) {
302+ llvm::errs () << fn << " \n " ;
303+ llvm_unreachable (" Differentiating empty function" );
275304 }
276305
306+ ReturnType returnValue = ReturnType::Tape;
307+ MGradientUtilsReverse * gutils = MGradientUtilsReverse::CreateFromClone (*this , mode, width, fn, TA, type_args, retType, /* diffeReturnArg*/ true , constants, returnValue, addedType, symbolTable);
308+
309+ Region & oldRegion = gutils->oldFunc .getFunctionBody ();
310+ Region & newRegion = gutils->newFunc .getFunctionBody ();
311+
312+ buildReturnFunction buildFuncRetrunOp = [](OpBuilder& builder, Location loc, SmallVector<Value> retargs){
313+ builder.create <func::ReturnOp>(loc, retargs);
314+ return ;
315+ };
316+
317+ differentiate (gutils, oldRegion, newRegion, true , buildFuncRetrunOp);
318+
277319 auto nf = gutils->newFunc ;
278320
279321 // llvm::errs() << "nf\n";
280322 // nf.dump();
281323 // llvm::errs() << "nf end\n";
324+
282325 delete gutils;
283326 return nf;
284327}
0 commit comments