@@ -191,8 +191,13 @@ class AdjointGenerator
191191 AL =
192192 AL.addParamAttribute (DT->getContext (), 1 , Attribute::AttrKind::NonNull);
193193#if LLVM_VERSION_MAJOR >= 14
194+ #if LLVM_VERSION_MAJOR >= 16
195+ AL = AL.addAttributeAtIndex (DT->getContext (), AttributeList::FunctionIndex,
196+ Attribute::AttrKind::Memory);
197+ #else
194198 AL = AL.addAttributeAtIndex (DT->getContext (), AttributeList::FunctionIndex,
195199 Attribute::AttrKind::ArgMemOnly);
200+ #endif
196201 AL = AL.addAttributeAtIndex (DT->getContext (), AttributeList::FunctionIndex,
197202 Attribute::AttrKind::NoUnwind);
198203 AL = AL.addAttributeAtIndex (DT->getContext (), AttributeList::FunctionIndex,
@@ -5373,7 +5378,7 @@ class AdjointGenerator
53735378 nextTypeInfo.Return = TR.query (&call);
53745379 }
53755380
5376- // llvm::Optional <std::map<std::pair<Instruction*, std::string>, unsigned>>
5381+ // std::optional <std::map<std::pair<Instruction*, std::string>, unsigned>>
53775382 // sub_index_map;
53785383 // Optional<int> tapeIdx;
53795384 // Optional<int> returnIdx;
@@ -8501,23 +8506,34 @@ class AdjointGenerator
85018506
85028507 args.push_back (gutils->invertPointerM (call.getArgOperand (i), Builder2));
85038508 }
8504-
8509+ #if LLVM_VERSION_MAJOR >= 16
8510+ std::optional<int > tapeIdx;
8511+ #else
85058512 Optional<int > tapeIdx;
8513+ #endif
85068514 if (subdata) {
85078515 auto found = subdata->returns .find (AugmentedStruct::Tape);
85088516 if (found != subdata->returns .end ()) {
85098517 tapeIdx = found->second ;
85108518 }
85118519 }
85128520 Value *tape = nullptr ;
8521+ #if LLVM_VERSION_MAJOR >= 16
8522+ if (tapeIdx.has_value ()) {
8523+ #else
85138524 if (tapeIdx.hasValue ()) {
8525+ #endif
85148526
85158527 FunctionType *FT = subdata->fn ->getFunctionType ();
85168528
85178529 tape = BuilderZ.CreatePHI (
85188530 (tapeIdx == -1 ) ? FT->getReturnType ()
85198531 : cast<StructType>(FT->getReturnType ())
8532+ #if LLVM_VERSION_MAJOR >= 16
8533+ ->getElementType (tapeIdx.value ()),
8534+ #else
85208535 ->getElementType (tapeIdx.getValue ()),
8536+ #endif
85218537 1 , " tapeArg" );
85228538
85238539 assert (!tape->getType ()->isEmptyTy ());
@@ -8861,12 +8877,17 @@ class AdjointGenerator
88618877 CallInst *augmentcall = nullptr ;
88628878 Value *cachereplace = nullptr ;
88638879
8864- // llvm::Optional <std::map<std::pair<Instruction*, std::string>,
8880+ // std::optional <std::map<std::pair<Instruction*, std::string>,
88658881 // unsigned>> sub_index_map;
8882+ #if LLVM_VERSION_MAJOR >= 16
8883+ std::optional<int > tapeIdx;
8884+ std::optional<int > returnIdx;
8885+ std::optional<int > differetIdx;
8886+ #else
88668887 Optional<int > tapeIdx;
88678888 Optional<int > returnIdx;
88688889 Optional<int > differetIdx;
8869-
8890+ # endif
88708891 if (modifyPrimal) {
88718892
88728893 Value *newcalled = nullptr ;
@@ -9057,11 +9078,20 @@ class AdjointGenerator
90579078 if (!augmentcall->getType ()->isVoidTy ())
90589079 augmentcall->setName (call.getName () + " _augmented" );
90599080
9081+ #if LLVM_VERSION_MAJOR >= 16
9082+ if (tapeIdx.has_value ()) {
9083+ tape = (tapeIdx.value () == -1 )
9084+ #else
90609085 if (tapeIdx.hasValue ()) {
90619086 tape = (tapeIdx.getValue () == -1 )
9087+ #endif
90629088 ? augmentcall
90639089 : BuilderZ.CreateExtractValue (
9090+ #if LLVM_VERSION_MAJOR >= 16
9091+ augmentcall, {(unsigned )tapeIdx.value ()},
9092+ #else
90649093 augmentcall, {(unsigned )tapeIdx.getValue ()},
9094+ #endif
90659095 " subcache" );
90669096 if (tape->getType ()->isEmptyTy ()) {
90679097 auto tt = tape->getType ();
@@ -9078,10 +9108,17 @@ class AdjointGenerator
90789108 Value *dcall = nullptr ;
90799109 assert (returnIdx);
90809110 assert (augmentcall);
9111+ #if LLVM_VERSION_MAJOR >= 16
9112+ dcall = (returnIdx.value () < 0 )
9113+ ? augmentcall
9114+ : BuilderZ.CreateExtractValue (
9115+ augmentcall, {(unsigned )returnIdx.value ()});
9116+ #else
90819117 dcall = (returnIdx.getValue () < 0 )
90829118 ? augmentcall
90839119 : BuilderZ.CreateExtractValue (
90849120 augmentcall, {(unsigned )returnIdx.getValue ()});
9121+ #endif
90859122 gutils->originalToNewFn [&call] = dcall;
90869123 gutils->newToOriginalFn .erase (newCall);
90879124 gutils->newToOriginalFn [dcall] = &call;
@@ -9149,12 +9186,21 @@ class AdjointGenerator
91499186 // assert(!tape);
91509187 // assert(subdata);
91519188 if (!tape) {
9189+ #if LLVM_VERSION_MAJOR >= 16
9190+ assert (tapeIdx.has_value ());
9191+ tape = BuilderZ.CreatePHI (
9192+ (tapeIdx == -1 ) ? FT->getReturnType ()
9193+ : cast<StructType>(FT->getReturnType ())
9194+ ->getElementType (tapeIdx.value ()),
9195+ 1 , " tapeArg" );
9196+ #else
91529197 assert (tapeIdx.hasValue ());
91539198 tape = BuilderZ.CreatePHI (
91549199 (tapeIdx == -1 ) ? FT->getReturnType ()
91559200 : cast<StructType>(FT->getReturnType ())
91569201 ->getElementType (tapeIdx.getValue ()),
91579202 1 , " tapeArg" );
9203+ #endif
91589204 }
91599205 tape = gutils->cacheForReverse (BuilderZ, tape,
91609206 getIndex (&call, CacheType::Tape));
@@ -9206,11 +9252,19 @@ class AdjointGenerator
92069252 Value *newip = nullptr ;
92079253 if (Mode == DerivativeMode::ReverseModeCombined ||
92089254 Mode == DerivativeMode::ReverseModePrimal) {
9255+ #if LLVM_VERSION_MAJOR >= 16
9256+ newip = (differetIdx.value () < 0 )
9257+ ? augmentcall
9258+ : BuilderZ.CreateExtractValue (
9259+ augmentcall, {(unsigned )differetIdx.value ()},
9260+ call.getName () + " 'ac" );
9261+ #else
92099262 newip = (differetIdx.getValue () < 0 )
92109263 ? augmentcall
92119264 : BuilderZ.CreateExtractValue (
92129265 augmentcall, {(unsigned )differetIdx.getValue ()},
92139266 call.getName () + " 'ac" );
9267+ #endif
92149268 assert (newip->getType () == call.getType ());
92159269 placeholder->replaceAllUsesWith (newip);
92169270 if (placeholder == &*BuilderZ.GetInsertPoint ()) {
@@ -12950,7 +13004,11 @@ class AdjointGenerator
1295013004 /* tryLegalRecompute*/ false );
1295113005 auto freeCall = cast<CallInst>(
1295213006 CallInst::CreateFree (tofree, Builder2.GetInsertBlock ()));
13007+ #if LLVM_VERSION_MAJOR >= 16
13008+ freeCall->insertInto (Builder2.GetInsertBlock (), Builder2.GetInsertBlock ()->end ());
13009+ #else
1295313010 Builder2.GetInsertBlock ()->getInstList ().push_back (freeCall);
13011+ #endif
1295413012 }
1295513013 }
1295613014 }
@@ -12989,7 +13047,11 @@ class AdjointGenerator
1298913047 gutils->lookupM (load, Builder2, ValueToValueMapTy (),
1299013048 /* tryLegal*/ false ),
1299113049 Builder2.GetInsertBlock ()));
13050+ #if LLVM_VERSION_MAJOR >= 16
13051+ freeCall->insertInto (Builder2.GetInsertBlock (), Builder2.GetInsertBlock ()->end ());
13052+ #else
1299213053 Builder2.GetInsertBlock ()->getInstList ().push_back (freeCall);
13054+ #endif
1299313055 }
1299413056
1299513057 return ;
0 commit comments