@@ -10739,115 +10739,110 @@ class AdjointGenerator
10739
10739
}
10740
10740
}
10741
10741
10742
- // Don't erase any store that needs to be preserved for a
10743
- // rematerialization
10744
- {
10745
- auto found = gutils->rematerializableAllocations.find(orig);
10746
- if (found != gutils->rematerializableAllocations.end()) {
10747
- // If rematerializing (e.g. needed in reverse, but not needing
10748
- // the whole allocation):
10749
- if (primalNeededInReverse && !cacheWholeAllocation) {
10750
- // if rematerialize, don't ever cache and downgrade to stack
10751
- // allocation where possible.
10752
- if (auto MD = hasMetadata(orig, "enzyme_fromstack")) {
10753
- if (Mode == DerivativeMode::ReverseModeGradient &&
10754
- found->second.LI) {
10755
- gutils->rematerializedPrimalOrShadowAllocations.push_back(
10756
- newCall);
10757
- } else {
10758
- IRBuilder<> B(newCall);
10759
-
10760
- Value *Size;
10761
- if (funcName == "malloc")
10762
- Size = orig->getArgOperand(0);
10763
- else if (funcName == "julia.gc_alloc_obj" ||
10764
- funcName == "jl_gc_alloc_typed" ||
10765
- funcName == "ijl_gc_alloc_typed")
10766
- Size = orig->getArgOperand(1);
10767
- else
10768
- llvm_unreachable("Unknown allocation to upgrade");
10769
- Size = gutils->getNewFromOriginal(Size);
10770
-
10771
- if (auto CI = dyn_cast<ConstantInt>(Size)) {
10772
- B.SetInsertPoint(gutils->inversionAllocs);
10773
- }
10742
+ std::function<void(MDNode *)> restoreFromStack = [&](MDNode *MD) {
10743
+ IRBuilder<> B(newCall);
10744
+ Value *Size;
10745
+ if (funcName == "malloc")
10746
+ Size = orig->getArgOperand(0);
10747
+ else if (funcName == "julia.gc_alloc_obj" ||
10748
+ funcName == "jl_gc_alloc_typed" ||
10749
+ funcName == "ijl_gc_alloc_typed")
10750
+ Size = orig->getArgOperand(1);
10751
+ else
10752
+ llvm_unreachable("Unknown allocation to upgrade");
10753
+ Size = gutils->getNewFromOriginal(Size);
10774
10754
10775
- Type *elTy = Type::getInt8Ty(orig->getContext());
10776
- Instruction *I = nullptr;
10755
+ if (auto CI = dyn_cast<ConstantInt>(Size)) {
10756
+ B.SetInsertPoint(gutils->inversionAllocs);
10757
+ }
10758
+ Type *elTy = Type::getInt8Ty(orig->getContext());
10759
+ Instruction *I = nullptr;
10777
10760
#if LLVM_VERSION_MAJOR >= 15
10778
- if (orig->getContext().supportsTypedPointers()) {
10779
- #endif
10780
- for (auto U : orig->users()) {
10781
- if (hasMetadata(cast<Instruction>(U), "enzyme_caststack")) {
10782
- elTy = U->getType()->getPointerElementType();
10783
- Value *tsize = ConstantInt::get(
10784
- Size->getType(), (gutils->newFunc->getParent()
10785
- ->getDataLayout()
10786
- .getTypeAllocSizeInBits(elTy) +
10787
- 7) /
10788
- 8);
10789
- Size = B.CreateUDiv(Size, tsize, "", /*exact*/ true);
10790
- I = gutils->getNewFromOriginal(cast<Instruction>(U));
10791
- break;
10792
- }
10793
- }
10761
+ if (orig->getContext().supportsTypedPointers()) {
10762
+ #endif
10763
+ for (auto U : orig->users()) {
10764
+ if (hasMetadata(cast<Instruction>(U), "enzyme_caststack")) {
10765
+ elTy = U->getType()->getPointerElementType();
10766
+ Value *tsize = ConstantInt::get(
10767
+ Size->getType(), (gutils->newFunc->getParent()
10768
+ ->getDataLayout()
10769
+ .getTypeAllocSizeInBits(elTy) +
10770
+ 7) /
10771
+ 8);
10772
+ Size = B.CreateUDiv(Size, tsize, "", /*exact*/ true);
10773
+ I = gutils->getNewFromOriginal(cast<Instruction>(U));
10774
+ break;
10775
+ }
10776
+ }
10794
10777
#if LLVM_VERSION_MAJOR >= 15
10795
- }
10778
+ }
10796
10779
#endif
10797
-
10798
- Value *replacement = B.CreateAlloca(elTy, Size);
10799
- if (I)
10800
- replacement->takeName(I);
10801
- else
10802
- replacement->takeName(newCall);
10803
-
10804
- auto Alignment =
10805
- cast<ConstantInt>(
10806
- cast<ConstantAsMetadata>(MD->getOperand(0))->getValue())
10807
- ->getLimitedValue();
10808
- // Don't set zero alignment
10809
- if (Alignment) {
10780
+ Value *replacement = B.CreateAlloca(elTy, Size);
10781
+ if (I)
10782
+ replacement->takeName(I);
10783
+ else
10784
+ replacement->takeName(newCall);
10785
+ auto Alignment =
10786
+ cast<ConstantInt>(
10787
+ cast<ConstantAsMetadata>(MD->getOperand(0))->getValue())
10788
+ ->getLimitedValue();
10789
+ // Don't set zero alignment
10790
+ if (Alignment) {
10810
10791
#if LLVM_VERSION_MAJOR >= 10
10811
- cast<AllocaInst>(replacement)->setAlignment(Align(Alignment));
10792
+ cast<AllocaInst>(replacement)->setAlignment(Align(Alignment));
10812
10793
#else
10813
- cast<AllocaInst>(replacement)->setAlignment(Alignment);
10794
+ cast<AllocaInst>(replacement)->setAlignment(Alignment);
10814
10795
#endif
10815
- }
10796
+ }
10816
10797
#if LLVM_VERSION_MAJOR >= 15
10817
- if (orig->getContext().supportsTypedPointers()) {
10798
+ if (orig->getContext().supportsTypedPointers()) {
10818
10799
#endif
10819
- if (orig->getType()->getPointerElementType() != elTy)
10820
- replacement = B.CreatePointerCast(
10821
- replacement,
10822
- PointerType::getUnqual(
10823
- orig->getType()->getPointerElementType()));
10800
+ if (orig->getType()->getPointerElementType() != elTy)
10801
+ replacement = B.CreatePointerCast(
10802
+ replacement, PointerType::getUnqual(
10803
+ orig->getType()->getPointerElementType()));
10824
10804
10825
10805
#if LLVM_VERSION_MAJOR >= 15
10826
- }
10806
+ }
10827
10807
#endif
10808
+ if (int AS = cast<PointerType>(orig->getType())->getAddressSpace()) {
10828
10809
10829
- if (int AS =
10830
- cast<PointerType>(orig->getType())->getAddressSpace()) {
10831
-
10832
- llvm::PointerType *PT;
10810
+ llvm::PointerType *PT;
10833
10811
#if LLVM_VERSION_MAJOR >= 15
10834
- if (orig->getContext().supportsTypedPointers()) {
10812
+ if (orig->getContext().supportsTypedPointers()) {
10835
10813
#endif
10836
- PT = PointerType::get(
10837
- orig->getType()->getPointerElementType(), AS);
10814
+ PT = PointerType::get(orig->getType()->getPointerElementType(), AS);
10838
10815
#if LLVM_VERSION_MAJOR >= 15
10839
- } else {
10840
- PT = PointerType::get(orig->getContext(), AS);
10841
- }
10816
+ } else {
10817
+ PT = PointerType::get(orig->getContext(), AS);
10818
+ }
10842
10819
#endif
10843
- replacement = B.CreateAddrSpaceCast(replacement, PT);
10844
- cast<Instruction>(replacement)
10845
- ->setMetadata("enzyme_backstack",
10846
- MDNode::get(replacement->getContext(), {}));
10847
- }
10820
+ replacement = B.CreateAddrSpaceCast(replacement, PT);
10821
+ cast<Instruction>(replacement)
10822
+ ->setMetadata("enzyme_backstack",
10823
+ MDNode::get(replacement->getContext(), {}));
10824
+ }
10825
+ gutils->replaceAWithB(newCall, replacement);
10826
+ gutils->erase(newCall);
10827
+ };
10848
10828
10849
- gutils->replaceAWithB(newCall, replacement);
10850
- gutils->erase(newCall);
10829
+ // Don't erase any store that needs to be preserved for a
10830
+ // rematerialization
10831
+ {
10832
+ auto found = gutils->rematerializableAllocations.find(orig);
10833
+ if (found != gutils->rematerializableAllocations.end()) {
10834
+ // If rematerializing (e.g. needed in reverse, but not needing
10835
+ // the whole allocation):
10836
+ if (primalNeededInReverse && !cacheWholeAllocation) {
10837
+ // if rematerialize, don't ever cache and downgrade to stack
10838
+ // allocation where possible.
10839
+ if (auto MD = hasMetadata(orig, "enzyme_fromstack")) {
10840
+ if (Mode == DerivativeMode::ReverseModeGradient &&
10841
+ found->second.LI) {
10842
+ gutils->rematerializedPrimalOrShadowAllocations.push_back(
10843
+ newCall);
10844
+ } else {
10845
+ restoreFromStack(MD);
10851
10846
}
10852
10847
return;
10853
10848
}
@@ -10896,97 +10891,7 @@ class AdjointGenerator
10896
10891
if (Mode == DerivativeMode::ReverseModeGradient)
10897
10892
eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
10898
10893
else if (auto MD = hasMetadata(orig, "enzyme_fromstack")) {
10899
- IRBuilder<> B(newCall);
10900
-
10901
- Value *Size;
10902
- if (funcName == "malloc")
10903
- Size = orig->getArgOperand(0);
10904
- else if (funcName == "julia.gc_alloc_obj" ||
10905
- funcName == "jl_gc_alloc_typed" ||
10906
- funcName == "ijl_gc_alloc_typed")
10907
- Size = orig->getArgOperand(1);
10908
- else
10909
- llvm_unreachable("Unknown allocation to upgrade");
10910
- Size = gutils->getNewFromOriginal(Size);
10911
-
10912
- if (auto CI = dyn_cast<ConstantInt>(Size)) {
10913
- B.SetInsertPoint(gutils->inversionAllocs);
10914
- }
10915
-
10916
- Type *elTy = Type::getInt8Ty(orig->getContext());
10917
- Instruction *I = nullptr;
10918
- #if LLVM_VERSION_MAJOR >= 15
10919
- if (orig->getContext().supportsTypedPointers()) {
10920
- #endif
10921
- for (auto U : orig->users()) {
10922
- if (hasMetadata(cast<Instruction>(U), "enzyme_caststack")) {
10923
- elTy = U->getType()->getPointerElementType();
10924
- Value *tsize = ConstantInt::get(
10925
- Size->getType(), (gutils->newFunc->getParent()
10926
- ->getDataLayout()
10927
- .getTypeAllocSizeInBits(elTy) +
10928
- 7) /
10929
- 8);
10930
- Size = B.CreateUDiv(Size, tsize, "", /*exact*/ true);
10931
- I = gutils->getNewFromOriginal(cast<Instruction>(U));
10932
- break;
10933
- }
10934
- }
10935
- #if LLVM_VERSION_MAJOR >= 15
10936
- }
10937
- #endif
10938
-
10939
- Value *replacement = B.CreateAlloca(elTy, Size);
10940
- if (I)
10941
- replacement->takeName(I);
10942
- else
10943
- replacement->takeName(newCall);
10944
- auto Alignment =
10945
- cast<ConstantInt>(
10946
- cast<ConstantAsMetadata>(MD->getOperand(0))->getValue())
10947
- ->getLimitedValue();
10948
- // Don't set zero alignment
10949
- if (Alignment) {
10950
- #if LLVM_VERSION_MAJOR >= 10
10951
- cast<AllocaInst>(replacement)->setAlignment(Align(Alignment));
10952
- #else
10953
- cast<AllocaInst>(replacement)->setAlignment(Alignment);
10954
- #endif
10955
- }
10956
-
10957
- #if LLVM_VERSION_MAJOR >= 15
10958
- if (orig->getContext().supportsTypedPointers()) {
10959
- #endif
10960
- if (orig->getType()->getPointerElementType() != elTy)
10961
- replacement = B.CreatePointerCast(
10962
- replacement,
10963
- PointerType::getUnqual(
10964
- orig->getType()->getPointerElementType()));
10965
-
10966
- #if LLVM_VERSION_MAJOR >= 15
10967
- }
10968
- #endif
10969
- if (int AS =
10970
- cast<PointerType>(orig->getType())->getAddressSpace()) {
10971
- llvm::PointerType *PT;
10972
- #if LLVM_VERSION_MAJOR >= 15
10973
- if (orig->getContext().supportsTypedPointers()) {
10974
- #endif
10975
- PT = PointerType::get(
10976
- orig->getType()->getPointerElementType(), AS);
10977
- #if LLVM_VERSION_MAJOR >= 15
10978
- } else {
10979
- PT = PointerType::get(orig->getContext(), AS);
10980
- }
10981
- #endif
10982
- replacement = B.CreateAddrSpaceCast(replacement, PT);
10983
- cast<Instruction>(replacement)
10984
- ->setMetadata("enzyme_backstack",
10985
- MDNode::get(replacement->getContext(), {}));
10986
- }
10987
-
10988
- gutils->replaceAWithB(newCall, replacement);
10989
- gutils->erase(newCall);
10894
+ restoreFromStack(MD);
10990
10895
}
10991
10896
return;
10992
10897
}
@@ -11004,92 +10909,7 @@ class AdjointGenerator
11004
10909
eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
11005
10910
} else {
11006
10911
if (auto MD = hasMetadata(orig, "enzyme_fromstack")) {
11007
- IRBuilder<> B(newCall);
11008
- Value *Size;
11009
- if (funcName == "malloc")
11010
- Size = orig->getArgOperand(0);
11011
- else if (funcName == "julia.gc_alloc_obj" ||
11012
- funcName == "jl_gc_alloc_typed" ||
11013
- funcName == "ijl_gc_alloc_typed")
11014
- Size = orig->getArgOperand(1);
11015
- else
11016
- llvm_unreachable("Unknown allocation to upgrade");
11017
- Size = gutils->getNewFromOriginal(Size);
11018
-
11019
- if (auto CI = dyn_cast<ConstantInt>(Size)) {
11020
- B.SetInsertPoint(gutils->inversionAllocs);
11021
- }
11022
- Type *elTy = Type::getInt8Ty(orig->getContext());
11023
- Instruction *I = nullptr;
11024
- #if LLVM_VERSION_MAJOR >= 15
11025
- if (orig->getContext().supportsTypedPointers()) {
11026
- #endif
11027
- for (auto U : orig->users()) {
11028
- if (hasMetadata(cast<Instruction>(U), "enzyme_caststack")) {
11029
- elTy = U->getType()->getPointerElementType();
11030
- Value *tsize = ConstantInt::get(
11031
- Size->getType(), (gutils->newFunc->getParent()
11032
- ->getDataLayout()
11033
- .getTypeAllocSizeInBits(elTy) +
11034
- 7) /
11035
- 8);
11036
- Size = B.CreateUDiv(Size, tsize, "", /*exact*/ true);
11037
- I = gutils->getNewFromOriginal(cast<Instruction>(U));
11038
- break;
11039
- }
11040
- }
11041
- #if LLVM_VERSION_MAJOR >= 15
11042
- }
11043
- #endif
11044
- Value *replacement = B.CreateAlloca(elTy, Size);
11045
- if (I)
11046
- replacement->takeName(I);
11047
- else
11048
- replacement->takeName(newCall);
11049
- auto Alignment =
11050
- cast<ConstantInt>(
11051
- cast<ConstantAsMetadata>(MD->getOperand(0))->getValue())
11052
- ->getLimitedValue();
11053
- // Don't set zero alignment
11054
- if (Alignment) {
11055
- #if LLVM_VERSION_MAJOR >= 10
11056
- cast<AllocaInst>(replacement)->setAlignment(Align(Alignment));
11057
- #else
11058
- cast<AllocaInst>(replacement)->setAlignment(Alignment);
11059
- #endif
11060
- }
11061
- #if LLVM_VERSION_MAJOR >= 15
11062
- if (orig->getContext().supportsTypedPointers()) {
11063
- #endif
11064
- if (orig->getType()->getPointerElementType() != elTy)
11065
- replacement = B.CreatePointerCast(
11066
- replacement, PointerType::getUnqual(
11067
- orig->getType()->getPointerElementType()));
11068
-
11069
- #if LLVM_VERSION_MAJOR >= 15
11070
- }
11071
- #endif
11072
- if (int AS =
11073
- cast<PointerType>(orig->getType())->getAddressSpace()) {
11074
-
11075
- llvm::PointerType *PT;
11076
- #if LLVM_VERSION_MAJOR >= 15
11077
- if (orig->getContext().supportsTypedPointers()) {
11078
- #endif
11079
- PT = PointerType::get(orig->getType()->getPointerElementType(),
11080
- AS);
11081
- #if LLVM_VERSION_MAJOR >= 15
11082
- } else {
11083
- PT = PointerType::get(orig->getContext(), AS);
11084
- }
11085
- #endif
11086
- replacement = B.CreateAddrSpaceCast(replacement, PT);
11087
- cast<Instruction>(replacement)
11088
- ->setMetadata("enzyme_backstack",
11089
- MDNode::get(replacement->getContext(), {}));
11090
- }
11091
- gutils->replaceAWithB(newCall, replacement);
11092
- gutils->erase(newCall);
10912
+ restoreFromStack(MD);
11093
10913
}
11094
10914
}
11095
10915
return;
0 commit comments