Skip to content

Commit 8c0a909

Browse files
Merge pull request #7 from umatin/ReverseMode
2 parents 1e2a473 + 7de1d73 commit 8c0a909

File tree

5 files changed

+224
-159
lines changed

5 files changed

+224
-159
lines changed

enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "Interfaces/AutoDiffOpInterface.h"
1616
#include "Interfaces/AutoDiffTypeInterface.h"
1717
#include "Interfaces/GradientUtils.h"
18+
#include "Interfaces/GradientUtilsReverse.h"
1819
#include "mlir/Dialect/SCF/IR/SCF.h"
1920
#include "mlir/IR/DialectRegistry.h"
2021
#include "mlir/Support/LogicalResult.h"
@@ -94,11 +95,72 @@ struct ForOpInterface
9495
return success();
9596
}
9697
};
98+
99+
struct ForOpInterfaceReverse : public ReverseAutoDiffOpInterface::ExternalModel<ForOpInterfaceReverse, scf::ForOp> {
100+
void createReverseModeAdjoint(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils, SmallVector<Value> caches) const {
101+
auto forOp = cast<scf::ForOp>(op);
102+
auto newForOp = cast<scf::ForOp>(gutils->getNewFromOriginal(op));
103+
104+
SmallVector<Value> nArgs;
105+
for (Value v : forOp.getResults()){
106+
if (auto iface = v.getType().dyn_cast<AutoDiffTypeInterface>()){
107+
if (gutils->hasInvertPointer(v)){
108+
nArgs.push_back(gutils->invertPointerM(v, builder));
109+
}
110+
else{
111+
nArgs.push_back(iface.createNullValue(builder, v.getLoc()));
112+
}
113+
}
114+
}
115+
116+
auto repFor = builder.create<scf::ForOp>(forOp.getLoc(), gutils->popCache(caches[0], builder), gutils->popCache(caches[1], builder), gutils->popCache(caches[2], builder), nArgs); // TODO
117+
repFor.getRegion().begin()->erase();
118+
119+
buildReturnFunction buildFuncRetrunOp = [](OpBuilder& builder, Location loc, SmallVector<Value> retargs){
120+
builder.create<scf::YieldOp>(loc, retargs);
121+
return ;
122+
};
123+
124+
gutils->Logic.differentiate(gutils, forOp.getRegion(), repFor.getRegion(), false, buildFuncRetrunOp);
125+
126+
// Insert the index which is carried by the scf for op.
127+
Type indexType = mlir::IndexType::get(gutils->initializationBlock->begin()->getContext());
128+
repFor.getRegion().insertArgument((unsigned)0, indexType, forOp.getLoc());
129+
130+
// TODO Can we do reverse iteration???
131+
}
132+
133+
SmallVector<Value> cacheValues(Operation *op, MGradientUtilsReverse *gutils) const {
134+
auto forOp = cast<scf::ForOp>(op);
135+
136+
Operation * newOp = gutils->getNewFromOriginal(op);
137+
OpBuilder cacheBuilder(newOp);
138+
SmallVector<Value> caches;
139+
140+
Value cacheLB = gutils->initAndPushCache(gutils->getNewFromOriginal(forOp.getLowerBound()), cacheBuilder);
141+
caches.push_back(cacheLB);
142+
143+
Value cacheUB = gutils->initAndPushCache(gutils->getNewFromOriginal(forOp.getUpperBound()), cacheBuilder);
144+
caches.push_back(cacheUB);
145+
146+
Value cacheStep = gutils->initAndPushCache(gutils->getNewFromOriginal(forOp.getStep()), cacheBuilder);
147+
caches.push_back(cacheStep);
148+
149+
return caches;
150+
}
151+
152+
void createShadowValues(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils) const {
153+
auto forOp = cast<scf::ForOp>(op);
154+
}
155+
};
156+
97157
} // namespace
98158

99159
void mlir::enzyme::registerSCFDialectAutoDiffInterface(
100160
DialectRegistry &registry) {
101161
registry.addExtension(+[](MLIRContext *context, scf::SCFDialect *) {
102162
scf::ForOp::attachInterface<ForOpInterface>(*context);
163+
164+
scf::ForOp::attachInterface<ForOpInterfaceReverse>(*context);
103165
});
104166
}

enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,14 @@
66
// TODO: no relative includes.
77
#include "../../EnzymeLogic.h"
88

9+
10+
911
namespace mlir {
1012
namespace enzyme {
1113

14+
typedef void (*buildReturnFunction) (OpBuilder&, Location, SmallVector<mlir::Value>);
15+
16+
class MGradientUtilsReverse;
1217

1318
class MFnTypeInfo {
1419
public:
@@ -100,6 +105,15 @@ class MEnzymeLogic {
100105
std::vector<bool> volatile_args, void *augmented);
101106

102107
FunctionOpInterface CreateReverseDiff(FunctionOpInterface fn, DIFFE_TYPE retType, std::vector<DIFFE_TYPE> constants, MTypeAnalysis &TA, bool returnUsed, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector<bool> volatile_args, void *augmented, SymbolTableCollection &symbolTable);
108+
void initializeShadowValues(SmallVector<mlir::Block*>& dominatorToposortBlocks, MGradientUtilsReverse * gutils);
109+
void handlePredecessors(Block * oBB, Block * newBB, Block * reverseBB, MGradientUtilsReverse * gutils, void (*buildRetrunOp) (OpBuilder&, Location, SmallVector<mlir::Value>));
110+
void visitChildren(Block * oBB, Block * reverseBB, MGradientUtilsReverse * gutils);
111+
void visitChild(Operation * op, OpBuilder &builder, MGradientUtilsReverse * gutils);
112+
bool visitChildCustom(Operation * op, OpBuilder &builder, MGradientUtilsReverse * gutils);
113+
void handleReturns(Block * oBB, Block * newBB, Block * reverseBB, MGradientUtilsReverse * gutils, bool parentRegion);
114+
void mapInvertArguments(Block * oBB, Block * reverseBB, MGradientUtilsReverse * gutils);
115+
SmallVector<mlir::Block*> getDominatorToposort(MGradientUtilsReverse *gutils, Region& region);
116+
void differentiate(MGradientUtilsReverse * gutils, Region & oldRegion, Region & newRegion, bool parentRegion, buildReturnFunction buildFuncRetrunOp);
103117
};
104118

105119
} // Namespace enzyme

enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp

Lines changed: 94 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
#include "Interfaces/GradientUtils.h"
2-
#include "Interfaces/GradientUtilsReverse.h"
31
#include "Dialect/Ops.h"
42
#include "Interfaces/AutoDiffOpInterface.h"
53
#include "Interfaces/AutoDiffTypeInterface.h"
@@ -17,21 +15,22 @@
1715
#include "llvm/ADT/BreadthFirstIterator.h"
1816
#include "mlir/IR/Dominance.h"
1917

20-
#include "GradientUtils.h"
2118
#include "EnzymeLogic.h"
19+
#include "Interfaces/GradientUtils.h"
20+
#include "Interfaces/GradientUtilsReverse.h"
2221

2322
using namespace mlir;
2423
using namespace mlir::enzyme;
2524

26-
SmallVector<mlir::Block*> getDominatorToposort(MGradientUtilsReverse *gutils){
25+
SmallVector<mlir::Block*> MEnzymeLogic::getDominatorToposort(MGradientUtilsReverse *gutils, Region& region){
2726
SmallVector<mlir::Block*> dominatorToposortBlocks;
28-
if (gutils->oldFunc.getFunctionBody().hasOneBlock()){
29-
dominatorToposortBlocks.push_back(&*(gutils->oldFunc.getFunctionBody().begin()));
27+
if (region.hasOneBlock()){
28+
dominatorToposortBlocks.push_back(&*(region.begin()));
3029
}
3130
else{
3231
auto dInfo = mlir::detail::DominanceInfoBase<false>(nullptr);
3332
llvm::DominatorTreeBase<Block, false> & dt = dInfo.getDomTree(&(gutils->oldFunc.getFunctionBody()));
34-
auto root = dt.getNode(&*(gutils->oldFunc.getFunctionBody().begin()));
33+
auto root = dt.getNode(&*(region.begin()));
3534

3635
for(llvm::DomTreeNodeBase<mlir::Block> * node : llvm::breadth_first(root)){
3736
dominatorToposortBlocks.push_back(node->getBlock());
@@ -41,28 +40,43 @@ SmallVector<mlir::Block*> getDominatorToposort(MGradientUtilsReverse *gutils){
4140
return dominatorToposortBlocks;
4241
}
4342

44-
void mapInvertArguments(Block * oBB, Block * reverseBB, MDiffeGradientUtilsReverse * gutils){
43+
void MEnzymeLogic::mapInvertArguments(Block * oBB, Block * reverseBB, MGradientUtilsReverse * gutils){
4544
for (int i = 0; i < (int)gutils->mapBlockArguments[oBB].size(); i++){
4645
auto x = gutils->mapBlockArguments[oBB][i];
4746
OpBuilder builder(reverseBB, reverseBB->begin());
4847
gutils->mapInvertPointer(x.second, reverseBB->getArgument(i), builder);
4948
}
5049
}
5150

52-
void handleReturns(Block * oBB, Block * newBB, Block * reverseBB, MDiffeGradientUtilsReverse * gutils){
51+
void MEnzymeLogic::handleReturns(Block * oBB, Block * newBB, Block * reverseBB, MGradientUtilsReverse * gutils, bool parentRegion){
5352
if (oBB->getNumSuccessors() == 0){
54-
Operation * returnStatement = newBB->getTerminator();
55-
gutils->erase(returnStatement);
53+
if (parentRegion){
54+
Operation * returnStatement = newBB->getTerminator();
55+
gutils->erase(returnStatement);
5656

57-
OpBuilder forwardToBackwardBuilder(newBB, newBB->end());
58-
gutils->mapInvertPointer(oBB->getTerminator()->getOperand(0), gutils->newFunc.getArgument(gutils->newFunc.getNumArguments() - 1), forwardToBackwardBuilder); //TODO handle multiple return values
59-
Operation * newBranchOp = forwardToBackwardBuilder.create<cf::BranchOp>(oBB->getTerminator()->getLoc(), reverseBB);
60-
61-
gutils->originalToNewFnOps[oBB->getTerminator()] = newBranchOp;
57+
OpBuilder forwardToBackwardBuilder(newBB, newBB->end());
58+
gutils->mapInvertPointer(oBB->getTerminator()->getOperand(0), gutils->newFunc.getArgument(gutils->newFunc.getNumArguments() - 1), forwardToBackwardBuilder); //TODO handle multiple return values
59+
Operation * newBranchOp = forwardToBackwardBuilder.create<cf::BranchOp>(oBB->getTerminator()->getLoc(), reverseBB);
60+
61+
gutils->originalToNewFnOps[oBB->getTerminator()] = newBranchOp;
62+
}
63+
else{
64+
Operation * terminator = oBB->getTerminator();
65+
OpBuilder builder(reverseBB, reverseBB->begin());
66+
67+
int i = 0;
68+
for(OpOperand & operand : terminator->getOpOperands()){
69+
Value val = operand.get();
70+
if (auto iface = val.getType().dyn_cast<AutoDiffTypeInterface>()) {
71+
gutils->mapInvertPointer(val, reverseBB->getArgument(i), builder);
72+
i++;
73+
}
74+
}
75+
}
6276
}
6377
}
6478

65-
bool visitChildCustom(Operation * op, OpBuilder &builder, MDiffeGradientUtilsReverse * gutils){
79+
bool MEnzymeLogic::visitChildCustom(Operation * op, OpBuilder &builder, MGradientUtilsReverse * gutils){
6680
std::string nameDiffe = "diffe_" + op->getName().getDialectNamespace().str() + "_" + op->getName().stripDialect().str();
6781
std::string nameStore = "store_" + op->getName().getDialectNamespace().str() + "_" + op->getName().stripDialect().str();
6882

@@ -128,9 +142,9 @@ bool visitChildCustom(Operation * op, OpBuilder &builder, MDiffeGradientUtilsRev
128142
/*
129143
Create reverse mode adjoint for an operation.
130144
*/
131-
void visitChild(Operation * op, OpBuilder &builder, MDiffeGradientUtilsReverse * gutils){
145+
void MEnzymeLogic::visitChild(Operation * op, OpBuilder &builder, MGradientUtilsReverse * gutils){
132146
if (auto ifaceOp = dyn_cast<ReverseAutoDiffOpInterface>(op)) {
133-
ValueRange caches = ifaceOp.cacheValues(gutils);
147+
SmallVector<Value> caches = ifaceOp.cacheValues(gutils);
134148
ifaceOp.createReverseModeAdjoint(builder, gutils, caches);
135149

136150
for (int indexResult = 0; indexResult < (int)op->getNumResults(); indexResult++){
@@ -140,7 +154,7 @@ void visitChild(Operation * op, OpBuilder &builder, MDiffeGradientUtilsReverse *
140154
}
141155
}
142156

143-
void visitChildren(Block * oBB, Block * reverseBB, MDiffeGradientUtilsReverse * gutils){
157+
void MEnzymeLogic::visitChildren(Block * oBB, Block * reverseBB, MGradientUtilsReverse * gutils){
144158
OpBuilder revBuilder(reverseBB, reverseBB->end());
145159
if (!oBB->empty()){
146160
auto first = oBB->rbegin();
@@ -155,21 +169,22 @@ void visitChildren(Block * oBB, Block * reverseBB, MDiffeGradientUtilsReverse *
155169
}
156170
}
157171

158-
void handlePredecessors(Block * oBB, Block * reverseBB, MDiffeGradientUtilsReverse * gutils){
172+
void MEnzymeLogic::handlePredecessors(Block * oBB, Block * newBB, Block * reverseBB, MGradientUtilsReverse * gutils, buildReturnFunction buildRetrunOp){
159173
OpBuilder revBuilder(reverseBB, reverseBB->end());
160174
if (oBB->hasNoPredecessors()){
161-
SmallVector<mlir::Value, 2> retargs;
175+
SmallVector<mlir::Value> retargs;
162176
for (Value attribute : gutils->oldFunc.getFunctionBody().getArguments()) {
163177
Value attributeGradient = gutils->invertPointerM(attribute, revBuilder);
164178
retargs.push_back(attributeGradient);
165179
}
166-
revBuilder.create<func::ReturnOp>(oBB->rbegin()->getLoc(), retargs);
180+
buildRetrunOp(revBuilder, oBB->rbegin()->getLoc(), retargs);
181+
//revBuilder.create<func::ReturnOp>(oBB->rbegin()->getLoc(), retargs);
167182
}
168183
else {
169184
SmallVector<Block *> blocks;
170185
SmallVector<APInt> indices;
171186
SmallVector<ValueRange> arguments;
172-
ValueRange defaultArguments;
187+
SmallVector<Value> defaultArguments;
173188
Block * defaultBlock;
174189
int i = 1;
175190
for (Block * predecessor : oBB->getPredecessors()){
@@ -188,7 +203,7 @@ void handlePredecessors(Block * oBB, Block * reverseBB, MDiffeGradientUtilsRever
188203
operands.push_back(nullValue);
189204
}
190205
else{
191-
llvm_unreachable("non canonial null value found");
206+
llvm_unreachable("no canonial null value found");
192207
}
193208
}
194209
}
@@ -197,37 +212,55 @@ void handlePredecessors(Block * oBB, Block * reverseBB, MDiffeGradientUtilsRever
197212
if (predecessor != *(oBB->getPredecessors().begin())){
198213
blocks.push_back(predecessorRevMode);
199214
indices.push_back(APInt(32, i++));
200-
arguments.push_back(ValueRange(operands));
215+
arguments.push_back(operands);
201216
}
202217
else{
203218
defaultBlock = predecessorRevMode;
204-
defaultArguments = ValueRange(operands);
219+
defaultArguments = operands;
205220
}
206221
}
222+
Location loc = oBB->rbegin()->getLoc();
207223
//Remove Dependency to CF dialect
208224
if (std::next(oBB->getPredecessors().begin()) == oBB->getPredecessors().end()){
209225
//If there is only one block we can directly create a branch for simplicity sake
210-
revBuilder.create<cf::BranchOp>(gutils->getNewFromOriginal(&*(oBB->rbegin()))->getLoc(), defaultBlock, defaultArguments);
226+
revBuilder.create<cf::BranchOp>(loc, defaultBlock, defaultArguments);
211227
}
212228
else{
213229
Value cache = gutils->insertInit(gutils->getIndexCacheType());
214-
Value flag = revBuilder.create<enzyme::PopOp>(oBB->rbegin()->getLoc(), gutils->getIndexType(), cache);
230+
Value flag = revBuilder.create<enzyme::PopOp>(loc, gutils->getIndexType(), cache);
215231

216-
revBuilder.create<cf::SwitchOp>(oBB->rbegin()->getLoc(), flag, defaultBlock, defaultArguments, ArrayRef<APInt>(indices), ArrayRef<Block *>(blocks), ArrayRef<ValueRange>(arguments));
232+
revBuilder.create<cf::SwitchOp>(loc, flag, defaultBlock, defaultArguments, ArrayRef<APInt>(indices), ArrayRef<Block *>(blocks), ArrayRef<ValueRange>(arguments));
217233

234+
Value origin = newBB->addArgument(gutils->getIndexType(), loc);
235+
236+
OpBuilder newBuilder(newBB, newBB->begin());
237+
newBuilder.create<enzyme::PushOp>(loc, cache, origin);
238+
218239
int j = 0;
219240
for (Block * predecessor : oBB->getPredecessors()){
220241
Block * newPredecessor = gutils->getNewFromOriginal(predecessor);
221-
OpBuilder predecessorBuilder(newPredecessor, std::prev(newPredecessor->end()));
222242

223-
Value indicator = predecessorBuilder.create<arith::ConstantIntOp>(oBB->rbegin()->getLoc(), j++, 32);
224-
predecessorBuilder.create<enzyme::PushOp>(oBB->rbegin()->getLoc(), cache, indicator);
243+
OpBuilder predecessorBuilder(newPredecessor, std::prev(newPredecessor->end()));
244+
Value indicator = predecessorBuilder.create<arith::ConstantIntOp>(loc, j++, 32);
245+
246+
Operation * terminator = newPredecessor->getTerminator();
247+
if (auto binst = dyn_cast<BranchOpInterface>(terminator)) {
248+
for (int i = 0; i < terminator->getNumSuccessors(); i++){
249+
if (terminator->getSuccessor(i) == newBB){
250+
SuccessorOperands sOps = binst.getSuccessorOperands(i);
251+
sOps.append(indicator);
252+
}
253+
}
254+
}
255+
else{
256+
llvm_unreachable("invalid terminator");
257+
}
225258
}
226259
}
227260
}
228261
}
229262

230-
void initializeShadowValues(SmallVector<mlir::Block*>& dominatorToposortBlocks, MDiffeGradientUtilsReverse * gutils){
263+
void MEnzymeLogic::initializeShadowValues(SmallVector<mlir::Block*>& dominatorToposortBlocks, MGradientUtilsReverse * gutils){
231264
for (auto it = dominatorToposortBlocks.begin(); it != dominatorToposortBlocks.end(); ++it){
232265
Block * oBB = *it;
233266

@@ -245,19 +278,10 @@ void initializeShadowValues(SmallVector<mlir::Block*>& dominatorToposortBlocks,
245278
}
246279
}
247280

281+
void MEnzymeLogic::differentiate(MGradientUtilsReverse * gutils, Region & oldRegion, Region & newRegion, bool parentRegion, buildReturnFunction buildFuncRetrunOp){
282+
gutils->createReverseModeBlocks(oldRegion, newRegion, parentRegion);
248283

249-
FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateReverseDiff(FunctionOpInterface fn, DIFFE_TYPE retType, std::vector<DIFFE_TYPE> constants, MTypeAnalysis &TA, bool returnUsed, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector<bool> volatile_args, void *augmented, SymbolTableCollection &symbolTable) {
250-
251-
if (fn.getFunctionBody().empty()) {
252-
llvm::errs() << fn << "\n";
253-
llvm_unreachable("Differentiating empty function");
254-
}
255-
256-
ReturnType returnValue = ReturnType::Tape;
257-
MDiffeGradientUtilsReverse * gutils = MDiffeGradientUtilsReverse::CreateFromClone(*this, mode, width, fn, TA, type_args, retType, /*diffeReturnArg*/ true, constants, returnValue, addedType, symbolTable);
258-
259-
SmallVector<mlir::Block*> dominatorToposortBlocks = getDominatorToposort(gutils);
260-
284+
SmallVector<mlir::Block*> dominatorToposortBlocks = getDominatorToposort(gutils, oldRegion);
261285
initializeShadowValues(dominatorToposortBlocks, gutils);
262286

263287
for (auto it = dominatorToposortBlocks.rbegin(); it != dominatorToposortBlocks.rend(); ++it){
@@ -266,19 +290,38 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateReverseDiff(FunctionOpInte
266290
Block * reverseBB = gutils->mapReverseModeBlocks.lookupOrNull(oBB);
267291

268292
mapInvertArguments(oBB, reverseBB, gutils);
269-
270-
handleReturns(oBB, newBB, reverseBB, gutils);
271-
293+
handleReturns(oBB, newBB, reverseBB, gutils, parentRegion);
272294
visitChildren(oBB, reverseBB, gutils);
273-
274-
handlePredecessors(oBB, reverseBB, gutils);
295+
handlePredecessors(oBB, newBB, reverseBB, gutils, buildFuncRetrunOp);
296+
}
297+
}
298+
299+
FunctionOpInterface MEnzymeLogic::CreateReverseDiff(FunctionOpInterface fn, DIFFE_TYPE retType, std::vector<DIFFE_TYPE> constants, MTypeAnalysis &TA, bool returnUsed, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector<bool> volatile_args, void *augmented, SymbolTableCollection &symbolTable) {
300+
301+
if (fn.getFunctionBody().empty()) {
302+
llvm::errs() << fn << "\n";
303+
llvm_unreachable("Differentiating empty function");
275304
}
276305

306+
ReturnType returnValue = ReturnType::Tape;
307+
MGradientUtilsReverse * gutils = MGradientUtilsReverse::CreateFromClone(*this, mode, width, fn, TA, type_args, retType, /*diffeReturnArg*/ true, constants, returnValue, addedType, symbolTable);
308+
309+
Region & oldRegion = gutils->oldFunc.getFunctionBody();
310+
Region & newRegion = gutils->newFunc.getFunctionBody();
311+
312+
buildReturnFunction buildFuncRetrunOp = [](OpBuilder& builder, Location loc, SmallVector<Value> retargs){
313+
builder.create<func::ReturnOp>(loc, retargs);
314+
return ;
315+
};
316+
317+
differentiate(gutils, oldRegion, newRegion, true, buildFuncRetrunOp);
318+
277319
auto nf = gutils->newFunc;
278320

279321
//llvm::errs() << "nf\n";
280322
//nf.dump();
281323
//llvm::errs() << "nf end\n";
324+
282325
delete gutils;
283326
return nf;
284327
}

0 commit comments

Comments
 (0)