Skip to content

Commit ae6bee0

Browse files
committed
add more caching layers to speed compilation
1 parent a8bdf3b commit ae6bee0

File tree

7 files changed

+146
-72
lines changed

7 files changed

+146
-72
lines changed

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,13 +1593,8 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
15931593
//}
15941594
Builder2.setFastMathFlags(getFast());
15951595

1596-
std::map<Value*,Value*> alreadyLoaded;
1597-
15981596
std::function<Value*(Value*)> lookup = [&](Value* val) -> Value* {
1599-
if (alreadyLoaded.find(val) != alreadyLoaded.end()) {
1600-
return alreadyLoaded[val];
1601-
}
1602-
return alreadyLoaded[val] = gutils->lookupM(val, Builder2);
1597+
return gutils->lookupM(val, Builder2);
16031598
};
16041599

16051600
auto diffe = [&Builder2,&gutils](Value* val) -> Value* {

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,7 @@ void removeRedundantIVs(const Loop* L, BasicBlock* Header, BasicBlock* Preheader
635635
assert(cmp->getOperand(0) == increment);
636636

637637
auto scv = SE.getSCEVAtScope(cmp->getOperand(1), L);
638+
llvm::errs() << "coing to think about " << *cmp << "\n";
638639
if (cmp->isUnsigned() || (scv != SE.getCouldNotCompute() && SE.isKnownNonNegative(scv)) ) {
639640

640641
// valid replacements (since unsigned comparison and i starts at 0 counting up)
@@ -795,6 +796,13 @@ Value* GradientUtils::lookupM(Value* val, IRBuilder<>& BuilderM) {
795796
val = inst = fixLCSSA(inst, BuilderM);
796797

797798
assert(!this->isOriginalBlock(*BuilderM.GetInsertBlock()));
799+
800+
static std::map<std::pair<Value*, BasicBlock*>, Value*> cache;
801+
auto idx = std::make_pair(val, BuilderM.GetInsertBlock());
802+
if (cache.find(idx) != cache.end()) {
803+
return cache[idx];
804+
}
805+
798806
LoopContext lc;
799807
bool inLoop = getContext(inst->getParent(), lc);
800808

@@ -826,6 +834,7 @@ Value* GradientUtils::lookupM(Value* val, IRBuilder<>& BuilderM) {
826834
assert(scopeMap[inst]);
827835
Value* result = lookupValueFromCache(BuilderM, inst->getParent(), scopeMap[inst]);
828836
assert(result->getType() == inst->getType());
837+
cache[idx] = result;
829838
return result;
830839
}
831840

enzyme/Enzyme/GradientUtils.h

Lines changed: 120 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -717,63 +717,131 @@ class GradientUtils {
717717
}
718718

719719
Value* unwrapM(Value* val, IRBuilder<>& BuilderM, const ValueToValueMapTy& available, bool lookupIfAble) {
720-
assert(val);
720+
assert(val);
721+
722+
static std::map<std::pair<Value*, BasicBlock*>, Value*> cache;
723+
auto cidx = std::make_pair(val, BuilderM.GetInsertBlock());
724+
if (cache.find(cidx) != cache.end()) {
725+
return cache[cidx];
726+
}
727+
721728
if (available.count(val)) {
722729
return available.lookup(val);
723730
}
731+
732+
if (auto inst = dyn_cast<Instruction>(val)) {
733+
if (isOriginalBlock(*BuilderM.GetInsertBlock())) {
734+
if (BuilderM.GetInsertBlock()->size() && BuilderM.GetInsertPoint() != BuilderM.GetInsertBlock()->end()) {
735+
if (DT.dominates(inst, &*BuilderM.GetInsertPoint())) {
736+
//llvm::errs() << "allowed " << *inst << "from domination\n";
737+
return inst;
738+
}
739+
} else {
740+
if (DT.dominates(inst, BuilderM.GetInsertBlock())) {
741+
//llvm::errs() << "allowed " << *inst << "from block domination\n";
742+
return inst;
743+
}
744+
}
745+
}
746+
}
724747

725748
if (isa<Argument>(val) || isa<Constant>(val)) {
749+
cache[std::make_pair(val, BuilderM.GetInsertBlock())] = val;
726750
return val;
727751
} else if (isa<AllocaInst>(val)) {
752+
cache[std::make_pair(val, BuilderM.GetInsertBlock())] = val;
728753
return val;
729754
} else if (auto op = dyn_cast<CastInst>(val)) {
730755
auto op0 = unwrapM(op->getOperand(0), BuilderM, available, lookupIfAble);
731756
if (op0 == nullptr) goto endCheck;
732-
return BuilderM.CreateCast(op->getOpcode(), op0, op->getDestTy(), op->getName()+"_unwrap");
757+
auto toreturn = BuilderM.CreateCast(op->getOpcode(), op0, op->getDestTy(), op->getName()+"_unwrap");
758+
if (cache.find(std::make_pair((Value*)op->getOperand(0), BuilderM.GetInsertBlock())) != cache.end()) {
759+
cache[cidx] = toreturn;
760+
}
761+
return toreturn;
733762
} else if (auto op = dyn_cast<ExtractValueInst>(val)) {
734763
auto op0 = unwrapM(op->getAggregateOperand(), BuilderM, available, lookupIfAble);
735764
if (op0 == nullptr) goto endCheck;
736-
return BuilderM.CreateExtractValue(op0, op->getIndices(), op->getName()+"_unwrap");
765+
auto toreturn = BuilderM.CreateExtractValue(op0, op->getIndices(), op->getName()+"_unwrap");
766+
if (cache.find(std::make_pair((Value*)op->getOperand(0), BuilderM.GetInsertBlock())) != cache.end()) {
767+
cache[cidx] = toreturn;
768+
}
769+
return toreturn;
737770
} else if (auto op = dyn_cast<BinaryOperator>(val)) {
738771
auto op0 = unwrapM(op->getOperand(0), BuilderM, available, lookupIfAble);
739772
if (op0 == nullptr) goto endCheck;
740773
auto op1 = unwrapM(op->getOperand(1), BuilderM, available, lookupIfAble);
741774
if (op1 == nullptr) goto endCheck;
742-
return BuilderM.CreateBinOp(op->getOpcode(), op0, op1);
775+
auto toreturn = BuilderM.CreateBinOp(op->getOpcode(), op0, op1);
776+
if (
777+
(cache.find(std::make_pair((Value*)op->getOperand(0), BuilderM.GetInsertBlock())) != cache.end()) &&
778+
(cache.find(std::make_pair((Value*)op->getOperand(1), BuilderM.GetInsertBlock())) != cache.end()) ) {
779+
cache[cidx] = toreturn;
780+
}
781+
return toreturn;
743782
} else if (auto op = dyn_cast<ICmpInst>(val)) {
744783
auto op0 = unwrapM(op->getOperand(0), BuilderM, available, lookupIfAble);
745784
if (op0 == nullptr) goto endCheck;
746785
auto op1 = unwrapM(op->getOperand(1), BuilderM, available, lookupIfAble);
747786
if (op1 == nullptr) goto endCheck;
748-
return BuilderM.CreateICmp(op->getPredicate(), op0, op1);
787+
auto toreturn = BuilderM.CreateICmp(op->getPredicate(), op0, op1);
788+
if (
789+
(cache.find(std::make_pair((Value*)op->getOperand(0), BuilderM.GetInsertBlock())) != cache.end()) &&
790+
(cache.find(std::make_pair((Value*)op->getOperand(1), BuilderM.GetInsertBlock())) != cache.end()) ) {
791+
cache[cidx] = toreturn;
792+
}
793+
return toreturn;
749794
} else if (auto op = dyn_cast<FCmpInst>(val)) {
750795
auto op0 = unwrapM(op->getOperand(0), BuilderM, available, lookupIfAble);
751796
if (op0 == nullptr) goto endCheck;
752797
auto op1 = unwrapM(op->getOperand(1), BuilderM, available, lookupIfAble);
753798
if (op1 == nullptr) goto endCheck;
754-
return BuilderM.CreateFCmp(op->getPredicate(), op0, op1);
799+
auto toreturn = BuilderM.CreateFCmp(op->getPredicate(), op0, op1);
800+
if (
801+
(cache.find(std::make_pair((Value*)op->getOperand(0), BuilderM.GetInsertBlock())) != cache.end()) &&
802+
(cache.find(std::make_pair((Value*)op->getOperand(1), BuilderM.GetInsertBlock())) != cache.end()) ) {
803+
cache[cidx] = toreturn;
804+
}
805+
return toreturn;
755806
} else if (auto op = dyn_cast<SelectInst>(val)) {
756807
auto op0 = unwrapM(op->getOperand(0), BuilderM, available, lookupIfAble);
757808
if (op0 == nullptr) goto endCheck;
758809
auto op1 = unwrapM(op->getOperand(1), BuilderM, available, lookupIfAble);
759810
if (op1 == nullptr) goto endCheck;
760811
auto op2 = unwrapM(op->getOperand(2), BuilderM, available, lookupIfAble);
761812
if (op2 == nullptr) goto endCheck;
762-
return BuilderM.CreateSelect(op0, op1, op2);
813+
auto toreturn = BuilderM.CreateSelect(op0, op1, op2);
814+
if (
815+
(cache.find(std::make_pair((Value*)op->getOperand(0), BuilderM.GetInsertBlock())) != cache.end()) &&
816+
(cache.find(std::make_pair((Value*)op->getOperand(1), BuilderM.GetInsertBlock())) != cache.end()) &&
817+
(cache.find(std::make_pair((Value*)op->getOperand(2), BuilderM.GetInsertBlock())) != cache.end()) ) {
818+
cache[cidx] = toreturn;
819+
}
820+
return toreturn;
763821
} else if (auto inst = dyn_cast<GetElementPtrInst>(val)) {
764822
auto ptr = unwrapM(inst->getPointerOperand(), BuilderM, available, lookupIfAble);
765823
if (ptr == nullptr) goto endCheck;
824+
bool cached = cache.find(std::make_pair(inst->getPointerOperand(), BuilderM.GetInsertBlock())) != cache.end();
766825
SmallVector<Value*,4> ind;
767826
for(auto& a : inst->indices()) {
768827
auto op = unwrapM(a, BuilderM,available, lookupIfAble);
769828
if (op == nullptr) goto endCheck;
829+
cached &= cache.find(std::make_pair((Value*)a, BuilderM.GetInsertBlock())) != cache.end();
770830
ind.push_back(op);
771831
}
772-
return BuilderM.CreateGEP(ptr, ind);
832+
auto toreturn = BuilderM.CreateGEP(ptr, ind, inst->getName() + "_unwrap");
833+
if (cached) {
834+
cache[cidx] = toreturn;
835+
}
836+
return toreturn;
773837
} else if (auto load = dyn_cast<LoadInst>(val)) {
774838
Value* idx = unwrapM(load->getOperand(0), BuilderM, available, lookupIfAble);
775839
if (idx == nullptr) goto endCheck;
776-
return BuilderM.CreateLoad(idx);
840+
auto toreturn = BuilderM.CreateLoad(idx);
841+
if (cache.find(std::make_pair((Value*)load->getOperand(0), BuilderM.GetInsertBlock())) != cache.end()) {
842+
cache[cidx] = toreturn;
843+
}
844+
return toreturn;
777845
} else if (auto op = dyn_cast<IntrinsicInst>(val)) {
778846
switch(op->getIntrinsicID()) {
779847
case Intrinsic::sin: {
@@ -839,7 +907,6 @@ class GradientUtils {
839907
if (!inLoop) {
840908
return entryBuilder.CreateAlloca(T, nullptr, name+"_cache");
841909
} else {
842-
Value* size = nullptr;
843910

844911
BasicBlock* outermostPreheader = nullptr;
845912

@@ -853,38 +920,45 @@ class GradientUtils {
853920

854921
IRBuilder <> allocationBuilder(&outermostPreheader->back());
855922

856-
for(LoopContext idx = lc; ; getContext(idx.parent->getHeader(), idx) ) {
857-
//TODO handle allocations for dynamic loops
858-
if (idx.dynamic && idx.parent != nullptr) {
859-
assert(idx.var);
860-
assert(idx.var->getParent());
861-
assert(idx.var->getParent()->getParent());
862-
llvm::errs() << *idx.var->getParent()->getParent() << "\n"
863-
<< "idx.var=" <<*idx.var << "\n"
864-
<< "idx.limit=" <<*idx.limit << "\n";
865-
llvm::errs() << "cannot handle non-outermost dynamic loop\n";
866-
assert(0 && "cannot handle non-outermost dynamic loop");
867-
}
868-
Value* ns = nullptr;
869-
Type* intT = idx.dynamic ? cast<PointerType>(idx.limit->getType())->getElementType() : idx.limit->getType();
870-
if (idx.dynamic) {
871-
ns = ConstantInt::get(intT, 1);
872-
} else {
873-
Value* limitm1 = nullptr;
874-
ValueToValueMapTy emptyMap;
875-
limitm1 = unwrapM(idx.limit, allocationBuilder, emptyMap, /*lookupIfAble*/false);
876-
if (limitm1 == nullptr) {
877-
assert(outermostPreheader);
878-
assert(outermostPreheader->getParent());
879-
llvm::errs() << *outermostPreheader->getParent() << "\n";
880-
llvm::errs() << "needed value " << *idx.limit << " at " << allocationBuilder.GetInsertBlock()->getName() << "\n";
923+
Value* size = nullptr;
924+
static std::map<BasicBlock*, Value*> sizecache;
925+
if (sizecache.find(lc.header) != sizecache.end()) {
926+
size = sizecache[lc.header];
927+
} else {
928+
for(LoopContext idx = lc; ; getContext(idx.parent->getHeader(), idx) ) {
929+
//TODO handle allocations for dynamic loops
930+
if (idx.dynamic && idx.parent != nullptr) {
931+
assert(idx.var);
932+
assert(idx.var->getParent());
933+
assert(idx.var->getParent()->getParent());
934+
llvm::errs() << *idx.var->getParent()->getParent() << "\n"
935+
<< "idx.var=" <<*idx.var << "\n"
936+
<< "idx.limit=" <<*idx.limit << "\n";
937+
llvm::errs() << "cannot handle non-outermost dynamic loop\n";
938+
assert(0 && "cannot handle non-outermost dynamic loop");
939+
}
940+
Value* ns = nullptr;
941+
Type* intT = idx.dynamic ? cast<PointerType>(idx.limit->getType())->getElementType() : idx.limit->getType();
942+
if (idx.dynamic) {
943+
ns = ConstantInt::get(intT, 1);
944+
} else {
945+
Value* limitm1 = nullptr;
946+
ValueToValueMapTy emptyMap;
947+
limitm1 = unwrapM(idx.limit, allocationBuilder, emptyMap, /*lookupIfAble*/false);
948+
if (limitm1 == nullptr) {
949+
assert(outermostPreheader);
950+
assert(outermostPreheader->getParent());
951+
llvm::errs() << *outermostPreheader->getParent() << "\n";
952+
llvm::errs() << "needed value " << *idx.limit << " at " << allocationBuilder.GetInsertBlock()->getName() << "\n";
953+
}
954+
assert(limitm1);
955+
ns = allocationBuilder.CreateNUWAdd(limitm1, ConstantInt::get(intT, 1));
956+
}
957+
if (size == nullptr) size = ns;
958+
else size = allocationBuilder.CreateNUWMul(size, ns);
959+
if (idx.parent == nullptr) break;
881960
}
882-
assert(limitm1);
883-
ns = allocationBuilder.CreateNUWAdd(limitm1, ConstantInt::get(intT, 1));
884-
}
885-
if (size == nullptr) size = ns;
886-
else size = allocationBuilder.CreateNUWMul(size, ns);
887-
if (idx.parent == nullptr) break;
961+
sizecache[lc.header] = size;
888962
}
889963

890964
auto firstallocation = CallInst::CreateMalloc(
@@ -955,6 +1029,7 @@ class GradientUtils {
9551029
limits.push_back(lim);
9561030
}
9571031

1032+
/*
9581033
Value* idx = nullptr;
9591034
for(unsigned i=0; i<indices.size(); i++) {
9601035
if (i == 0) {
@@ -963,20 +1038,18 @@ class GradientUtils {
9631038
auto mul = v.CreateNUWMul(indices[i], limits[i-1]);
9641039
idx = v.CreateNUWAdd(idx, mul);
9651040
}
966-
}
1041+
}*/
9671042

9681043
if (dynamicPHI != nullptr) {
9691044
Type *BPTy = Type::getInt8PtrTy(v.GetInsertBlock()->getContext());
9701045
auto realloc = newFunc->getParent()->getOrInsertFunction("realloc", BPTy, BPTy, size->getType());
9711046
Value* allocation = v.CreateLoad(holderAlloc);
972-
auto foo = v.CreateNUWAdd(dynamicPHI, ConstantInt::get(dynamicPHI->getType(), 1));
1047+
Value* foo = v.CreateNUWAdd(dynamicPHI, ConstantInt::get(dynamicPHI->getType(), 1));
1048+
Value* realloc_size = v.CreateNUWMul(size, foo);
9731049
Value* idxs[2] = {
9741050
v.CreatePointerCast(allocation, BPTy),
9751051
v.CreateNUWMul(
976-
ConstantInt::get(size->getType(), newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits(T)/8),
977-
v.CreateNUWMul(
978-
size, foo
979-
)
1052+
ConstantInt::get(size->getType(), newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits(T)/8), realloc_size
9801053
)
9811054
};
9821055

enzyme/test/Enzyme/cppllist.ll

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -206,15 +206,14 @@ attributes #8 = { builtin nounwind }
206206
; CHECK-NEXT: %[[antivar:.+]] = phi i64 [ %n, %[[invertdelete]] ], [ %[[isub:.+]], %invertfor.body.i ]
207207
; CHECK-NEXT: %[[isub]] = add i64 %[[antivar]], -1
208208
; CHECK-NEXT: %[[gepiv:.+]] = getelementptr i8*, i8** %"call'mi_malloccache.i", i64 %[[antivar]]
209-
; CHECK-NEXT: %[[bcast:.+]] = bitcast i8** %[[gepiv]] to double**
210-
; CHECK-NEXT: %[[metaload:.+]] = load double*, double** %[[bcast]]
211-
; CHECK-NEXT: %[[load:.+]] = load double, double* %[[metaload]]
209+
; CHECK-NEXT: %[[metaload:.+]] = load i8*, i8** %[[gepiv]]
210+
; CHECK-NEXT: %[[bcast:.+]] = bitcast i8* %[[metaload]] to double*
211+
; CHECK-NEXT: %[[load:.+]] = load double, double* %[[bcast]]
212212
; this store is optional and could get removed by DCE
213-
; CHECK-NEXT: store double 0.000000e+00, double* %[[metaload]]
213+
; CHECK-NEXT: store double 0.000000e+00, double* %[[bcast]]
214214
; CHECK-NEXT: %[[xadd]] = fadd fast double %"x'de.0.i", %[[load]]
215215
; this reload really should be eliminated
216-
; CHECK-NEXT: %[[recallpload2free:.+]] = load i8*, i8** %[[gepiv]]
217-
; CHECK-NEXT: call void @_ZdlPv(i8* nonnull %[[recallpload2free]]) #5
216+
; CHECK-NEXT: call void @_ZdlPv(i8* nonnull %[[metaload]]) #5
218217
; CHECK-NEXT: %[[heregep:.+]] = getelementptr i8*, i8** %call_malloccache.i, i64 %[[antivar]]
219218
; CHECK-NEXT: %[[callload2free:.+]] = load i8*, i8** %[[heregep]]
220219
; CHECK-NEXT: call void @_ZdlPv(i8* %[[callload2free]]) #5

enzyme/test/Enzyme/initializemany.ll

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,13 +181,12 @@ attributes #4 = { nounwind }
181181
; CHECK-NEXT: %[[antivar:.+]] = phi i64 [ %wide.trip.count, %entry ], [ %[[sub:.+]], %invertfor.body ]
182182
; CHECK-NEXT: %[[sub]] = add i64 %[[antivar]], -1
183183
; CHECK-NEXT: %[[geper:.+]] = getelementptr i8*, i8** %0, i64 %[[sub]]
184-
; CHECK-NEXT: %[[bc:.+]] = bitcast i8** %[[geper]] to double**
185-
; CHECK-NEXT: %[[metaload:.+]] = load double*, double** %[[bc]], align 8
186-
; CHECK-NEXT: %[[load:.+]] = load double, double* %[[metaload]], align 8
187-
; CHECK-NEXT: store double 0.000000e+00, double* %[[metaload]], align 8
184+
; CHECK-NEXT: %[[metaload:.+]] = load i8*, i8** %[[geper]], align 8
185+
; CHECK-NEXT: %[[bc:.+]] = bitcast i8* %[[metaload]] to double*
186+
; CHECK-NEXT: %[[load:.+]] = load double, double* %[[bc]], align 8
187+
; CHECK-NEXT: store double 0.000000e+00, double* %[[bc]], align 8
188188
; CHECK-NEXT: %[[added]] = fadd fast double %"x'de.0", %[[load]]
189-
; CHECK-NEXT: %[[tofree:.+]] = load i8*, i8** %[[geper]], align 8
190-
; CHECK-NEXT: tail call void @free(i8* nonnull %[[tofree]])
189+
; CHECK-NEXT: tail call void @free(i8* nonnull %[[metaload]])
191190
; CHECK-NEXT: %[[lcmp:.+]] = icmp eq i64 %[[sub]], 0
192191
; CHECK-NEXT: br i1 %[[lcmp]], label %invertentry, label %invertfor.body
193192
; CHECK-NEXT: }

enzyme/test/Enzyme/llist.ll

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,13 @@ attributes #4 = { nounwind }
126126
; CHECK-NEXT: %[[antivar:.+]] = phi i64 [ %n, %invertfor.cond.cleanup.i ], [ %[[sub:.+]], %invertfor.body.i ]
127127
; CHECK-NEXT: %[[sub]] = add i64 %[[antivar]], -1
128128
; CHECK-NEXT: %[[gep:.+]] = getelementptr i8*, i8** %"call'mi_malloccache.i", i64 %[[antivar]]
129-
; CHECK-NEXT: %[[ccast:.+]] = bitcast i8** %[[gep]] to double**
130-
; CHECK-NEXT: %[[loadcache:.+]] = load double*, double** %[[ccast]]
131-
; CHECK-NEXT: %[[load:.+]] = load double, double* %[[loadcache]]
129+
; CHECK-NEXT: %[[loadcache:.+]] = load i8*, i8** %[[gep]]
130+
; CHECK-NEXT: %[[ccast:.+]] = bitcast i8* %[[loadcache]] to double*
131+
; CHECK-NEXT: %[[load:.+]] = load double, double* %[[ccast]]
132132
; this store is optional and could get removed by DCE
133-
; CHECK-NEXT: store double 0.000000e+00, double* %[[loadcache]]
133+
; CHECK-NEXT: store double 0.000000e+00, double* %[[ccast]]
134134
; CHECK-NEXT: %[[add]] = fadd fast double %"x'de.0.i", %[[load]]
135-
; CHECK-NEXT: %[[prefree2:.+]] = load i8*, i8** %[[gep]]
136-
; CHECK-NEXT: call void @free(i8* nonnull %[[prefree2]]) #4
135+
; CHECK-NEXT: call void @free(i8* nonnull %[[loadcache]]) #4
137136
; CHECK-NEXT: %[[gepcall:.+]] = getelementptr i8*, i8** %call_malloccache.i, i64 %[[antivar]]
138137
; CHECK-NEXT: %[[loadprefree:.+]] = load i8*, i8** %[[gepcall]]
139138
; CHECK-NEXT: call void @free(i8* %[[loadprefree]]) #4

0 commit comments

Comments
 (0)