Skip to content

Commit 6798c48

Browse files
committed
implement call inst
1 parent b2de42d commit 6798c48

File tree

1 file changed

+122
-8
lines changed

1 file changed

+122
-8
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 122 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6535,10 +6535,130 @@ class AdjointGenerator
65356535
return;
65366536
}
65376537

6538-
bool modifyPrimal = shouldAugmentCall(orig, gutils, TR);
6539-
65406538
bool foreignFunction = called == nullptr || called->empty();
65416539

6540+
FnTypeInfo nextTypeInfo(called);
6541+
6542+
if (called) {
6543+
nextTypeInfo = TR.getCallInfo(*orig, *called);
6544+
}
6545+
6546+
if (Mode == DerivativeMode::ForwardMode) {
6547+
IRBuilder<> Builder2(&call);
6548+
getForwardBuilder(Builder2);
6549+
6550+
bool retUsed = subretused;
6551+
6552+
SmallVector<Value *, 8> args;
6553+
SmallVector<Value *, 8> pre_args;
6554+
std::vector<DIFFE_TYPE> argsInverted;
6555+
std::vector<Instruction *> postCreate;
6556+
std::vector<Instruction *> userReplace;
6557+
std::map<int, Type *> preByVal;
6558+
std::map<int, Type *> gradByVal;
6559+
6560+
for (unsigned i = 0; i < orig->getNumArgOperands(); ++i) {
6561+
6562+
auto argi = gutils->getNewFromOriginal(orig->getArgOperand(i));
6563+
6564+
#if LLVM_VERSION_MAJOR >= 9
6565+
if (orig->isByValArgument(i)) {
6566+
preByVal[pre_args.size()] = orig->getParamByValType(i);
6567+
}
6568+
#endif
6569+
6570+
pre_args.push_back(argi);
6571+
6572+
#if LLVM_VERSION_MAJOR >= 9
6573+
if (orig->isByValArgument(i)) {
6574+
gradByVal[args.size()] = orig->getParamByValType(i);
6575+
}
6576+
#endif
6577+
args.push_back(lookup(argi, Builder2));
6578+
6579+
if (gutils->isConstantValue(orig->getArgOperand(i)) &&
6580+
!foreignFunction) {
6581+
argsInverted.push_back(DIFFE_TYPE::CONSTANT);
6582+
continue;
6583+
}
6584+
6585+
auto argType = argi->getType();
6586+
6587+
if (!argType->isFPOrFPVectorTy() &&
6588+
(TR.query(orig->getArgOperand(i)).Inner0().isPossiblePointer() ||
6589+
foreignFunction)) {
6590+
DIFFE_TYPE ty = DIFFE_TYPE::DUP_ARG;
6591+
if (argType->isPointerTy()) {
6592+
#if LLVM_VERSION_MAJOR >= 12
6593+
auto at = getUnderlyingObject(orig->getArgOperand(i), 100);
6594+
#else
6595+
auto at = GetUnderlyingObject(
6596+
orig->getArgOperand(i),
6597+
gutils->oldFunc->getParent()->getDataLayout(), 100);
6598+
#endif
6599+
if (auto arg = dyn_cast<Argument>(at)) {
6600+
if (constant_args[arg->getArgNo()] == DIFFE_TYPE::DUP_NONEED) {
6601+
ty = DIFFE_TYPE::DUP_NONEED;
6602+
}
6603+
}
6604+
}
6605+
argsInverted.push_back(ty);
6606+
6607+
if (Mode != DerivativeMode::ReverseModePrimal) {
6608+
IRBuilder<> Builder2(call.getParent());
6609+
getReverseBuilder(Builder2);
6610+
args.push_back(
6611+
gutils->invertPointerM(orig->getArgOperand(i), Builder2));
6612+
}
6613+
pre_args.push_back(
6614+
gutils->invertPointerM(orig->getArgOperand(i), BuilderZ));
6615+
6616+
// Note sometimes whattype mistakenly says something should be
6617+
// constant [because composed of integer pointers alone]
6618+
assert(whatType(argType, Mode) == DIFFE_TYPE::DUP_ARG ||
6619+
whatType(argType, Mode) == DIFFE_TYPE::CONSTANT);
6620+
} else {
6621+
if (foreignFunction)
6622+
assert(!argType->isIntOrIntVectorTy());
6623+
6624+
args.push_back(diffe(orig->getArgOperand(i), Builder2));
6625+
pre_args.push_back(diffe(orig->getArgOperand(i), BuilderZ));
6626+
6627+
argsInverted.push_back(DIFFE_TYPE::DUP_ARG);
6628+
}
6629+
}
6630+
6631+
auto newcalled = gutils->Logic.CreatePrimalAndGradient(
6632+
cast<Function>(called), subretType, argsInverted, gutils->TLI,
6633+
TR.analyzer.interprocedural, /*returnValue*/ retUsed,
6634+
/*subdretptr*/ false, DerivativeMode::ForwardMode, nullptr,
6635+
nextTypeInfo, uncacheable_args, nullptr,
6636+
/*AtomicAdd*/ gutils->AtomicAdd);
6637+
6638+
assert(newcalled);
6639+
FunctionType *FT = cast<FunctionType>(
6640+
cast<PointerType>(newcalled->getType())->getElementType());
6641+
6642+
CallInst *diffes = Builder2.CreateCall(FT, newcalled, args);
6643+
diffes->setCallingConv(orig->getCallingConv());
6644+
diffes->setDebugLoc(gutils->getNewFromOriginal(orig->getDebugLoc()));
6645+
#if LLVM_VERSION_MAJOR >= 9
6646+
for (auto pair : gradByVal) {
6647+
diffes->addParamAttr(
6648+
pair.first,
6649+
Attribute::getWithByValType(diffes->getContext(), pair.second));
6650+
}
6651+
#endif
6652+
6653+
unsigned structidx = retUsed ? 1 : 0;
6654+
Value *diffe = Builder2.CreateExtractValue(diffes, {structidx});
6655+
setDiffe(&call, diffe, Builder2);
6656+
6657+
return;
6658+
}
6659+
6660+
bool modifyPrimal = shouldAugmentCall(orig, gutils, TR);
6661+
65426662
SmallVector<Value *, 8> args;
65436663
SmallVector<Value *, 8> pre_args;
65446664
std::vector<DIFFE_TYPE> argsInverted;
@@ -6644,12 +6764,6 @@ class AdjointGenerator
66446764
CallInst *augmentcall = nullptr;
66456765
Value *cachereplace = nullptr;
66466766

6647-
FnTypeInfo nextTypeInfo(called);
6648-
6649-
if (called) {
6650-
nextTypeInfo = TR.getCallInfo(*orig, *called);
6651-
}
6652-
66536767
// llvm::Optional<std::map<std::pair<Instruction*, std::string>,
66546768
// unsigned>> sub_index_map;
66556769
Optional<int> tapeIdx;

0 commit comments

Comments
 (0)