1- #include " Interfaces/GradientUtils.h"
2- #include " Interfaces/GradientUtilsReverse.h"
31#include " Dialect/Ops.h"
42#include " Interfaces/AutoDiffOpInterface.h"
53#include " Interfaces/AutoDiffTypeInterface.h"
1715#include " llvm/ADT/BreadthFirstIterator.h"
1816#include " mlir/IR/Dominance.h"
1917
20- #include " GradientUtils.h"
2118#include " EnzymeLogic.h"
19+ #include " Interfaces/GradientUtils.h"
20+ #include " Interfaces/GradientUtilsReverse.h"
2221
2322using namespace mlir ;
2423using namespace mlir ::enzyme;
2524
26- SmallVector<mlir::Block*> getDominatorToposort (MGradientUtilsReverse *gutils){
25+ SmallVector<mlir::Block*> MEnzymeLogic:: getDominatorToposort (MGradientUtilsReverse *gutils, Region& region ){
2726 SmallVector<mlir::Block*> dominatorToposortBlocks;
28- if (gutils-> oldFunc . getFunctionBody () .hasOneBlock ()){
29- dominatorToposortBlocks.push_back (&*(gutils-> oldFunc . getFunctionBody () .begin ()));
27+ if (region .hasOneBlock ()){
28+ dominatorToposortBlocks.push_back (&*(region .begin ()));
3029 }
3130 else {
3231 auto dInfo = mlir::detail::DominanceInfoBase<false >(nullptr );
3332 llvm::DominatorTreeBase<Block, false > & dt = dInfo.getDomTree (&(gutils->oldFunc .getFunctionBody ()));
34- auto root = dt.getNode (&*(gutils-> oldFunc . getFunctionBody () .begin ()));
33+ auto root = dt.getNode (&*(region .begin ()));
3534
3635 for (llvm::DomTreeNodeBase<mlir::Block> * node : llvm::breadth_first (root)){
3736 dominatorToposortBlocks.push_back (node->getBlock ());
@@ -41,28 +40,43 @@ SmallVector<mlir::Block*> getDominatorToposort(MGradientUtilsReverse *gutils){
4140 return dominatorToposortBlocks;
4241}
4342
44- void mapInvertArguments (Block * oBB, Block * reverseBB, MDiffeGradientUtilsReverse * gutils){
43+ void MEnzymeLogic:: mapInvertArguments (Block * oBB, Block * reverseBB, MGradientUtilsReverse * gutils){
4544 for (int i = 0 ; i < (int )gutils->mapBlockArguments [oBB].size (); i++){
4645 auto x = gutils->mapBlockArguments [oBB][i];
4746 OpBuilder builder (reverseBB, reverseBB->begin ());
4847 gutils->mapInvertPointer (x.second , reverseBB->getArgument (i), builder);
4948 }
5049}
5150
52- void handleReturns (Block * oBB, Block * newBB, Block * reverseBB, MDiffeGradientUtilsReverse * gutils){
51+ void MEnzymeLogic:: handleReturns (Block * oBB, Block * newBB, Block * reverseBB, MGradientUtilsReverse * gutils, bool parentRegion ){
5352 if (oBB->getNumSuccessors () == 0 ){
54- Operation * returnStatement = newBB->getTerminator ();
55- gutils->erase (returnStatement);
53+ if (parentRegion){
54+ Operation * returnStatement = newBB->getTerminator ();
55+ gutils->erase (returnStatement);
5656
57- OpBuilder forwardToBackwardBuilder (newBB, newBB->end ());
58- gutils->mapInvertPointer (oBB->getTerminator ()->getOperand (0 ), gutils->newFunc .getArgument (gutils->newFunc .getNumArguments () - 1 ), forwardToBackwardBuilder); // TODO handle multiple return values
59- Operation * newBranchOp = forwardToBackwardBuilder.create <cf::BranchOp>(oBB->getTerminator ()->getLoc (), reverseBB);
60-
61- gutils->originalToNewFnOps [oBB->getTerminator ()] = newBranchOp;
57+ OpBuilder forwardToBackwardBuilder (newBB, newBB->end ());
58+ gutils->mapInvertPointer (oBB->getTerminator ()->getOperand (0 ), gutils->newFunc .getArgument (gutils->newFunc .getNumArguments () - 1 ), forwardToBackwardBuilder); // TODO handle multiple return values
59+ Operation * newBranchOp = forwardToBackwardBuilder.create <cf::BranchOp>(oBB->getTerminator ()->getLoc (), reverseBB);
60+
61+ gutils->originalToNewFnOps [oBB->getTerminator ()] = newBranchOp;
62+ }
63+ else {
64+ Operation * terminator = oBB->getTerminator ();
65+ OpBuilder builder (reverseBB, reverseBB->begin ());
66+
67+ int i = 0 ;
68+ for (OpOperand & operand : terminator->getOpOperands ()){
69+ Value val = operand.get ();
70+ if (auto iface = val.getType ().dyn_cast <AutoDiffTypeInterface>()) {
71+ gutils->mapInvertPointer (val, reverseBB->getArgument (i), builder);
72+ i++;
73+ }
74+ }
75+ }
6276 }
6377}
6478
65- bool visitChildCustom (Operation * op, OpBuilder &builder, MDiffeGradientUtilsReverse * gutils){
79+ bool MEnzymeLogic:: visitChildCustom (Operation * op, OpBuilder &builder, MGradientUtilsReverse * gutils){
6680 std::string nameDiffe = " diffe_" + op->getName ().getDialectNamespace ().str () + " _" + op->getName ().stripDialect ().str ();
6781 std::string nameStore = " store_" + op->getName ().getDialectNamespace ().str () + " _" + op->getName ().stripDialect ().str ();
6882
@@ -128,9 +142,9 @@ bool visitChildCustom(Operation * op, OpBuilder &builder, MDiffeGradientUtilsRev
128142/*
129143Create reverse mode adjoint for an operation.
130144*/
131- void visitChild (Operation * op, OpBuilder &builder, MDiffeGradientUtilsReverse * gutils){
145+ void MEnzymeLogic:: visitChild (Operation * op, OpBuilder &builder, MGradientUtilsReverse * gutils){
132146 if (auto ifaceOp = dyn_cast<ReverseAutoDiffOpInterface>(op)) {
133- ValueRange caches = ifaceOp.cacheValues (gutils);
147+ SmallVector<Value> caches = ifaceOp.cacheValues (gutils);
134148 ifaceOp.createReverseModeAdjoint (builder, gutils, caches);
135149
136150 for (int indexResult = 0 ; indexResult < (int )op->getNumResults (); indexResult++){
@@ -140,7 +154,7 @@ void visitChild(Operation * op, OpBuilder &builder, MDiffeGradientUtilsReverse *
140154 }
141155}
142156
143- void visitChildren (Block * oBB, Block * reverseBB, MDiffeGradientUtilsReverse * gutils){
157+ void MEnzymeLogic:: visitChildren (Block * oBB, Block * reverseBB, MGradientUtilsReverse * gutils){
144158 OpBuilder revBuilder (reverseBB, reverseBB->end ());
145159 if (!oBB->empty ()){
146160 auto first = oBB->rbegin ();
@@ -155,21 +169,22 @@ void visitChildren(Block * oBB, Block * reverseBB, MDiffeGradientUtilsReverse *
155169 }
156170}
157171
158- void handlePredecessors (Block * oBB, Block * reverseBB, MDiffeGradientUtilsReverse * gutils){
172+ void MEnzymeLogic:: handlePredecessors (Block * oBB, Block * reverseBB, MGradientUtilsReverse * gutils, buildReturnFunction buildRetrunOp ){
159173 OpBuilder revBuilder (reverseBB, reverseBB->end ());
160174 if (oBB->hasNoPredecessors ()){
161- SmallVector<mlir::Value, 2 > retargs;
175+ SmallVector<mlir::Value> retargs;
162176 for (Value attribute : gutils->oldFunc .getFunctionBody ().getArguments ()) {
163177 Value attributeGradient = gutils->invertPointerM (attribute, revBuilder);
164178 retargs.push_back (attributeGradient);
165179 }
166- revBuilder.create <func::ReturnOp>(oBB->rbegin ()->getLoc (), retargs);
180+ buildRetrunOp (revBuilder, oBB->rbegin ()->getLoc (), retargs);
181+ // revBuilder.create<func::ReturnOp>(oBB->rbegin()->getLoc(), retargs);
167182 }
168183 else {
169184 SmallVector<Block *> blocks;
170185 SmallVector<APInt> indices;
171186 SmallVector<ValueRange> arguments;
172- ValueRange defaultArguments;
187+ SmallVector<Value> defaultArguments;
173188 Block * defaultBlock;
174189 int i = 1 ;
175190 for (Block * predecessor : oBB->getPredecessors ()){
@@ -197,11 +212,11 @@ void handlePredecessors(Block * oBB, Block * reverseBB, MDiffeGradientUtilsRever
197212 if (predecessor != *(oBB->getPredecessors ().begin ())){
198213 blocks.push_back (predecessorRevMode);
199214 indices.push_back (APInt (32 , i++));
200- arguments.push_back (ValueRange ( operands) );
215+ arguments.push_back (operands);
201216 }
202217 else {
203218 defaultBlock = predecessorRevMode;
204- defaultArguments = ValueRange ( operands) ;
219+ defaultArguments = operands;
205220 }
206221 }
207222 // Remove Dependency to CF dialect
@@ -227,7 +242,7 @@ void handlePredecessors(Block * oBB, Block * reverseBB, MDiffeGradientUtilsRever
227242 }
228243}
229244
230- void initializeShadowValues (SmallVector<mlir::Block*>& dominatorToposortBlocks, MDiffeGradientUtilsReverse * gutils){
245+ void MEnzymeLogic:: initializeShadowValues (SmallVector<mlir::Block*>& dominatorToposortBlocks, MGradientUtilsReverse * gutils){
231246 for (auto it = dominatorToposortBlocks.begin (); it != dominatorToposortBlocks.end (); ++it){
232247 Block * oBB = *it;
233248
@@ -245,19 +260,10 @@ void initializeShadowValues(SmallVector<mlir::Block*>& dominatorToposortBlocks,
245260 }
246261}
247262
263+ void MEnzymeLogic::differentiate (MGradientUtilsReverse * gutils, Region & oldRegion, Region & newRegion, bool parentRegion, buildReturnFunction buildFuncRetrunOp){
264+ gutils->createReverseModeBlocks (oldRegion, newRegion, parentRegion);
248265
249- FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateReverseDiff (FunctionOpInterface fn, DIFFE_TYPE retType, std::vector<DIFFE_TYPE> constants, MTypeAnalysis &TA, bool returnUsed, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector<bool > volatile_args, void *augmented, SymbolTableCollection &symbolTable) {
250-
251- if (fn.getFunctionBody ().empty ()) {
252- llvm::errs () << fn << " \n " ;
253- llvm_unreachable (" Differentiating empty function" );
254- }
255-
256- ReturnType returnValue = ReturnType::Tape;
257- MDiffeGradientUtilsReverse * gutils = MDiffeGradientUtilsReverse::CreateFromClone (*this , mode, width, fn, TA, type_args, retType, /* diffeReturnArg*/ true , constants, returnValue, addedType, symbolTable);
258-
259- SmallVector<mlir::Block*> dominatorToposortBlocks = getDominatorToposort (gutils);
260-
266+ SmallVector<mlir::Block*> dominatorToposortBlocks = getDominatorToposort (gutils, oldRegion);
261267 initializeShadowValues (dominatorToposortBlocks, gutils);
262268
263269 for (auto it = dominatorToposortBlocks.rbegin (); it != dominatorToposortBlocks.rend (); ++it){
@@ -266,19 +272,38 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateReverseDiff(FunctionOpInte
266272 Block * reverseBB = gutils->mapReverseModeBlocks .lookupOrNull (oBB);
267273
268274 mapInvertArguments (oBB, reverseBB, gutils);
269-
270- handleReturns (oBB, newBB, reverseBB, gutils);
271-
275+ handleReturns (oBB, newBB, reverseBB, gutils, parentRegion);
272276 visitChildren (oBB, reverseBB, gutils);
273-
274- handlePredecessors (oBB, reverseBB, gutils);
277+ handlePredecessors (oBB, reverseBB, gutils, buildFuncRetrunOp);
278+ }
279+ }
280+
281+ FunctionOpInterface MEnzymeLogic::CreateReverseDiff (FunctionOpInterface fn, DIFFE_TYPE retType, std::vector<DIFFE_TYPE> constants, MTypeAnalysis &TA, bool returnUsed, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector<bool > volatile_args, void *augmented, SymbolTableCollection &symbolTable) {
282+
283+ if (fn.getFunctionBody ().empty ()) {
284+ llvm::errs () << fn << " \n " ;
285+ llvm_unreachable (" Differentiating empty function" );
275286 }
276287
288+ ReturnType returnValue = ReturnType::Tape;
289+ MGradientUtilsReverse * gutils = MGradientUtilsReverse::CreateFromClone (*this , mode, width, fn, TA, type_args, retType, /* diffeReturnArg*/ true , constants, returnValue, addedType, symbolTable);
290+
291+ Region & oldRegion = gutils->oldFunc .getFunctionBody ();
292+ Region & newRegion = gutils->newFunc .getFunctionBody ();
293+
294+ buildReturnFunction buildFuncRetrunOp = [](OpBuilder& builder, Location loc, SmallVector<Value> retargs){
295+ builder.create <func::ReturnOp>(loc, retargs);
296+ return ;
297+ };
298+
299+ differentiate (gutils, oldRegion, newRegion, true , buildFuncRetrunOp);
300+
277301 auto nf = gutils->newFunc ;
278302
279- // llvm::errs() << "nf\n";
280- // nf.dump();
281- // llvm::errs() << "nf end\n";
303+ llvm::errs () << " nf\n " ;
304+ nf.dump ();
305+ llvm::errs () << " nf end\n " ;
306+
282307 delete gutils;
283308 return nf;
284309}
0 commit comments