Skip to content

Commit 7509a6e

Browse files
committed
[CAPI] Handle command line args
1 parent ae17378 commit 7509a6e

File tree

2 files changed

+72
-4
lines changed

2 files changed

+72
-4
lines changed

enzyme/Enzyme/CApi.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,18 @@ FnTypeInfo eunwrap(CFnTypeInfo CTI, llvm::Function *F) {
140140
extern "C" {
141141

142142
void EnzymeSetCLBool(void *ptr, uint8_t val) {
143-
auto &cl = (llvm::cl::opt<bool> &)ptr;
144-
cl.setValue((bool)val);
143+
auto cl = (llvm::cl::opt<bool> *)ptr;
144+
cl->setValue((bool)val);
145+
}
146+
147+
uint8_t EnzymeGetCLBool(void *ptr) {
148+
auto cl = (llvm::cl::opt<bool> *)ptr;
149+
return (uint8_t)(bool)cl->getValue();
145150
}
146151

147152
void EnzymeSetCLInteger(void *ptr, int64_t val) {
148-
auto &cl = (llvm::cl::opt<int> &)ptr;
149-
cl.setValue((int)val);
153+
auto cl = (llvm::cl::opt<int> *)ptr;
154+
cl->setValue((int)val);
150155
}
151156

152157
EnzymeLogicRef CreateEnzymeLogic() {

enzyme/Enzyme/FunctionUtils.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,6 +1524,69 @@ Function *PreProcessCache::CloneFunctionWithReturns(
15241524
return NewF;
15251525
}
15261526

1527+
void CoaleseTrivialMallocs(Function &F, DominatorTree &DT) {
1528+
std::map<BasicBlock *, std::vector<std::pair<CallInst *, CallInst *>>>
1529+
LegalMallocs;
1530+
for (BasicBlock &BB : F) {
1531+
for (Instruction &I : BB) {
1532+
if (auto CI = dyn_cast<CallInst>(&I)) {
1533+
if (auto F = CI->getCalledFunction()) {
1534+
if (F->getName() == "malloc") {
1535+
for (auto U : CI->users()) {
1536+
if (auto CI2 = dyn_cast<CallInst>(U)) {
1537+
if (auto F2 = CI2->getCalledFunction()) {
1538+
if (F2->getName() == "free") {
1539+
if (DT.dominates(CI, CI2)) {
1540+
LegalMallocs[&BB].emplace_back(CI, CI2);
1541+
}
1542+
}
1543+
}
1544+
}
1545+
}
1546+
}
1547+
}
1548+
}
1549+
}
1550+
}
1551+
for (auto &pair : LegalMallocs) {
1552+
if (pair.second.size() < 2)
1553+
continue;
1554+
CallInst *First = pair.second[0].first;
1555+
for (auto &z : pair.second) {
1556+
if (!DT.dominates(First, z.first))
1557+
First = z.first;
1558+
}
1559+
bool legal = true;
1560+
for (auto &z : pair.second) {
1561+
if (auto inst = dyn_cast<Instruction>(z.first->getArgOperand(0)))
1562+
if (!DT.dominates(inst, First))
1563+
legal = true;
1564+
}
1565+
if (!legal)
1566+
continue;
1567+
IRBuilder<> B(First);
1568+
Value *Size = First->getArgOperand(0);
1569+
for (auto &z : pair.second) {
1570+
if (z.first == First)
1571+
continue;
1572+
Size = B.CreateAdd(
1573+
B.CreateOr(B.CreateSub(Size, ConstantInt::get(Size->getType(), 1)),
1574+
ConstantInt::get(Size->getType(), 15)),
1575+
ConstantInt::get(Size->getType(), 1));
1576+
z.second->eraseFromParent();
1577+
IRBuilder B2(z.first);
1578+
z.first->replaceAllUsesWith(B2.CreateInBoundsGEP(First, Size));
1579+
Size = B.CreateAdd(Size, z.first->getArgOperand(0));
1580+
z.first->eraseFromParent();
1581+
}
1582+
auto NewMalloc =
1583+
cast<CallInst>(B.CreateCall(First->getCalledFunction(), Size));
1584+
NewMalloc->copyIRFlags(First);
1585+
First->replaceAllUsesWith(NewMalloc);
1586+
First->eraseFromParent();
1587+
}
1588+
}
1589+
15271590
void PreProcessCache::optimizeIntermediate(Function *F) {
15281591
PromotePass().run(*F, FAM);
15291592
#if LLVM_VERSION_MAJOR <= 7

0 commit comments

Comments
 (0)