Skip to content

Commit 3afb3b4

Browse files
authored
Lifetime (EnzymeAD#2402)
* Fix for new lifetime intrinsic * bazver * fix * fmt * fix * fmt * fmt * fmt * Don't use out of bounds comparisons * fmt
1 parent 965e491 commit 3afb3b4

File tree

9 files changed

+189
-28
lines changed

9 files changed

+189
-28
lines changed

enzyme/.bazelversion

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
6.5.0
1+
7.4.1

enzyme/Enzyme/ActivityAnalysis.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,8 @@ const StringSet<> KnownInactiveFunctions = {
301301
"__ubsan_handle_pointer_overflow",
302302
"__ubsan_handle_type_mismatch_v1",
303303
"__ubsan_vptr_type_cache",
304+
"llvm.enzyme.lifetime_start",
305+
"llvm.enzyme.lifetime_end",
304306
};
305307

306308
const std::set<Intrinsic::ID> KnownInactiveIntrinsics = {

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2988,6 +2988,18 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
29882988
break;
29892989
if (auto MCI = dyn_cast<ConstantInt>(MS.getOperand(2))) {
29902990
if (auto II = dyn_cast<IntrinsicInst>(cur)) {
2991+
if (II->getCalledFunction()->getName() ==
2992+
"llvm.enzyme.lifetime_start") {
2993+
if (getBaseObject(II->getOperand(1)) == root) {
2994+
if (auto CI2 =
2995+
dyn_cast<ConstantInt>(II->getOperand(0))) {
2996+
if (MCI->getValue().ule(CI2->getValue()))
2997+
break;
2998+
}
2999+
}
3000+
cur = cur->getPrevNode();
3001+
continue;
3002+
}
29913003
// If the start of the lifetime for more memory than being
29923004
// memset, its valid.
29933005
if (II->getIntrinsicID() == Intrinsic::lifetime_start) {
@@ -3709,7 +3721,8 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
37093721
return;
37103722
}
37113723
if (II.getIntrinsicID() == Intrinsic::stackrestore ||
3712-
II.getIntrinsicID() == Intrinsic::lifetime_end) {
3724+
II.getIntrinsicID() == Intrinsic::lifetime_end ||
3725+
II.getCalledFunction()->getName() == "llvm.enzyme.lifetime_end") {
37133726
eraseIfUnused(II, /*erase*/ true, /*check*/ false);
37143727
return;
37153728
}
@@ -6068,17 +6081,33 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
60686081
void visitCallInst(llvm::CallInst &call) {
60696082
using namespace llvm;
60706083

6084+
StringRef funcName = getFuncNameFromCall(&call);
6085+
60716086
// When compiling Enzyme against standard LLVM, and not Intel's
60726087
// modified version of LLVM, the intrinsic `llvm.intel.subscript` is
60736088
// not fully understood by LLVM. One of the results of this is that the
60746089
// visitor dispatches to visitCallInst, rather than visitIntrinsicInst, when
60756090
// presented with the intrinsic - hence why we are handling it here.
6076-
if (startsWith(getFuncNameFromCall(&call), ("llvm.intel.subscript"))) {
6091+
if (startsWith(funcName, ("llvm.intel.subscript"))) {
60776092
assert(isa<IntrinsicInst>(call));
60786093
visitIntrinsicInst(cast<IntrinsicInst>(call));
60796094
return;
60806095
}
60816096

6097+
if (funcName == "llvm.enzyme.lifetime_start") {
6098+
visitIntrinsicInst(cast<IntrinsicInst>(call));
6099+
return;
6100+
}
6101+
if (funcName == "llvm.enzyme.lifetime_end") {
6102+
SmallVector<Value *, 2> orig_ops(call.getNumOperands());
6103+
for (unsigned i = 0; i < call.getNumOperands(); ++i) {
6104+
orig_ops[i] = call.getOperand(i);
6105+
}
6106+
handleAdjointForIntrinsic(Intrinsic::lifetime_end, call, orig_ops);
6107+
eraseIfUnused(call);
6108+
return;
6109+
}
6110+
60826111
CallInst *const newCall = cast<CallInst>(gutils->getNewFromOriginal(&call));
60836112
IRBuilder<> BuilderZ(newCall);
60846113
BuilderZ.setFastMathFlags(getFast());
@@ -6107,7 +6136,6 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
61076136
: overwritten_args_map.find(&call)->second.second;
61086137

61096138
auto called = getFunctionFromCall(&call);
6110-
StringRef funcName = getFuncNameFromCall(&call);
61116139

61126140
bool subretused = false;
61136141
bool shadowReturnUsed = false;

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,12 @@ void calculateUnusedValuesInFunction(
870870
},
871871
[&](const Instruction *inst) {
872872
if (auto II = dyn_cast<IntrinsicInst>(inst)) {
873+
if (II->getCalledFunction()->getName() ==
874+
"llvm.enzyme.lifetime_start" ||
875+
II->getCalledFunction()->getName() ==
876+
"llvm.enzyme.lifetime_end") {
877+
return UseReq::Cached;
878+
}
873879
if (II->getIntrinsicID() == Intrinsic::lifetime_start ||
874880
II->getIntrinsicID() == Intrinsic::lifetime_end ||
875881
II->getIntrinsicID() == Intrinsic::stacksave ||
@@ -6636,7 +6642,9 @@ llvm::Function *EnzymeLogic::CreateNoFree(RequestContext context, Function *F) {
66366642
"__assertfail",
66376643
"__kmpc_global_thread_num",
66386644
"nlopt_force_stop",
6639-
"cudaRuntimeGetVersion"
6645+
"cudaRuntimeGetVersion",
6646+
"llvm.enzyme.lifetime_start",
6647+
"llvm.enzyme.lifetime_end",
66406648
};
66416649
// clang-format on
66426650

enzyme/Enzyme/FunctionUtils.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,10 +534,64 @@ UpgradeAllocasToMallocs(Function *NewF, DerivativeMode mode,
534534
}
535535
}
536536

537+
#if LLVM_VERSION_MAJOR >= 22
538+
Function *start_lifetime = nullptr;
539+
Function *end_lifetime = nullptr;
540+
#endif
541+
537542
for (auto AI : ToConvert) {
538543
std::string nam = AI->getName().str();
539544
AI->setName("");
540545

546+
#if LLVM_VERSION_MAJOR >= 22
547+
for (auto U : llvm::make_early_inc_range(AI->users())) {
548+
if (auto II = dyn_cast<IntrinsicInst>(U)) {
549+
if (II->getIntrinsicID() == Intrinsic::lifetime_start) {
550+
if (!start_lifetime) {
551+
start_lifetime = cast<Function>(
552+
NewF->getParent()
553+
->getOrInsertFunction(
554+
"llvm.enzyme.lifetime_start",
555+
FunctionType::get(Type::getVoidTy(NewF->getContext()),
556+
{}, true))
557+
.getCallee());
558+
}
559+
IRBuilder<> B(II);
560+
SmallVector<Value *, 2> args(II->arg_size());
561+
for (unsigned i = 0; i < II->arg_size(); ++i) {
562+
args[i] = II->getArgOperand(i);
563+
}
564+
auto newI = B.CreateCall(start_lifetime, args);
565+
newI->takeName(II);
566+
newI->setDebugLoc(II->getDebugLoc());
567+
II->eraseFromParent();
568+
continue;
569+
}
570+
if (II->getIntrinsicID() == Intrinsic::lifetime_end) {
571+
if (!end_lifetime) {
572+
end_lifetime = cast<Function>(
573+
NewF->getParent()
574+
->getOrInsertFunction(
575+
"llvm.enzyme.lifetime_end",
576+
FunctionType::get(Type::getVoidTy(NewF->getContext()),
577+
{}, true))
578+
.getCallee());
579+
}
580+
IRBuilder<> B(II);
581+
SmallVector<Value *, 2> args(II->arg_size());
582+
for (unsigned i = 0; i < II->arg_size(); ++i) {
583+
args[i] = II->getArgOperand(i);
584+
}
585+
auto newI = B.CreateCall(end_lifetime, args);
586+
newI->takeName(II);
587+
newI->setDebugLoc(II->getDebugLoc());
588+
II->eraseFromParent();
589+
continue;
590+
}
591+
}
592+
}
593+
#endif
594+
541595
// Ensure we insert the malloc after the allocas
542596
Instruction *insertBefore = AI;
543597
while (isa<AllocaInst>(insertBefore->getNextNode())) {
@@ -884,6 +938,45 @@ void PreProcessCache::LowerAllocAddr(Function *NewF) {
884938
#endif
885939
RecursivelyReplaceAddressSpace(T, AIV, /*legal*/ true);
886940
}
941+
942+
#if LLVM_VERSION_MAJOR >= 22
943+
{
944+
auto start_lifetime =
945+
NewF->getParent()->getFunction("llvm.enzyme.lifetime_start");
946+
auto end_lifetime =
947+
NewF->getParent()->getFunction("llvm.enzyme.lifetime_end");
948+
949+
SmallVector<CallInst *, 1> Todo;
950+
for (auto &BB : *NewF) {
951+
for (auto &I : BB) {
952+
if (auto CB = dyn_cast<CallInst>(&I)) {
953+
if (!CB->getCalledFunction())
954+
continue;
955+
if (CB->getCalledFunction() == start_lifetime ||
956+
CB->getCalledFunction() == end_lifetime) {
957+
Todo.push_back(CB);
958+
}
959+
}
960+
}
961+
}
962+
963+
for (auto CB : Todo) {
964+
if (!isa<AllocaInst>(CB->getArgOperand(1))) {
965+
CB->eraseFromParent();
966+
continue;
967+
}
968+
IRBuilder<> B(CB);
969+
if (CB->getCalledFunction() == start_lifetime) {
970+
B.CreateLifetimeStart(CB->getArgOperand(1),
971+
cast<ConstantInt>(CB->getArgOperand(0)));
972+
} else {
973+
B.CreateLifetimeEnd(CB->getArgOperand(1),
974+
cast<ConstantInt>(CB->getArgOperand(0)));
975+
}
976+
CB->eraseFromParent();
977+
}
978+
}
979+
#endif
887980
}
888981

889982
/// Calls to realloc with an appropriate implementation
@@ -7300,6 +7393,9 @@ Constraints::InnerTy Constraints::make_compare(const SCEV *v, bool isEqual,
73007393
ConstraintContext ctx2(ctx.SE, ctx.loopToSolve, noassumption, ctx.DT);
73017394
for (auto I : ctx.Assumptions) {
73027395
bool legal = true;
7396+
if (I->getParent()->getParent() !=
7397+
ctx.loopToSolve->getHeader()->getParent())
7398+
continue;
73037399
auto parsedCond = getSparseConditions(legal, I->getOperand(0),
73047400
Constraints::none(), nullptr, ctx2);
73057401
bool dominates = ctx.DT.dominates(I, ctx.loopToSolve->getHeader());

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8954,32 +8954,38 @@ void GradientUtils::computeForwardingProperties(Instruction *V) {
89548954
storingOps.insert(store);
89558955
}
89568956
} else if (auto II = dyn_cast<IntrinsicInst>(cur)) {
8957-
switch (II->getIntrinsicID()) {
8958-
case Intrinsic::lifetime_start:
8957+
if (II->getCalledFunction()->getName() == "llvm.enzyme.lifetime_start") {
89598958
LifetimeStarts.insert(II);
8960-
break;
8961-
case Intrinsic::dbg_declare:
8962-
case Intrinsic::dbg_value:
8963-
case Intrinsic::dbg_label:
8959+
} else if (II->getCalledFunction()->getName() ==
8960+
"llvm.enzyme.lifetime_end") {
8961+
} else {
8962+
switch (II->getIntrinsicID()) {
8963+
case Intrinsic::lifetime_start:
8964+
LifetimeStarts.insert(II);
8965+
break;
8966+
case Intrinsic::dbg_declare:
8967+
case Intrinsic::dbg_value:
8968+
case Intrinsic::dbg_label:
89648969
#if LLVM_VERSION_MAJOR <= 16
8965-
case llvm::Intrinsic::dbg_addr:
8970+
case llvm::Intrinsic::dbg_addr:
89668971
#endif
8967-
case Intrinsic::lifetime_end:
8968-
break;
8969-
case Intrinsic::memset: {
8970-
stores.insert(II);
8971-
storingOps.insert(II);
8972-
break;
8973-
}
8974-
// TODO memtransfer(cpy/move)
8975-
case Intrinsic::memcpy:
8976-
case Intrinsic::memmove:
8977-
default:
8978-
promotable = false;
8979-
shadowpromotable = false;
8980-
EmitWarning("NotPromotable", *cur, " Could not promote allocation ", *V,
8981-
" due to unknown intrinsic ", *cur);
8982-
break;
8972+
case Intrinsic::lifetime_end:
8973+
break;
8974+
case Intrinsic::memset: {
8975+
stores.insert(II);
8976+
storingOps.insert(II);
8977+
break;
8978+
}
8979+
// TODO memtransfer(cpy/move)
8980+
case Intrinsic::memcpy:
8981+
case Intrinsic::memmove:
8982+
default:
8983+
promotable = false;
8984+
shadowpromotable = false;
8985+
EmitWarning("NotPromotable", *cur, " Could not promote allocation ",
8986+
*V, " due to unknown intrinsic ", *cur);
8987+
break;
8988+
}
89838989
}
89848990
} else if (auto CI = dyn_cast<CallInst>(cur)) {
89858991
StringRef funcName = getFuncNameFromCall(CI);

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,16 @@ extern const llvm::StringMap<llvm::Intrinsic::ID> LIBM_FUNCTIONS;
6161
static inline bool isMemFreeLibMFunction(llvm::StringRef str,
6262
llvm::Intrinsic::ID *ID = nullptr) {
6363
llvm::StringRef ogstr = str;
64+
if (ID) {
65+
if (str == "llvm.enzyme.lifetime_start") {
66+
*ID = llvm::Intrinsic::lifetime_start;
67+
return false;
68+
}
69+
if (str == "llvm.enzyme.lifetime_end") {
70+
*ID = llvm::Intrinsic::lifetime_end;
71+
return false;
72+
}
73+
}
6474
if (startsWith(str, "__") && endsWith(str, "_finite")) {
6575
str = str.substr(2, str.size() - 2 - 7);
6676
} else if (startsWith(str, "__fd_") && endsWith(str, "_1")) {

enzyme/Enzyme/Utils.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2801,6 +2801,9 @@ getAllLoadedValuesFrom(AllocaInst *ptr0, size_t offset, size_t valSz,
28012801
}
28022802

28032803
if (auto II = dyn_cast<IntrinsicInst>(U)) {
2804+
if (II->getCalledFunction()->getName() == "llvm.enzyme.lifetime_start" ||
2805+
II->getCalledFunction()->getName() == "llvm.enzyme.lifetime_end")
2806+
continue;
28042807
if (II->getIntrinsicID() == Intrinsic::lifetime_start ||
28052808
II->getIntrinsicID() == Intrinsic::lifetime_end)
28062809
continue;

enzyme/Enzyme/Utils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,10 @@ static inline llvm::Type *IntToFloatTy(llvm::Type *T) {
644644
static inline bool isDebugFunction(llvm::Function *called) {
645645
if (!called)
646646
return false;
647+
if (called->getName() == "llvm.enzyme.lifetime_start" ||
648+
called->getName() == "llvm.enzyme.lifetime_end") {
649+
return true;
650+
}
647651
switch (called->getIntrinsicID()) {
648652
case llvm::Intrinsic::dbg_declare:
649653
case llvm::Intrinsic::dbg_value:
@@ -1729,6 +1733,10 @@ static inline bool isNoAlias(const llvm::Value *val) {
17291733
static inline bool isNoEscapingAllocation(const llvm::Function *F) {
17301734
if (F->hasFnAttribute("enzyme_no_escaping_allocation"))
17311735
return true;
1736+
if (F->getName() == "llvm.enzyme.lifetime_start" ||
1737+
F->getName() == "llvm.enzyme.lifetime_end") {
1738+
return true;
1739+
}
17321740
using namespace llvm;
17331741
switch (F->getIntrinsicID()) {
17341742
case Intrinsic::memset:

0 commit comments

Comments
 (0)