Skip to content

Commit b954bed

Browse files
stepasitewsmoses
authored andcommitted
compiles with llvm-16 (old PM is gone so I could not run opt) (#1)
* port to llvm-16 (not backward-compatible with older llvms), WIP * wip, but probably not correct on may places * rule out old PM on llvm 16 and above * #if LLVM_VERSION_MAJOR >= 16, WIP * #if LLVM_VERSION_MAJOR >= 16, WIP * #if LLVM_VERSION_MAJOR >= 16, WIP * #if LLVM_VERSION_MAJOR >= 16, WIP * #if LLVM_VERSION_MAJOR >= 16, WIP * #if LLVM_VERSION_MAJOR >= 16
1 parent 194875c commit b954bed

File tree

13 files changed

+328
-24
lines changed

13 files changed

+328
-24
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,13 @@ class AdjointGenerator
191191
AL =
192192
AL.addParamAttribute(DT->getContext(), 1, Attribute::AttrKind::NonNull);
193193
#if LLVM_VERSION_MAJOR >= 14
194+
#if LLVM_VERSION_MAJOR >= 16
195+
AL = AL.addAttributeAtIndex(DT->getContext(), AttributeList::FunctionIndex,
196+
Attribute::AttrKind::Memory);
197+
#else
194198
AL = AL.addAttributeAtIndex(DT->getContext(), AttributeList::FunctionIndex,
195199
Attribute::AttrKind::ArgMemOnly);
200+
#endif
196201
AL = AL.addAttributeAtIndex(DT->getContext(), AttributeList::FunctionIndex,
197202
Attribute::AttrKind::NoUnwind);
198203
AL = AL.addAttributeAtIndex(DT->getContext(), AttributeList::FunctionIndex,
@@ -5373,7 +5378,7 @@ class AdjointGenerator
53735378
nextTypeInfo.Return = TR.query(&call);
53745379
}
53755380

5376-
// llvm::Optional<std::map<std::pair<Instruction*, std::string>, unsigned>>
5381+
// std::optional<std::map<std::pair<Instruction*, std::string>, unsigned>>
53775382
// sub_index_map;
53785383
// Optional<int> tapeIdx;
53795384
// Optional<int> returnIdx;
@@ -8501,23 +8506,34 @@ class AdjointGenerator
85018506

85028507
args.push_back(gutils->invertPointerM(call.getArgOperand(i), Builder2));
85038508
}
8504-
8509+
#if LLVM_VERSION_MAJOR >= 16
8510+
std::optional<int> tapeIdx;
8511+
#else
85058512
Optional<int> tapeIdx;
8513+
#endif
85068514
if (subdata) {
85078515
auto found = subdata->returns.find(AugmentedStruct::Tape);
85088516
if (found != subdata->returns.end()) {
85098517
tapeIdx = found->second;
85108518
}
85118519
}
85128520
Value *tape = nullptr;
8521+
#if LLVM_VERSION_MAJOR >= 16
8522+
if (tapeIdx.has_value()) {
8523+
#else
85138524
if (tapeIdx.hasValue()) {
8525+
#endif
85148526

85158527
FunctionType *FT = subdata->fn->getFunctionType();
85168528

85178529
tape = BuilderZ.CreatePHI(
85188530
(tapeIdx == -1) ? FT->getReturnType()
85198531
: cast<StructType>(FT->getReturnType())
8532+
#if LLVM_VERSION_MAJOR >= 16
8533+
->getElementType(tapeIdx.value()),
8534+
#else
85208535
->getElementType(tapeIdx.getValue()),
8536+
#endif
85218537
1, "tapeArg");
85228538

85238539
assert(!tape->getType()->isEmptyTy());
@@ -8861,12 +8877,17 @@ class AdjointGenerator
88618877
CallInst *augmentcall = nullptr;
88628878
Value *cachereplace = nullptr;
88638879

8864-
// llvm::Optional<std::map<std::pair<Instruction*, std::string>,
8880+
// std::optional<std::map<std::pair<Instruction*, std::string>,
88658881
// unsigned>> sub_index_map;
8882+
#if LLVM_VERSION_MAJOR >= 16
8883+
std::optional<int> tapeIdx;
8884+
std::optional<int> returnIdx;
8885+
std::optional<int> differetIdx;
8886+
#else
88668887
Optional<int> tapeIdx;
88678888
Optional<int> returnIdx;
88688889
Optional<int> differetIdx;
8869-
8890+
#endif
88708891
if (modifyPrimal) {
88718892

88728893
Value *newcalled = nullptr;
@@ -9057,11 +9078,20 @@ class AdjointGenerator
90579078
if (!augmentcall->getType()->isVoidTy())
90589079
augmentcall->setName(call.getName() + "_augmented");
90599080

9081+
#if LLVM_VERSION_MAJOR >= 16
9082+
if (tapeIdx.has_value()) {
9083+
tape = (tapeIdx.value() == -1)
9084+
#else
90609085
if (tapeIdx.hasValue()) {
90619086
tape = (tapeIdx.getValue() == -1)
9087+
#endif
90629088
? augmentcall
90639089
: BuilderZ.CreateExtractValue(
9090+
#if LLVM_VERSION_MAJOR >= 16
9091+
augmentcall, {(unsigned)tapeIdx.value()},
9092+
#else
90649093
augmentcall, {(unsigned)tapeIdx.getValue()},
9094+
#endif
90659095
"subcache");
90669096
if (tape->getType()->isEmptyTy()) {
90679097
auto tt = tape->getType();
@@ -9078,10 +9108,17 @@ class AdjointGenerator
90789108
Value *dcall = nullptr;
90799109
assert(returnIdx);
90809110
assert(augmentcall);
9111+
#if LLVM_VERSION_MAJOR >= 16
9112+
dcall = (returnIdx.value() < 0)
9113+
? augmentcall
9114+
: BuilderZ.CreateExtractValue(
9115+
augmentcall, {(unsigned)returnIdx.value()});
9116+
#else
90819117
dcall = (returnIdx.getValue() < 0)
90829118
? augmentcall
90839119
: BuilderZ.CreateExtractValue(
90849120
augmentcall, {(unsigned)returnIdx.getValue()});
9121+
#endif
90859122
gutils->originalToNewFn[&call] = dcall;
90869123
gutils->newToOriginalFn.erase(newCall);
90879124
gutils->newToOriginalFn[dcall] = &call;
@@ -9149,12 +9186,21 @@ class AdjointGenerator
91499186
// assert(!tape);
91509187
// assert(subdata);
91519188
if (!tape) {
9189+
#if LLVM_VERSION_MAJOR >= 16
9190+
assert(tapeIdx.has_value());
9191+
tape = BuilderZ.CreatePHI(
9192+
(tapeIdx == -1) ? FT->getReturnType()
9193+
: cast<StructType>(FT->getReturnType())
9194+
->getElementType(tapeIdx.value()),
9195+
1, "tapeArg");
9196+
#else
91529197
assert(tapeIdx.hasValue());
91539198
tape = BuilderZ.CreatePHI(
91549199
(tapeIdx == -1) ? FT->getReturnType()
91559200
: cast<StructType>(FT->getReturnType())
91569201
->getElementType(tapeIdx.getValue()),
91579202
1, "tapeArg");
9203+
#endif
91589204
}
91599205
tape = gutils->cacheForReverse(BuilderZ, tape,
91609206
getIndex(&call, CacheType::Tape));
@@ -9206,11 +9252,19 @@ class AdjointGenerator
92069252
Value *newip = nullptr;
92079253
if (Mode == DerivativeMode::ReverseModeCombined ||
92089254
Mode == DerivativeMode::ReverseModePrimal) {
9255+
#if LLVM_VERSION_MAJOR >= 16
9256+
newip = (differetIdx.value() < 0)
9257+
? augmentcall
9258+
: BuilderZ.CreateExtractValue(
9259+
augmentcall, {(unsigned)differetIdx.value()},
9260+
call.getName() + "'ac");
9261+
#else
92099262
newip = (differetIdx.getValue() < 0)
92109263
? augmentcall
92119264
: BuilderZ.CreateExtractValue(
92129265
augmentcall, {(unsigned)differetIdx.getValue()},
92139266
call.getName() + "'ac");
9267+
#endif
92149268
assert(newip->getType() == call.getType());
92159269
placeholder->replaceAllUsesWith(newip);
92169270
if (placeholder == &*BuilderZ.GetInsertPoint()) {
@@ -12950,7 +13004,11 @@ class AdjointGenerator
1295013004
/*tryLegalRecompute*/ false);
1295113005
auto freeCall = cast<CallInst>(
1295213006
CallInst::CreateFree(tofree, Builder2.GetInsertBlock()));
13007+
#if LLVM_VERSION_MAJOR >= 16
13008+
freeCall->insertInto(Builder2.GetInsertBlock(), Builder2.GetInsertBlock()->end());
13009+
#else
1295313010
Builder2.GetInsertBlock()->getInstList().push_back(freeCall);
13011+
#endif
1295413012
}
1295513013
}
1295613014
}
@@ -12989,7 +13047,11 @@ class AdjointGenerator
1298913047
gutils->lookupM(load, Builder2, ValueToValueMapTy(),
1299013048
/*tryLegal*/ false),
1299113049
Builder2.GetInsertBlock()));
13050+
#if LLVM_VERSION_MAJOR >= 16
13051+
freeCall->insertInto(Builder2.GetInsertBlock(), Builder2.GetInsertBlock()->end());
13052+
#else
1299213053
Builder2.GetInsertBlock()->getInstList().push_back(freeCall);
13054+
#endif
1299313055
}
1299413056

1299513057
return;

enzyme/Enzyme/Clang/EnzymePassLoader.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
//
2424
//===----------------------------------------------------------------------===//
2525

26+
#include "llvm/Config/llvm-config.h"
27+
28+
#if LLVM_VERSION_MAJOR < 16
29+
2630
#include "llvm/IR/LegacyPassManager.h"
2731
#include "llvm/Transforms/IPO.h"
2832
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
@@ -85,4 +89,7 @@ static void loadLTOPass(const PassManagerBuilder &Builder,
8589
static RegisterStandardPasses
8690
clangtoolLoader_LTO(PassManagerBuilder::EP_FullLinkTimeOptimizationEarly,
8791
loadLTOPass);
88-
#endif
92+
93+
#endif // LLVM_VERSION_MAJOR >= 9
94+
95+
#endif // LLVM_VERSION_MAJOR < 16

enzyme/Enzyme/DiffeGradientUtils.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -814,9 +814,14 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig,
814814
MaybeAlign alignv = align;
815815
if (alignv) {
816816
if (start != 0) {
817-
assert(alignv.getValue().value() != 0);
818817
// todo make better alignment calculation
818+
#if LLVM_VERSION_MAJOR >= 16
819+
assert(alignv.value().value() != 0);
820+
if (start % alignv.value().value() != 0) {
821+
#else
822+
assert(alignv.getValue().value() != 0);
819823
if (start % alignv.getValue().value() != 0) {
824+
#endif
820825
alignv = Align(1);
821826
}
822827
}
@@ -852,9 +857,14 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig,
852857
MaybeAlign alignv = align;
853858
if (alignv) {
854859
if (start != 0) {
855-
assert(alignv.getValue().value() != 0);
856860
// todo make better alignment calculation
861+
#if LLVM_VERSION_MAJOR >= 16
862+
assert(alignv.value().value() != 0);
863+
if (start % alignv.value().value() != 0) {
864+
#else
865+
assert(alignv.getValue().value() != 0);
857866
if (start % alignv.getValue().value() != 0) {
867+
#endif
858868
alignv = Align(1);
859869
}
860870
}
@@ -946,7 +956,9 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig,
946956
st->setDebugLoc(getNewFromOriginal(orig->getDebugLoc()));
947957

948958
if (align) {
949-
#if LLVM_VERSION_MAJOR >= 10
959+
#if LLVM_VERSION_MAJOR >= 16
960+
auto alignv = align ? align.value().value() : 0;
961+
#elif LLVM_VERSION_MAJOR >= 10
950962
auto alignv = align ? align.getValue().value() : 0;
951963
#else
952964
auto alignv = align;

0 commit comments

Comments
 (0)