Skip to content

Commit 7de1d73

Browse files
committed
fixed a bug with self loops
1 parent 6757793 commit 7de1d73

File tree

2 files changed

+31
-13
lines changed

2 files changed

+31
-13
lines changed

enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class MEnzymeLogic {
106106

107107
FunctionOpInterface CreateReverseDiff(FunctionOpInterface fn, DIFFE_TYPE retType, std::vector<DIFFE_TYPE> constants, MTypeAnalysis &TA, bool returnUsed, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector<bool> volatile_args, void *augmented, SymbolTableCollection &symbolTable);
108108
void initializeShadowValues(SmallVector<mlir::Block*>& dominatorToposortBlocks, MGradientUtilsReverse * gutils);
109-
void handlePredecessors(Block * oBB, Block * reverseBB, MGradientUtilsReverse * gutils, void (*buildRetrunOp) (OpBuilder&, Location, SmallVector<mlir::Value>));
109+
void handlePredecessors(Block * oBB, Block * newBB, Block * reverseBB, MGradientUtilsReverse * gutils, void (*buildRetrunOp) (OpBuilder&, Location, SmallVector<mlir::Value>));
110110
void visitChildren(Block * oBB, Block * reverseBB, MGradientUtilsReverse * gutils);
111111
void visitChild(Operation * op, OpBuilder &builder, MGradientUtilsReverse * gutils);
112112
bool visitChildCustom(Operation * op, OpBuilder &builder, MGradientUtilsReverse * gutils);

enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ void MEnzymeLogic::visitChildren(Block * oBB, Block * reverseBB, MGradientUtilsR
169169
}
170170
}
171171

172-
void MEnzymeLogic::handlePredecessors(Block * oBB, Block * reverseBB, MGradientUtilsReverse * gutils, buildReturnFunction buildRetrunOp){
172+
void MEnzymeLogic::handlePredecessors(Block * oBB, Block * newBB, Block * reverseBB, MGradientUtilsReverse * gutils, buildReturnFunction buildRetrunOp){
173173
OpBuilder revBuilder(reverseBB, reverseBB->end());
174174
if (oBB->hasNoPredecessors()){
175175
SmallVector<mlir::Value> retargs;
@@ -203,7 +203,7 @@ void MEnzymeLogic::handlePredecessors(Block * oBB, Block * reverseBB, MGradientU
203203
operands.push_back(nullValue);
204204
}
205205
else{
206-
llvm_unreachable("non canonial null value found");
206+
llvm_unreachable("no canonial null value found");
207207
}
208208
}
209209
}
@@ -219,24 +219,42 @@ void MEnzymeLogic::handlePredecessors(Block * oBB, Block * reverseBB, MGradientU
219219
defaultArguments = operands;
220220
}
221221
}
222+
Location loc = oBB->rbegin()->getLoc();
222223
//Remove Dependency to CF dialect
223224
if (std::next(oBB->getPredecessors().begin()) == oBB->getPredecessors().end()){
224225
//If there is only one block we can directly create a branch for simplicity sake
225-
revBuilder.create<cf::BranchOp>(gutils->getNewFromOriginal(&*(oBB->rbegin()))->getLoc(), defaultBlock, defaultArguments);
226+
revBuilder.create<cf::BranchOp>(loc, defaultBlock, defaultArguments);
226227
}
227228
else{
228229
Value cache = gutils->insertInit(gutils->getIndexCacheType());
229-
Value flag = revBuilder.create<enzyme::PopOp>(oBB->rbegin()->getLoc(), gutils->getIndexType(), cache);
230+
Value flag = revBuilder.create<enzyme::PopOp>(loc, gutils->getIndexType(), cache);
230231

231-
revBuilder.create<cf::SwitchOp>(oBB->rbegin()->getLoc(), flag, defaultBlock, defaultArguments, ArrayRef<APInt>(indices), ArrayRef<Block *>(blocks), ArrayRef<ValueRange>(arguments));
232+
revBuilder.create<cf::SwitchOp>(loc, flag, defaultBlock, defaultArguments, ArrayRef<APInt>(indices), ArrayRef<Block *>(blocks), ArrayRef<ValueRange>(arguments));
232233

234+
Value origin = newBB->addArgument(gutils->getIndexType(), loc);
235+
236+
OpBuilder newBuilder(newBB, newBB->begin());
237+
newBuilder.create<enzyme::PushOp>(loc, cache, origin);
238+
233239
int j = 0;
234240
for (Block * predecessor : oBB->getPredecessors()){
235241
Block * newPredecessor = gutils->getNewFromOriginal(predecessor);
236-
OpBuilder predecessorBuilder(newPredecessor, std::prev(newPredecessor->end()));
237242

238-
Value indicator = predecessorBuilder.create<arith::ConstantIntOp>(oBB->rbegin()->getLoc(), j++, 32);
239-
predecessorBuilder.create<enzyme::PushOp>(oBB->rbegin()->getLoc(), cache, indicator);
243+
OpBuilder predecessorBuilder(newPredecessor, std::prev(newPredecessor->end()));
244+
Value indicator = predecessorBuilder.create<arith::ConstantIntOp>(loc, j++, 32);
245+
246+
Operation * terminator = newPredecessor->getTerminator();
247+
if (auto binst = dyn_cast<BranchOpInterface>(terminator)) {
248+
for (int i = 0; i < terminator->getNumSuccessors(); i++){
249+
if (terminator->getSuccessor(i) == newBB){
250+
SuccessorOperands sOps = binst.getSuccessorOperands(i);
251+
sOps.append(indicator);
252+
}
253+
}
254+
}
255+
else{
256+
llvm_unreachable("invalid terminator");
257+
}
240258
}
241259
}
242260
}
@@ -274,7 +292,7 @@ void MEnzymeLogic::differentiate(MGradientUtilsReverse * gutils, Region & oldReg
274292
mapInvertArguments(oBB, reverseBB, gutils);
275293
handleReturns(oBB, newBB, reverseBB, gutils, parentRegion);
276294
visitChildren(oBB, reverseBB, gutils);
277-
handlePredecessors(oBB, reverseBB, gutils, buildFuncRetrunOp);
295+
handlePredecessors(oBB, newBB, reverseBB, gutils, buildFuncRetrunOp);
278296
}
279297
}
280298

@@ -300,9 +318,9 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(FunctionOpInterface fn, DIFF
300318

301319
auto nf = gutils->newFunc;
302320

303-
llvm::errs() << "nf\n";
304-
nf.dump();
305-
llvm::errs() << "nf end\n";
321+
//llvm::errs() << "nf\n";
322+
//nf.dump();
323+
//llvm::errs() << "nf end\n";
306324

307325
delete gutils;
308326
return nf;

0 commit comments

Comments
 (0)