@@ -169,7 +169,7 @@ void MEnzymeLogic::visitChildren(Block * oBB, Block * reverseBB, MGradientUtilsR
169169 }
170170}
171171
172- void MEnzymeLogic::handlePredecessors (Block * oBB, Block * reverseBB, MGradientUtilsReverse * gutils, buildReturnFunction buildRetrunOp){
172+ void MEnzymeLogic::handlePredecessors (Block * oBB, Block * newBB, Block * reverseBB, MGradientUtilsReverse * gutils, buildReturnFunction buildRetrunOp){
173173 OpBuilder revBuilder (reverseBB, reverseBB->end ());
174174 if (oBB->hasNoPredecessors ()){
175175 SmallVector<mlir::Value> retargs;
@@ -203,7 +203,7 @@ void MEnzymeLogic::handlePredecessors(Block * oBB, Block * reverseBB, MGradientU
203203 operands.push_back (nullValue);
204204 }
205205 else {
206- llvm_unreachable (" non canonial null value found" );
206+ llvm_unreachable (" no canonial null value found" );
207207 }
208208 }
209209 }
@@ -219,24 +219,42 @@ void MEnzymeLogic::handlePredecessors(Block * oBB, Block * reverseBB, MGradientU
219219 defaultArguments = operands;
220220 }
221221 }
222+ Location loc = oBB->rbegin ()->getLoc ();
222223 // Remove Dependency to CF dialect
223224 if (std::next (oBB->getPredecessors ().begin ()) == oBB->getPredecessors ().end ()){
224225 // If there is only one block we can directly create a branch for simplicity sake
225- revBuilder.create <cf::BranchOp>(gutils-> getNewFromOriginal (&*(oBB-> rbegin ()))-> getLoc () , defaultBlock, defaultArguments);
226+ revBuilder.create <cf::BranchOp>(loc , defaultBlock, defaultArguments);
226227 }
227228 else {
228229 Value cache = gutils->insertInit (gutils->getIndexCacheType ());
229- Value flag = revBuilder.create <enzyme::PopOp>(oBB-> rbegin ()-> getLoc () , gutils->getIndexType (), cache);
230+ Value flag = revBuilder.create <enzyme::PopOp>(loc , gutils->getIndexType (), cache);
230231
231- revBuilder.create <cf::SwitchOp>(oBB-> rbegin ()-> getLoc () , flag, defaultBlock, defaultArguments, ArrayRef<APInt>(indices), ArrayRef<Block *>(blocks), ArrayRef<ValueRange>(arguments));
232+ revBuilder.create <cf::SwitchOp>(loc , flag, defaultBlock, defaultArguments, ArrayRef<APInt>(indices), ArrayRef<Block *>(blocks), ArrayRef<ValueRange>(arguments));
232233
234+ Value origin = newBB->addArgument (gutils->getIndexType (), loc);
235+
236+ OpBuilder newBuilder (newBB, newBB->begin ());
237+ newBuilder.create <enzyme::PushOp>(loc, cache, origin);
238+
233239 int j = 0 ;
234240 for (Block * predecessor : oBB->getPredecessors ()){
235241 Block * newPredecessor = gutils->getNewFromOriginal (predecessor);
236- OpBuilder predecessorBuilder (newPredecessor, std::prev (newPredecessor->end ()));
237242
238- Value indicator = predecessorBuilder.create <arith::ConstantIntOp>(oBB->rbegin ()->getLoc (), j++, 32 );
239- predecessorBuilder.create <enzyme::PushOp>(oBB->rbegin ()->getLoc (), cache, indicator);
243+ OpBuilder predecessorBuilder (newPredecessor, std::prev (newPredecessor->end ()));
244+ Value indicator = predecessorBuilder.create <arith::ConstantIntOp>(loc, j++, 32 );
245+
246+ Operation * terminator = newPredecessor->getTerminator ();
247+ if (auto binst = dyn_cast<BranchOpInterface>(terminator)) {
248+ for (int i = 0 ; i < terminator->getNumSuccessors (); i++){
249+ if (terminator->getSuccessor (i) == newBB){
250+ SuccessorOperands sOps = binst.getSuccessorOperands (i);
251+ sOps .append (indicator);
252+ }
253+ }
254+ }
255+ else {
256+ llvm_unreachable (" invalid terminator" );
257+ }
240258 }
241259 }
242260 }
@@ -274,7 +292,7 @@ void MEnzymeLogic::differentiate(MGradientUtilsReverse * gutils, Region & oldReg
274292 mapInvertArguments (oBB, reverseBB, gutils);
275293 handleReturns (oBB, newBB, reverseBB, gutils, parentRegion);
276294 visitChildren (oBB, reverseBB, gutils);
277- handlePredecessors (oBB, reverseBB, gutils, buildFuncRetrunOp);
295+ handlePredecessors (oBB, newBB, reverseBB, gutils, buildFuncRetrunOp);
278296 }
279297}
280298
@@ -300,9 +318,9 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(FunctionOpInterface fn, DIFF
300318
301319 auto nf = gutils->newFunc ;
302320
303- llvm::errs () << " nf\n " ;
304- nf.dump ();
305- llvm::errs () << " nf end\n " ;
321+ // llvm::errs() << "nf\n";
322+ // nf.dump();
323+ // llvm::errs() << "nf end\n";
306324
307325 delete gutils;
308326 return nf;
0 commit comments