@@ -6535,10 +6535,130 @@ class AdjointGenerator
65356535 return ;
65366536 }
65376537
6538- bool modifyPrimal = shouldAugmentCall (orig, gutils, TR);
6539-
65406538 bool foreignFunction = called == nullptr || called->empty ();
65416539
6540+ FnTypeInfo nextTypeInfo (called);
6541+
6542+ if (called) {
6543+ nextTypeInfo = TR.getCallInfo (*orig, *called);
6544+ }
6545+
6546+ if (Mode == DerivativeMode::ForwardMode) {
6547+ IRBuilder<> Builder2 (&call);
6548+ getForwardBuilder (Builder2);
6549+
6550+ bool retUsed = subretused;
6551+
6552+ SmallVector<Value *, 8 > args;
6553+ SmallVector<Value *, 8 > pre_args;
6554+ std::vector<DIFFE_TYPE> argsInverted;
6555+ std::vector<Instruction *> postCreate;
6556+ std::vector<Instruction *> userReplace;
6557+ std::map<int , Type *> preByVal;
6558+ std::map<int , Type *> gradByVal;
6559+
6560+ for (unsigned i = 0 ; i < orig->getNumArgOperands (); ++i) {
6561+
6562+ auto argi = gutils->getNewFromOriginal (orig->getArgOperand (i));
6563+
6564+ #if LLVM_VERSION_MAJOR >= 9
6565+ if (orig->isByValArgument (i)) {
6566+ preByVal[pre_args.size ()] = orig->getParamByValType (i);
6567+ }
6568+ #endif
6569+
6570+ pre_args.push_back (argi);
6571+
6572+ #if LLVM_VERSION_MAJOR >= 9
6573+ if (orig->isByValArgument (i)) {
6574+ gradByVal[args.size ()] = orig->getParamByValType (i);
6575+ }
6576+ #endif
6577+ args.push_back (lookup (argi, Builder2));
6578+
6579+ if (gutils->isConstantValue (orig->getArgOperand (i)) &&
6580+ !foreignFunction) {
6581+ argsInverted.push_back (DIFFE_TYPE::CONSTANT);
6582+ continue ;
6583+ }
6584+
6585+ auto argType = argi->getType ();
6586+
6587+ if (!argType->isFPOrFPVectorTy () &&
6588+ (TR.query (orig->getArgOperand (i)).Inner0 ().isPossiblePointer () ||
6589+ foreignFunction)) {
6590+ DIFFE_TYPE ty = DIFFE_TYPE::DUP_ARG;
6591+ if (argType->isPointerTy ()) {
6592+ #if LLVM_VERSION_MAJOR >= 12
6593+ auto at = getUnderlyingObject (orig->getArgOperand (i), 100 );
6594+ #else
6595+ auto at = GetUnderlyingObject (
6596+ orig->getArgOperand (i),
6597+ gutils->oldFunc ->getParent ()->getDataLayout (), 100 );
6598+ #endif
6599+ if (auto arg = dyn_cast<Argument>(at)) {
6600+ if (constant_args[arg->getArgNo ()] == DIFFE_TYPE::DUP_NONEED) {
6601+ ty = DIFFE_TYPE::DUP_NONEED;
6602+ }
6603+ }
6604+ }
6605+ argsInverted.push_back (ty);
6606+
6607+ if (Mode != DerivativeMode::ReverseModePrimal) {
6608+ IRBuilder<> Builder2 (call.getParent ());
6609+ getReverseBuilder (Builder2);
6610+ args.push_back (
6611+ gutils->invertPointerM (orig->getArgOperand (i), Builder2));
6612+ }
6613+ pre_args.push_back (
6614+ gutils->invertPointerM (orig->getArgOperand (i), BuilderZ));
6615+
6616+ // Note sometimes whattype mistakenly says something should be
6617+ // constant [because composed of integer pointers alone]
6618+ assert (whatType (argType, Mode) == DIFFE_TYPE::DUP_ARG ||
6619+ whatType (argType, Mode) == DIFFE_TYPE::CONSTANT);
6620+ } else {
6621+ if (foreignFunction)
6622+ assert (!argType->isIntOrIntVectorTy ());
6623+
6624+ args.push_back (diffe (orig->getArgOperand (i), Builder2));
6625+ pre_args.push_back (diffe (orig->getArgOperand (i), BuilderZ));
6626+
6627+ argsInverted.push_back (DIFFE_TYPE::DUP_ARG);
6628+ }
6629+ }
6630+
6631+ auto newcalled = gutils->Logic .CreatePrimalAndGradient (
6632+ cast<Function>(called), subretType, argsInverted, gutils->TLI ,
6633+ TR.analyzer .interprocedural , /* returnValue*/ retUsed,
6634+ /* subdretptr*/ false , DerivativeMode::ForwardMode, nullptr ,
6635+ nextTypeInfo, uncacheable_args, nullptr ,
6636+ /* AtomicAdd*/ gutils->AtomicAdd );
6637+
6638+ assert (newcalled);
6639+ FunctionType *FT = cast<FunctionType>(
6640+ cast<PointerType>(newcalled->getType ())->getElementType ());
6641+
6642+ CallInst *diffes = Builder2.CreateCall (FT, newcalled, args);
6643+ diffes->setCallingConv (orig->getCallingConv ());
6644+ diffes->setDebugLoc (gutils->getNewFromOriginal (orig->getDebugLoc ()));
6645+ #if LLVM_VERSION_MAJOR >= 9
6646+ for (auto pair : gradByVal) {
6647+ diffes->addParamAttr (
6648+ pair.first ,
6649+ Attribute::getWithByValType (diffes->getContext (), pair.second ));
6650+ }
6651+ #endif
6652+
6653+ unsigned structidx = retUsed ? 1 : 0 ;
6654+ Value *diffe = Builder2.CreateExtractValue (diffes, {structidx});
6655+ setDiffe (&call, diffe, Builder2);
6656+
6657+ return ;
6658+ }
6659+
6660+ bool modifyPrimal = shouldAugmentCall (orig, gutils, TR);
6661+
65426662 SmallVector<Value *, 8 > args;
65436663 SmallVector<Value *, 8 > pre_args;
65446664 std::vector<DIFFE_TYPE> argsInverted;
@@ -6644,12 +6764,6 @@ class AdjointGenerator
66446764 CallInst *augmentcall = nullptr ;
66456765 Value *cachereplace = nullptr ;
66466766
6647- FnTypeInfo nextTypeInfo (called);
6648-
6649- if (called) {
6650- nextTypeInfo = TR.getCallInfo (*orig, *called);
6651- }
6652-
66536767 // llvm::Optional<std::map<std::pair<Instruction*, std::string>,
66546768 // unsigned>> sub_index_map;
66556769 Optional<int > tapeIdx;
0 commit comments