@@ -1657,6 +1657,33 @@ LowerUnalignedLoadRetParam(SelectionDAG &DAG, SDValue &Chain, uint64_t Offset,
16571657 return RetVal;
16581658}
16591659
1660+ static bool shouldConvertToIndirectCall (bool IsVarArg, unsigned ParamCount,
1661+ NVPTXTargetLowering::ArgListTy &Args,
1662+ const CallBase *CB,
1663+ GlobalAddressSDNode *Func) {
1664+ if (!Func)
1665+ return false ;
1666+ auto *CalleeFunc = dyn_cast<Function>(Func->getGlobal ());
1667+ if (!CalleeFunc)
1668+ return false ;
1669+
1670+ auto ActualReturnType = CalleeFunc->getReturnType ();
1671+ if (CB->getType () != ActualReturnType)
1672+ return true ;
1673+
1674+ if (IsVarArg)
1675+ return false ;
1676+
1677+ auto ActualNumParams = CalleeFunc->getFunctionType ()->getNumParams ();
1678+ if (ParamCount != ActualNumParams)
1679+ return true ;
1680+ for (const Argument &I : CalleeFunc->args ())
1681+ if (I.getType () != Args[I.getArgNo ()].Ty )
1682+ return true ;
1683+
1684+ return false ;
1685+ }
1686+
16601687SDValue NVPTXTargetLowering::LowerCall (TargetLowering::CallLoweringInfo &CLI,
16611688 SmallVectorImpl<SDValue> &InVals) const {
16621689
@@ -1971,10 +1998,16 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
19711998 VADeclareParam->getVTList (), DeclareParamOps);
19721999 }
19732000
2001+ // If the param count, type of any param, or return type of the callsite
2002+ // mismatches with that of the function signature, convert the callsite to an
2003+ // indirect call.
2004+ bool ConvertToIndirectCall =
2005+ shouldConvertToIndirectCall (CLI.IsVarArg , ParamCount, Args, CB, Func);
2006+
19742007 // Both indirect calls and libcalls have nullptr Func. In order to distinguish
19752008 // between them we must rely on the call site value which is valid for
19762009 // indirect calls but is always null for libcalls.
1977- bool isIndirectCall = !Func && CB;
2010+ bool isIndirectCall = ( !Func && CB) || ConvertToIndirectCall ;
19782011
19792012 if (isa<ExternalSymbolSDNode>(Callee)) {
19802013 Function* CalleeFunc = nullptr ;
@@ -2026,6 +2059,18 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
20262059 Chain = DAG.getNode (Opcode, dl, PrintCallVTs, PrintCallOps);
20272060 InGlue = Chain.getValue (1 );
20282061
2062+ if (ConvertToIndirectCall) {
2063+ // Copy the function ptr to a ptx register and use the register to call the
2064+ // function.
2065+ EVT DestVT = Callee.getValueType ();
2066+ MachineRegisterInfo &RegInfo = DAG.getMachineFunction ().getRegInfo ();
2067+ const TargetLowering &TLI = DAG.getTargetLoweringInfo ();
2068+ unsigned DestReg =
2069+ RegInfo.createVirtualRegister (TLI.getRegClassFor (DestVT.getSimpleVT ()));
2070+ auto RegCopy = DAG.getCopyToReg (DAG.getEntryNode (), dl, DestReg, Callee);
2071+ Callee = DAG.getCopyFromReg (RegCopy, dl, DestReg, DestVT);
2072+ }
2073+
20292074 // Ops to print out the function name
20302075 SDVTList CallVoidVTs = DAG.getVTList (MVT::Other, MVT::Glue);
20312076 SDValue CallVoidOps[] = { Chain, Callee, InGlue };
0 commit comments