Skip to content

Commit d5510e3

Browse files
authored
Make single version of primal stack lowering handler (rust-lang#915)
1 parent 235fc1d commit d5510e3

File tree

1 file changed

+88
-268
lines changed

1 file changed

+88
-268
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 88 additions & 268 deletions
Original file line numberDiff line numberDiff line change
@@ -10739,115 +10739,110 @@ class AdjointGenerator
1073910739
}
1074010740
}
1074110741

10742-
// Don't erase any store that needs to be preserved for a
10743-
// rematerialization
10744-
{
10745-
auto found = gutils->rematerializableAllocations.find(orig);
10746-
if (found != gutils->rematerializableAllocations.end()) {
10747-
// If rematerializing (e.g. needed in reverse, but not needing
10748-
// the whole allocation):
10749-
if (primalNeededInReverse && !cacheWholeAllocation) {
10750-
// if rematerialize, don't ever cache and downgrade to stack
10751-
// allocation where possible.
10752-
if (auto MD = hasMetadata(orig, "enzyme_fromstack")) {
10753-
if (Mode == DerivativeMode::ReverseModeGradient &&
10754-
found->second.LI) {
10755-
gutils->rematerializedPrimalOrShadowAllocations.push_back(
10756-
newCall);
10757-
} else {
10758-
IRBuilder<> B(newCall);
10759-
10760-
Value *Size;
10761-
if (funcName == "malloc")
10762-
Size = orig->getArgOperand(0);
10763-
else if (funcName == "julia.gc_alloc_obj" ||
10764-
funcName == "jl_gc_alloc_typed" ||
10765-
funcName == "ijl_gc_alloc_typed")
10766-
Size = orig->getArgOperand(1);
10767-
else
10768-
llvm_unreachable("Unknown allocation to upgrade");
10769-
Size = gutils->getNewFromOriginal(Size);
10770-
10771-
if (auto CI = dyn_cast<ConstantInt>(Size)) {
10772-
B.SetInsertPoint(gutils->inversionAllocs);
10773-
}
10742+
std::function<void(MDNode *)> restoreFromStack = [&](MDNode *MD) {
10743+
IRBuilder<> B(newCall);
10744+
Value *Size;
10745+
if (funcName == "malloc")
10746+
Size = orig->getArgOperand(0);
10747+
else if (funcName == "julia.gc_alloc_obj" ||
10748+
funcName == "jl_gc_alloc_typed" ||
10749+
funcName == "ijl_gc_alloc_typed")
10750+
Size = orig->getArgOperand(1);
10751+
else
10752+
llvm_unreachable("Unknown allocation to upgrade");
10753+
Size = gutils->getNewFromOriginal(Size);
1077410754

10775-
Type *elTy = Type::getInt8Ty(orig->getContext());
10776-
Instruction *I = nullptr;
10755+
if (auto CI = dyn_cast<ConstantInt>(Size)) {
10756+
B.SetInsertPoint(gutils->inversionAllocs);
10757+
}
10758+
Type *elTy = Type::getInt8Ty(orig->getContext());
10759+
Instruction *I = nullptr;
1077710760
#if LLVM_VERSION_MAJOR >= 15
10778-
if (orig->getContext().supportsTypedPointers()) {
10779-
#endif
10780-
for (auto U : orig->users()) {
10781-
if (hasMetadata(cast<Instruction>(U), "enzyme_caststack")) {
10782-
elTy = U->getType()->getPointerElementType();
10783-
Value *tsize = ConstantInt::get(
10784-
Size->getType(), (gutils->newFunc->getParent()
10785-
->getDataLayout()
10786-
.getTypeAllocSizeInBits(elTy) +
10787-
7) /
10788-
8);
10789-
Size = B.CreateUDiv(Size, tsize, "", /*exact*/ true);
10790-
I = gutils->getNewFromOriginal(cast<Instruction>(U));
10791-
break;
10792-
}
10793-
}
10761+
if (orig->getContext().supportsTypedPointers()) {
10762+
#endif
10763+
for (auto U : orig->users()) {
10764+
if (hasMetadata(cast<Instruction>(U), "enzyme_caststack")) {
10765+
elTy = U->getType()->getPointerElementType();
10766+
Value *tsize = ConstantInt::get(
10767+
Size->getType(), (gutils->newFunc->getParent()
10768+
->getDataLayout()
10769+
.getTypeAllocSizeInBits(elTy) +
10770+
7) /
10771+
8);
10772+
Size = B.CreateUDiv(Size, tsize, "", /*exact*/ true);
10773+
I = gutils->getNewFromOriginal(cast<Instruction>(U));
10774+
break;
10775+
}
10776+
}
1079410777
#if LLVM_VERSION_MAJOR >= 15
10795-
}
10778+
}
1079610779
#endif
10797-
10798-
Value *replacement = B.CreateAlloca(elTy, Size);
10799-
if (I)
10800-
replacement->takeName(I);
10801-
else
10802-
replacement->takeName(newCall);
10803-
10804-
auto Alignment =
10805-
cast<ConstantInt>(
10806-
cast<ConstantAsMetadata>(MD->getOperand(0))->getValue())
10807-
->getLimitedValue();
10808-
// Don't set zero alignment
10809-
if (Alignment) {
10780+
Value *replacement = B.CreateAlloca(elTy, Size);
10781+
if (I)
10782+
replacement->takeName(I);
10783+
else
10784+
replacement->takeName(newCall);
10785+
auto Alignment =
10786+
cast<ConstantInt>(
10787+
cast<ConstantAsMetadata>(MD->getOperand(0))->getValue())
10788+
->getLimitedValue();
10789+
// Don't set zero alignment
10790+
if (Alignment) {
1081010791
#if LLVM_VERSION_MAJOR >= 10
10811-
cast<AllocaInst>(replacement)->setAlignment(Align(Alignment));
10792+
cast<AllocaInst>(replacement)->setAlignment(Align(Alignment));
1081210793
#else
10813-
cast<AllocaInst>(replacement)->setAlignment(Alignment);
10794+
cast<AllocaInst>(replacement)->setAlignment(Alignment);
1081410795
#endif
10815-
}
10796+
}
1081610797
#if LLVM_VERSION_MAJOR >= 15
10817-
if (orig->getContext().supportsTypedPointers()) {
10798+
if (orig->getContext().supportsTypedPointers()) {
1081810799
#endif
10819-
if (orig->getType()->getPointerElementType() != elTy)
10820-
replacement = B.CreatePointerCast(
10821-
replacement,
10822-
PointerType::getUnqual(
10823-
orig->getType()->getPointerElementType()));
10800+
if (orig->getType()->getPointerElementType() != elTy)
10801+
replacement = B.CreatePointerCast(
10802+
replacement, PointerType::getUnqual(
10803+
orig->getType()->getPointerElementType()));
1082410804

1082510805
#if LLVM_VERSION_MAJOR >= 15
10826-
}
10806+
}
1082710807
#endif
10808+
if (int AS = cast<PointerType>(orig->getType())->getAddressSpace()) {
1082810809

10829-
if (int AS =
10830-
cast<PointerType>(orig->getType())->getAddressSpace()) {
10831-
10832-
llvm::PointerType *PT;
10810+
llvm::PointerType *PT;
1083310811
#if LLVM_VERSION_MAJOR >= 15
10834-
if (orig->getContext().supportsTypedPointers()) {
10812+
if (orig->getContext().supportsTypedPointers()) {
1083510813
#endif
10836-
PT = PointerType::get(
10837-
orig->getType()->getPointerElementType(), AS);
10814+
PT = PointerType::get(orig->getType()->getPointerElementType(), AS);
1083810815
#if LLVM_VERSION_MAJOR >= 15
10839-
} else {
10840-
PT = PointerType::get(orig->getContext(), AS);
10841-
}
10816+
} else {
10817+
PT = PointerType::get(orig->getContext(), AS);
10818+
}
1084210819
#endif
10843-
replacement = B.CreateAddrSpaceCast(replacement, PT);
10844-
cast<Instruction>(replacement)
10845-
->setMetadata("enzyme_backstack",
10846-
MDNode::get(replacement->getContext(), {}));
10847-
}
10820+
replacement = B.CreateAddrSpaceCast(replacement, PT);
10821+
cast<Instruction>(replacement)
10822+
->setMetadata("enzyme_backstack",
10823+
MDNode::get(replacement->getContext(), {}));
10824+
}
10825+
gutils->replaceAWithB(newCall, replacement);
10826+
gutils->erase(newCall);
10827+
};
1084810828

10849-
gutils->replaceAWithB(newCall, replacement);
10850-
gutils->erase(newCall);
10829+
// Don't erase any store that needs to be preserved for a
10830+
// rematerialization
10831+
{
10832+
auto found = gutils->rematerializableAllocations.find(orig);
10833+
if (found != gutils->rematerializableAllocations.end()) {
10834+
// If rematerializing (e.g. needed in reverse, but not needing
10835+
// the whole allocation):
10836+
if (primalNeededInReverse && !cacheWholeAllocation) {
10837+
// if rematerialize, don't ever cache and downgrade to stack
10838+
// allocation where possible.
10839+
if (auto MD = hasMetadata(orig, "enzyme_fromstack")) {
10840+
if (Mode == DerivativeMode::ReverseModeGradient &&
10841+
found->second.LI) {
10842+
gutils->rematerializedPrimalOrShadowAllocations.push_back(
10843+
newCall);
10844+
} else {
10845+
restoreFromStack(MD);
1085110846
}
1085210847
return;
1085310848
}
@@ -10896,97 +10891,7 @@ class AdjointGenerator
1089610891
if (Mode == DerivativeMode::ReverseModeGradient)
1089710892
eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
1089810893
else if (auto MD = hasMetadata(orig, "enzyme_fromstack")) {
10899-
IRBuilder<> B(newCall);
10900-
10901-
Value *Size;
10902-
if (funcName == "malloc")
10903-
Size = orig->getArgOperand(0);
10904-
else if (funcName == "julia.gc_alloc_obj" ||
10905-
funcName == "jl_gc_alloc_typed" ||
10906-
funcName == "ijl_gc_alloc_typed")
10907-
Size = orig->getArgOperand(1);
10908-
else
10909-
llvm_unreachable("Unknown allocation to upgrade");
10910-
Size = gutils->getNewFromOriginal(Size);
10911-
10912-
if (auto CI = dyn_cast<ConstantInt>(Size)) {
10913-
B.SetInsertPoint(gutils->inversionAllocs);
10914-
}
10915-
10916-
Type *elTy = Type::getInt8Ty(orig->getContext());
10917-
Instruction *I = nullptr;
10918-
#if LLVM_VERSION_MAJOR >= 15
10919-
if (orig->getContext().supportsTypedPointers()) {
10920-
#endif
10921-
for (auto U : orig->users()) {
10922-
if (hasMetadata(cast<Instruction>(U), "enzyme_caststack")) {
10923-
elTy = U->getType()->getPointerElementType();
10924-
Value *tsize = ConstantInt::get(
10925-
Size->getType(), (gutils->newFunc->getParent()
10926-
->getDataLayout()
10927-
.getTypeAllocSizeInBits(elTy) +
10928-
7) /
10929-
8);
10930-
Size = B.CreateUDiv(Size, tsize, "", /*exact*/ true);
10931-
I = gutils->getNewFromOriginal(cast<Instruction>(U));
10932-
break;
10933-
}
10934-
}
10935-
#if LLVM_VERSION_MAJOR >= 15
10936-
}
10937-
#endif
10938-
10939-
Value *replacement = B.CreateAlloca(elTy, Size);
10940-
if (I)
10941-
replacement->takeName(I);
10942-
else
10943-
replacement->takeName(newCall);
10944-
auto Alignment =
10945-
cast<ConstantInt>(
10946-
cast<ConstantAsMetadata>(MD->getOperand(0))->getValue())
10947-
->getLimitedValue();
10948-
// Don't set zero alignment
10949-
if (Alignment) {
10950-
#if LLVM_VERSION_MAJOR >= 10
10951-
cast<AllocaInst>(replacement)->setAlignment(Align(Alignment));
10952-
#else
10953-
cast<AllocaInst>(replacement)->setAlignment(Alignment);
10954-
#endif
10955-
}
10956-
10957-
#if LLVM_VERSION_MAJOR >= 15
10958-
if (orig->getContext().supportsTypedPointers()) {
10959-
#endif
10960-
if (orig->getType()->getPointerElementType() != elTy)
10961-
replacement = B.CreatePointerCast(
10962-
replacement,
10963-
PointerType::getUnqual(
10964-
orig->getType()->getPointerElementType()));
10965-
10966-
#if LLVM_VERSION_MAJOR >= 15
10967-
}
10968-
#endif
10969-
if (int AS =
10970-
cast<PointerType>(orig->getType())->getAddressSpace()) {
10971-
llvm::PointerType *PT;
10972-
#if LLVM_VERSION_MAJOR >= 15
10973-
if (orig->getContext().supportsTypedPointers()) {
10974-
#endif
10975-
PT = PointerType::get(
10976-
orig->getType()->getPointerElementType(), AS);
10977-
#if LLVM_VERSION_MAJOR >= 15
10978-
} else {
10979-
PT = PointerType::get(orig->getContext(), AS);
10980-
}
10981-
#endif
10982-
replacement = B.CreateAddrSpaceCast(replacement, PT);
10983-
cast<Instruction>(replacement)
10984-
->setMetadata("enzyme_backstack",
10985-
MDNode::get(replacement->getContext(), {}));
10986-
}
10987-
10988-
gutils->replaceAWithB(newCall, replacement);
10989-
gutils->erase(newCall);
10894+
restoreFromStack(MD);
1099010895
}
1099110896
return;
1099210897
}
@@ -11004,92 +10909,7 @@ class AdjointGenerator
1100410909
eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
1100510910
} else {
1100610911
if (auto MD = hasMetadata(orig, "enzyme_fromstack")) {
11007-
IRBuilder<> B(newCall);
11008-
Value *Size;
11009-
if (funcName == "malloc")
11010-
Size = orig->getArgOperand(0);
11011-
else if (funcName == "julia.gc_alloc_obj" ||
11012-
funcName == "jl_gc_alloc_typed" ||
11013-
funcName == "ijl_gc_alloc_typed")
11014-
Size = orig->getArgOperand(1);
11015-
else
11016-
llvm_unreachable("Unknown allocation to upgrade");
11017-
Size = gutils->getNewFromOriginal(Size);
11018-
11019-
if (auto CI = dyn_cast<ConstantInt>(Size)) {
11020-
B.SetInsertPoint(gutils->inversionAllocs);
11021-
}
11022-
Type *elTy = Type::getInt8Ty(orig->getContext());
11023-
Instruction *I = nullptr;
11024-
#if LLVM_VERSION_MAJOR >= 15
11025-
if (orig->getContext().supportsTypedPointers()) {
11026-
#endif
11027-
for (auto U : orig->users()) {
11028-
if (hasMetadata(cast<Instruction>(U), "enzyme_caststack")) {
11029-
elTy = U->getType()->getPointerElementType();
11030-
Value *tsize = ConstantInt::get(
11031-
Size->getType(), (gutils->newFunc->getParent()
11032-
->getDataLayout()
11033-
.getTypeAllocSizeInBits(elTy) +
11034-
7) /
11035-
8);
11036-
Size = B.CreateUDiv(Size, tsize, "", /*exact*/ true);
11037-
I = gutils->getNewFromOriginal(cast<Instruction>(U));
11038-
break;
11039-
}
11040-
}
11041-
#if LLVM_VERSION_MAJOR >= 15
11042-
}
11043-
#endif
11044-
Value *replacement = B.CreateAlloca(elTy, Size);
11045-
if (I)
11046-
replacement->takeName(I);
11047-
else
11048-
replacement->takeName(newCall);
11049-
auto Alignment =
11050-
cast<ConstantInt>(
11051-
cast<ConstantAsMetadata>(MD->getOperand(0))->getValue())
11052-
->getLimitedValue();
11053-
// Don't set zero alignment
11054-
if (Alignment) {
11055-
#if LLVM_VERSION_MAJOR >= 10
11056-
cast<AllocaInst>(replacement)->setAlignment(Align(Alignment));
11057-
#else
11058-
cast<AllocaInst>(replacement)->setAlignment(Alignment);
11059-
#endif
11060-
}
11061-
#if LLVM_VERSION_MAJOR >= 15
11062-
if (orig->getContext().supportsTypedPointers()) {
11063-
#endif
11064-
if (orig->getType()->getPointerElementType() != elTy)
11065-
replacement = B.CreatePointerCast(
11066-
replacement, PointerType::getUnqual(
11067-
orig->getType()->getPointerElementType()));
11068-
11069-
#if LLVM_VERSION_MAJOR >= 15
11070-
}
11071-
#endif
11072-
if (int AS =
11073-
cast<PointerType>(orig->getType())->getAddressSpace()) {
11074-
11075-
llvm::PointerType *PT;
11076-
#if LLVM_VERSION_MAJOR >= 15
11077-
if (orig->getContext().supportsTypedPointers()) {
11078-
#endif
11079-
PT = PointerType::get(orig->getType()->getPointerElementType(),
11080-
AS);
11081-
#if LLVM_VERSION_MAJOR >= 15
11082-
} else {
11083-
PT = PointerType::get(orig->getContext(), AS);
11084-
}
11085-
#endif
11086-
replacement = B.CreateAddrSpaceCast(replacement, PT);
11087-
cast<Instruction>(replacement)
11088-
->setMetadata("enzyme_backstack",
11089-
MDNode::get(replacement->getContext(), {}));
11090-
}
11091-
gutils->replaceAWithB(newCall, replacement);
11092-
gutils->erase(newCall);
10912+
restoreFromStack(MD);
1109310913
}
1109410914
}
1109510915
return;

0 commit comments

Comments
 (0)