Skip to content

Commit 6757793

Browse files
committed
add implementatin for scf for
1 parent ec1c9c2 commit 6757793

File tree

5 files changed

+202
-155
lines changed

5 files changed

+202
-155
lines changed

enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "Interfaces/AutoDiffOpInterface.h"
1616
#include "Interfaces/AutoDiffTypeInterface.h"
1717
#include "Interfaces/GradientUtils.h"
18+
#include "Interfaces/GradientUtilsReverse.h"
1819
#include "mlir/Dialect/SCF/IR/SCF.h"
1920
#include "mlir/IR/DialectRegistry.h"
2021
#include "mlir/Support/LogicalResult.h"
@@ -94,11 +95,72 @@ struct ForOpInterface
9495
return success();
9596
}
9697
};
98+
99+
struct ForOpInterfaceReverse : public ReverseAutoDiffOpInterface::ExternalModel<ForOpInterfaceReverse, scf::ForOp> {
100+
void createReverseModeAdjoint(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils, SmallVector<Value> caches) const {
101+
auto forOp = cast<scf::ForOp>(op);
102+
auto newForOp = cast<scf::ForOp>(gutils->getNewFromOriginal(op));
103+
104+
SmallVector<Value> nArgs;
105+
for (Value v : forOp.getResults()){
106+
if (auto iface = v.getType().dyn_cast<AutoDiffTypeInterface>()){
107+
if (gutils->hasInvertPointer(v)){
108+
nArgs.push_back(gutils->invertPointerM(v, builder));
109+
}
110+
else{
111+
nArgs.push_back(iface.createNullValue(builder, v.getLoc()));
112+
}
113+
}
114+
}
115+
116+
auto repFor = builder.create<scf::ForOp>(forOp.getLoc(), gutils->popCache(caches[0], builder), gutils->popCache(caches[1], builder), gutils->popCache(caches[2], builder), nArgs); // TODO
117+
repFor.getRegion().begin()->erase();
118+
119+
buildReturnFunction buildFuncRetrunOp = [](OpBuilder& builder, Location loc, SmallVector<Value> retargs){
120+
builder.create<scf::YieldOp>(loc, retargs);
121+
return ;
122+
};
123+
124+
gutils->Logic.differentiate(gutils, forOp.getRegion(), repFor.getRegion(), false, buildFuncRetrunOp);
125+
126+
// Insert the index which is carried by the scf for op.
127+
Type indexType = mlir::IndexType::get(gutils->initializationBlock->begin()->getContext());
128+
repFor.getRegion().insertArgument((unsigned)0, indexType, forOp.getLoc());
129+
130+
// TODO Can we do reverse iteration???
131+
}
132+
133+
SmallVector<Value> cacheValues(Operation *op, MGradientUtilsReverse *gutils) const {
134+
auto forOp = cast<scf::ForOp>(op);
135+
136+
Operation * newOp = gutils->getNewFromOriginal(op);
137+
OpBuilder cacheBuilder(newOp);
138+
SmallVector<Value> caches;
139+
140+
Value cacheLB = gutils->initAndPushCache(gutils->getNewFromOriginal(forOp.getLowerBound()), cacheBuilder);
141+
caches.push_back(cacheLB);
142+
143+
Value cacheUB = gutils->initAndPushCache(gutils->getNewFromOriginal(forOp.getUpperBound()), cacheBuilder);
144+
caches.push_back(cacheUB);
145+
146+
Value cacheStep = gutils->initAndPushCache(gutils->getNewFromOriginal(forOp.getStep()), cacheBuilder);
147+
caches.push_back(cacheStep);
148+
149+
return caches;
150+
}
151+
152+
void createShadowValues(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils) const {
153+
auto forOp = cast<scf::ForOp>(op);
154+
}
155+
};
156+
97157
} // namespace
98158

99159
void mlir::enzyme::registerSCFDialectAutoDiffInterface(
100160
DialectRegistry &registry) {
101161
registry.addExtension(+[](MLIRContext *context, scf::SCFDialect *) {
102162
scf::ForOp::attachInterface<ForOpInterface>(*context);
163+
164+
scf::ForOp::attachInterface<ForOpInterfaceReverse>(*context);
103165
});
104166
}

enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,14 @@
66
// TODO: no relative includes.
77
#include "../../EnzymeLogic.h"
88

9+
10+
911
namespace mlir {
1012
namespace enzyme {
1113

14+
typedef void (*buildReturnFunction) (OpBuilder&, Location, SmallVector<mlir::Value>);
15+
16+
class MGradientUtilsReverse;
1217

1318
class MFnTypeInfo {
1419
public:
@@ -100,6 +105,15 @@ class MEnzymeLogic {
100105
std::vector<bool> volatile_args, void *augmented);
101106

102107
FunctionOpInterface CreateReverseDiff(FunctionOpInterface fn, DIFFE_TYPE retType, std::vector<DIFFE_TYPE> constants, MTypeAnalysis &TA, bool returnUsed, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector<bool> volatile_args, void *augmented, SymbolTableCollection &symbolTable);
108+
void initializeShadowValues(SmallVector<mlir::Block*>& dominatorToposortBlocks, MGradientUtilsReverse * gutils);
109+
void handlePredecessors(Block * oBB, Block * reverseBB, MGradientUtilsReverse * gutils, void (*buildRetrunOp) (OpBuilder&, Location, SmallVector<mlir::Value>));
110+
void visitChildren(Block * oBB, Block * reverseBB, MGradientUtilsReverse * gutils);
111+
void visitChild(Operation * op, OpBuilder &builder, MGradientUtilsReverse * gutils);
112+
bool visitChildCustom(Operation * op, OpBuilder &builder, MGradientUtilsReverse * gutils);
113+
void handleReturns(Block * oBB, Block * newBB, Block * reverseBB, MGradientUtilsReverse * gutils, bool parentRegion);
114+
void mapInvertArguments(Block * oBB, Block * reverseBB, MGradientUtilsReverse * gutils);
115+
SmallVector<mlir::Block*> getDominatorToposort(MGradientUtilsReverse *gutils, Region& region);
116+
void differentiate(MGradientUtilsReverse * gutils, Region & oldRegion, Region & newRegion, bool parentRegion, buildReturnFunction buildFuncRetrunOp);
103117
};
104118

105119
} // Namespace enzyme

enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp

Lines changed: 72 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
#include "Interfaces/GradientUtils.h"
2-
#include "Interfaces/GradientUtilsReverse.h"
31
#include "Dialect/Ops.h"
42
#include "Interfaces/AutoDiffOpInterface.h"
53
#include "Interfaces/AutoDiffTypeInterface.h"
@@ -17,21 +15,22 @@
1715
#include "llvm/ADT/BreadthFirstIterator.h"
1816
#include "mlir/IR/Dominance.h"
1917

20-
#include "GradientUtils.h"
2118
#include "EnzymeLogic.h"
19+
#include "Interfaces/GradientUtils.h"
20+
#include "Interfaces/GradientUtilsReverse.h"
2221

2322
using namespace mlir;
2423
using namespace mlir::enzyme;
2524

26-
SmallVector<mlir::Block*> getDominatorToposort(MGradientUtilsReverse *gutils){
25+
SmallVector<mlir::Block*> MEnzymeLogic::getDominatorToposort(MGradientUtilsReverse *gutils, Region& region){
2726
SmallVector<mlir::Block*> dominatorToposortBlocks;
28-
if (gutils->oldFunc.getFunctionBody().hasOneBlock()){
29-
dominatorToposortBlocks.push_back(&*(gutils->oldFunc.getFunctionBody().begin()));
27+
if (region.hasOneBlock()){
28+
dominatorToposortBlocks.push_back(&*(region.begin()));
3029
}
3130
else{
3231
auto dInfo = mlir::detail::DominanceInfoBase<false>(nullptr);
3332
llvm::DominatorTreeBase<Block, false> & dt = dInfo.getDomTree(&(gutils->oldFunc.getFunctionBody()));
34-
auto root = dt.getNode(&*(gutils->oldFunc.getFunctionBody().begin()));
33+
auto root = dt.getNode(&*(region.begin()));
3534

3635
for(llvm::DomTreeNodeBase<mlir::Block> * node : llvm::breadth_first(root)){
3736
dominatorToposortBlocks.push_back(node->getBlock());
@@ -41,28 +40,43 @@ SmallVector<mlir::Block*> getDominatorToposort(MGradientUtilsReverse *gutils){
4140
return dominatorToposortBlocks;
4241
}
4342

44-
void mapInvertArguments(Block * oBB, Block * reverseBB, MDiffeGradientUtilsReverse * gutils){
43+
void MEnzymeLogic::mapInvertArguments(Block * oBB, Block * reverseBB, MGradientUtilsReverse * gutils){
4544
for (int i = 0; i < (int)gutils->mapBlockArguments[oBB].size(); i++){
4645
auto x = gutils->mapBlockArguments[oBB][i];
4746
OpBuilder builder(reverseBB, reverseBB->begin());
4847
gutils->mapInvertPointer(x.second, reverseBB->getArgument(i), builder);
4948
}
5049
}
5150

52-
void handleReturns(Block * oBB, Block * newBB, Block * reverseBB, MDiffeGradientUtilsReverse * gutils){
51+
void MEnzymeLogic::handleReturns(Block * oBB, Block * newBB, Block * reverseBB, MGradientUtilsReverse * gutils, bool parentRegion){
5352
if (oBB->getNumSuccessors() == 0){
54-
Operation * returnStatement = newBB->getTerminator();
55-
gutils->erase(returnStatement);
53+
if (parentRegion){
54+
Operation * returnStatement = newBB->getTerminator();
55+
gutils->erase(returnStatement);
5656

57-
OpBuilder forwardToBackwardBuilder(newBB, newBB->end());
58-
gutils->mapInvertPointer(oBB->getTerminator()->getOperand(0), gutils->newFunc.getArgument(gutils->newFunc.getNumArguments() - 1), forwardToBackwardBuilder); //TODO handle multiple return values
59-
Operation * newBranchOp = forwardToBackwardBuilder.create<cf::BranchOp>(oBB->getTerminator()->getLoc(), reverseBB);
60-
61-
gutils->originalToNewFnOps[oBB->getTerminator()] = newBranchOp;
57+
OpBuilder forwardToBackwardBuilder(newBB, newBB->end());
58+
gutils->mapInvertPointer(oBB->getTerminator()->getOperand(0), gutils->newFunc.getArgument(gutils->newFunc.getNumArguments() - 1), forwardToBackwardBuilder); //TODO handle multiple return values
59+
Operation * newBranchOp = forwardToBackwardBuilder.create<cf::BranchOp>(oBB->getTerminator()->getLoc(), reverseBB);
60+
61+
gutils->originalToNewFnOps[oBB->getTerminator()] = newBranchOp;
62+
}
63+
else{
64+
Operation * terminator = oBB->getTerminator();
65+
OpBuilder builder(reverseBB, reverseBB->begin());
66+
67+
int i = 0;
68+
for(OpOperand & operand : terminator->getOpOperands()){
69+
Value val = operand.get();
70+
if (auto iface = val.getType().dyn_cast<AutoDiffTypeInterface>()) {
71+
gutils->mapInvertPointer(val, reverseBB->getArgument(i), builder);
72+
i++;
73+
}
74+
}
75+
}
6276
}
6377
}
6478

65-
bool visitChildCustom(Operation * op, OpBuilder &builder, MDiffeGradientUtilsReverse * gutils){
79+
bool MEnzymeLogic::visitChildCustom(Operation * op, OpBuilder &builder, MGradientUtilsReverse * gutils){
6680
std::string nameDiffe = "diffe_" + op->getName().getDialectNamespace().str() + "_" + op->getName().stripDialect().str();
6781
std::string nameStore = "store_" + op->getName().getDialectNamespace().str() + "_" + op->getName().stripDialect().str();
6882

@@ -128,9 +142,9 @@ bool visitChildCustom(Operation * op, OpBuilder &builder, MDiffeGradientUtilsRev
128142
/*
129143
Create reverse mode adjoint for an operation.
130144
*/
131-
void visitChild(Operation * op, OpBuilder &builder, MDiffeGradientUtilsReverse * gutils){
145+
void MEnzymeLogic::visitChild(Operation * op, OpBuilder &builder, MGradientUtilsReverse * gutils){
132146
if (auto ifaceOp = dyn_cast<ReverseAutoDiffOpInterface>(op)) {
133-
ValueRange caches = ifaceOp.cacheValues(gutils);
147+
SmallVector<Value> caches = ifaceOp.cacheValues(gutils);
134148
ifaceOp.createReverseModeAdjoint(builder, gutils, caches);
135149

136150
for (int indexResult = 0; indexResult < (int)op->getNumResults(); indexResult++){
@@ -140,7 +154,7 @@ void visitChild(Operation * op, OpBuilder &builder, MDiffeGradientUtilsReverse *
140154
}
141155
}
142156

143-
void visitChildren(Block * oBB, Block * reverseBB, MDiffeGradientUtilsReverse * gutils){
157+
void MEnzymeLogic::visitChildren(Block * oBB, Block * reverseBB, MGradientUtilsReverse * gutils){
144158
OpBuilder revBuilder(reverseBB, reverseBB->end());
145159
if (!oBB->empty()){
146160
auto first = oBB->rbegin();
@@ -155,21 +169,22 @@ void visitChildren(Block * oBB, Block * reverseBB, MDiffeGradientUtilsReverse *
155169
}
156170
}
157171

158-
void handlePredecessors(Block * oBB, Block * reverseBB, MDiffeGradientUtilsReverse * gutils){
172+
void MEnzymeLogic::handlePredecessors(Block * oBB, Block * reverseBB, MGradientUtilsReverse * gutils, buildReturnFunction buildRetrunOp){
159173
OpBuilder revBuilder(reverseBB, reverseBB->end());
160174
if (oBB->hasNoPredecessors()){
161-
SmallVector<mlir::Value, 2> retargs;
175+
SmallVector<mlir::Value> retargs;
162176
for (Value attribute : gutils->oldFunc.getFunctionBody().getArguments()) {
163177
Value attributeGradient = gutils->invertPointerM(attribute, revBuilder);
164178
retargs.push_back(attributeGradient);
165179
}
166-
revBuilder.create<func::ReturnOp>(oBB->rbegin()->getLoc(), retargs);
180+
buildRetrunOp(revBuilder, oBB->rbegin()->getLoc(), retargs);
181+
//revBuilder.create<func::ReturnOp>(oBB->rbegin()->getLoc(), retargs);
167182
}
168183
else {
169184
SmallVector<Block *> blocks;
170185
SmallVector<APInt> indices;
171186
SmallVector<ValueRange> arguments;
172-
ValueRange defaultArguments;
187+
SmallVector<Value> defaultArguments;
173188
Block * defaultBlock;
174189
int i = 1;
175190
for (Block * predecessor : oBB->getPredecessors()){
@@ -197,11 +212,11 @@ void handlePredecessors(Block * oBB, Block * reverseBB, MDiffeGradientUtilsRever
197212
if (predecessor != *(oBB->getPredecessors().begin())){
198213
blocks.push_back(predecessorRevMode);
199214
indices.push_back(APInt(32, i++));
200-
arguments.push_back(ValueRange(operands));
215+
arguments.push_back(operands);
201216
}
202217
else{
203218
defaultBlock = predecessorRevMode;
204-
defaultArguments = ValueRange(operands);
219+
defaultArguments = operands;
205220
}
206221
}
207222
//Remove Dependency to CF dialect
@@ -227,7 +242,7 @@ void handlePredecessors(Block * oBB, Block * reverseBB, MDiffeGradientUtilsRever
227242
}
228243
}
229244

230-
void initializeShadowValues(SmallVector<mlir::Block*>& dominatorToposortBlocks, MDiffeGradientUtilsReverse * gutils){
245+
void MEnzymeLogic::initializeShadowValues(SmallVector<mlir::Block*>& dominatorToposortBlocks, MGradientUtilsReverse * gutils){
231246
for (auto it = dominatorToposortBlocks.begin(); it != dominatorToposortBlocks.end(); ++it){
232247
Block * oBB = *it;
233248

@@ -245,19 +260,10 @@ void initializeShadowValues(SmallVector<mlir::Block*>& dominatorToposortBlocks,
245260
}
246261
}
247262

263+
void MEnzymeLogic::differentiate(MGradientUtilsReverse * gutils, Region & oldRegion, Region & newRegion, bool parentRegion, buildReturnFunction buildFuncRetrunOp){
264+
gutils->createReverseModeBlocks(oldRegion, newRegion, parentRegion);
248265

249-
FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateReverseDiff(FunctionOpInterface fn, DIFFE_TYPE retType, std::vector<DIFFE_TYPE> constants, MTypeAnalysis &TA, bool returnUsed, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector<bool> volatile_args, void *augmented, SymbolTableCollection &symbolTable) {
250-
251-
if (fn.getFunctionBody().empty()) {
252-
llvm::errs() << fn << "\n";
253-
llvm_unreachable("Differentiating empty function");
254-
}
255-
256-
ReturnType returnValue = ReturnType::Tape;
257-
MDiffeGradientUtilsReverse * gutils = MDiffeGradientUtilsReverse::CreateFromClone(*this, mode, width, fn, TA, type_args, retType, /*diffeReturnArg*/ true, constants, returnValue, addedType, symbolTable);
258-
259-
SmallVector<mlir::Block*> dominatorToposortBlocks = getDominatorToposort(gutils);
260-
266+
SmallVector<mlir::Block*> dominatorToposortBlocks = getDominatorToposort(gutils, oldRegion);
261267
initializeShadowValues(dominatorToposortBlocks, gutils);
262268

263269
for (auto it = dominatorToposortBlocks.rbegin(); it != dominatorToposortBlocks.rend(); ++it){
@@ -266,19 +272,38 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateReverseDiff(FunctionOpInte
266272
Block * reverseBB = gutils->mapReverseModeBlocks.lookupOrNull(oBB);
267273

268274
mapInvertArguments(oBB, reverseBB, gutils);
269-
270-
handleReturns(oBB, newBB, reverseBB, gutils);
271-
275+
handleReturns(oBB, newBB, reverseBB, gutils, parentRegion);
272276
visitChildren(oBB, reverseBB, gutils);
273-
274-
handlePredecessors(oBB, reverseBB, gutils);
277+
handlePredecessors(oBB, reverseBB, gutils, buildFuncRetrunOp);
278+
}
279+
}
280+
281+
FunctionOpInterface MEnzymeLogic::CreateReverseDiff(FunctionOpInterface fn, DIFFE_TYPE retType, std::vector<DIFFE_TYPE> constants, MTypeAnalysis &TA, bool returnUsed, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector<bool> volatile_args, void *augmented, SymbolTableCollection &symbolTable) {
282+
283+
if (fn.getFunctionBody().empty()) {
284+
llvm::errs() << fn << "\n";
285+
llvm_unreachable("Differentiating empty function");
275286
}
276287

288+
ReturnType returnValue = ReturnType::Tape;
289+
MGradientUtilsReverse * gutils = MGradientUtilsReverse::CreateFromClone(*this, mode, width, fn, TA, type_args, retType, /*diffeReturnArg*/ true, constants, returnValue, addedType, symbolTable);
290+
291+
Region & oldRegion = gutils->oldFunc.getFunctionBody();
292+
Region & newRegion = gutils->newFunc.getFunctionBody();
293+
294+
buildReturnFunction buildFuncRetrunOp = [](OpBuilder& builder, Location loc, SmallVector<Value> retargs){
295+
builder.create<func::ReturnOp>(loc, retargs);
296+
return ;
297+
};
298+
299+
differentiate(gutils, oldRegion, newRegion, true, buildFuncRetrunOp);
300+
277301
auto nf = gutils->newFunc;
278302

279-
//llvm::errs() << "nf\n";
280-
//nf.dump();
281-
//llvm::errs() << "nf end\n";
303+
llvm::errs() << "nf\n";
304+
nf.dump();
305+
llvm::errs() << "nf end\n";
306+
282307
delete gutils;
283308
return nf;
284309
}

0 commit comments

Comments
 (0)