@@ -813,18 +813,13 @@ static jl_cgval_t emit_llvmcall(jl_codectx_t &ctx, jl_value_t **args, size_t nar
813813    JL_TYPECHK (llvmcall, type, rt);
814814    JL_TYPECHK (llvmcall, type, at);
815815
816-     //  Generate arguments
817-     std::string arguments;
818-     raw_string_ostream argstream (arguments);
819-     jl_svec_t  *tt = ((jl_datatype_t *)at)->parameters ;
820-     jl_value_t  *rtt = rt;
816+     //  Determine argument types
817+     // 
818+     //  Semantics for arguments are as follows:
819+     //  If the argument type is immutable (including bitstype), we pass the loaded llvm value
820+     //  type. Otherwise we pass a pointer to a jl_value_t.
821+     jl_svec_t  *tt = ((jl_datatype_t  *)at)->parameters ;
821822    size_t  nargt = jl_svec_len (tt);
822- 
823-     /* 
824-      * Semantics for arguments are as follows: 
825-      * If the argument type is immutable (including bitstype), we pass the loaded llvm value 
826-      * type. Otherwise we pass a pointer to a jl_value_t. 
827-      */  
828823    SmallVector<llvm::Type*, 0 > argtypes;
829824    SmallVector<Value *, 8 > argvals (nargt);
830825    for  (size_t  i = 0 ; i < nargt; ++i) {
@@ -845,45 +840,87 @@ static jl_cgval_t emit_llvmcall(jl_codectx_t &ctx, jl_value_t **args, size_t nar
845840        argvals[i] = llvm_type_rewrite (ctx, v, t, issigned);
846841    }
847842
843+     //  Determine return type
844+     jl_value_t  *rtt = rt;
848845    bool  retboxed;
849846    Type *rettype = julia_type_to_llvm (ctx, rtt, &retboxed);
850847
851848    //  Make sure to find a unique name
852849    std::string ir_name;
853850    while  (true ) {
854-         raw_string_ostream (ir_name) << (ctx.f ->getName ().str ()) << " u" jl_atomic_fetch_add_relaxed (&globalUniqueGeneratedNames, 1 );
851+         raw_string_ostream (ir_name)
852+             << (ctx.f ->getName ().str ()) << " u" 
853+             << jl_atomic_fetch_add_relaxed (&globalUniqueGeneratedNames, 1 );
855854        if  (jl_Module->getFunction (ir_name) == NULL )
856855            break ;
857856    }
858857
859858    //  generate a temporary module that contains our IR
860859    std::unique_ptr<Module> Mod;
860+     Function *f;
861861    if  (entry == NULL ) {
862862        //  we only have function IR, which we should put in a function
863863
864-         bool  first = true ;
864+         //  stringify arguments
865+         std::string arguments;
866+         raw_string_ostream argstream (arguments);
865867        for  (SmallVector<Type *, 0 >::iterator it = argtypes.begin (); it != argtypes.end (); ++it) {
866-             if  (!first )
868+             if  (it != argtypes. begin () )
867869                argstream << " ," 
868-             else 
869-                 first = false ;
870870            (*it)->print (argstream);
871871            argstream << "  " 
872872        }
873873
874+         //  stringify return type
874875        std::string rstring;
875876        raw_string_ostream rtypename (rstring);
876877        rettype->print (rtypename);
877-         std::map<uint64_t ,std::string> localDecls;
878878
879+         //  generate IR function definition
879880        std::string ir_string;
880881        raw_string_ostream ir_stream (ir_string);
881-         ir_stream << " ; Number of arguments:  "   << nargt  << " \n  " 
882-         <<  " define  " <<rtypename. str ()<< "  @ \" "  << ir_name <<  " \" ( " << argstream.str ()<< " ) {\n " 
883-         << jl_string_data (ir) << " \n }" 
882+         ir_stream << " define  "  << rtypename. str () <<  "  @ \" "   << ir_name  << " \"  ( " 
883+                   <<  argstream.str () <<  " ) {\n " 
884+                    << jl_string_data (ir) << " \n }" 
884885
885886        SMDiagnostic Err = SMDiagnostic ();
886887        Mod = parseAssemblyString (ir_stream.str (), Err, ctx.builder .getContext ());
888+ 
889+         //  backwards compatibility: support for IR with integer pointers
890+         if  (!Mod) {
891+             std::string compat_arguments;
892+             raw_string_ostream compat_argstream (compat_arguments);
893+             for  (size_t  i = 0 ; i < nargt; ++i) {
894+                 if  (i > 0 )
895+                     compat_argstream << " ," 
896+                 jl_value_t  *tti = jl_svecref (tt, i);
897+                 Type *t;
898+                 if  (jl_is_cpointer_type (tti))
899+                     t = ctx.types ().T_size ;
900+                 else 
901+                     t = argtypes[i];
902+                 t->print (compat_argstream);
903+                 compat_argstream << "  " 
904+             }
905+ 
906+             std::string compat_rstring;
907+             raw_string_ostream compat_rtypename (compat_rstring);
908+             if  (jl_is_cpointer_type (rtt))
909+                 ctx.types ().T_size ->print (compat_rtypename);
910+             else 
911+                 rettype->print (compat_rtypename);
912+ 
913+             std::string compat_ir_string;
914+             raw_string_ostream compat_ir_stream (compat_ir_string);
915+             compat_ir_stream << " define " str () << "  @\" " 
916+                              << " \" (" str () << " ) {\n " 
917+                              << jl_string_data (ir) << " \n }" 
918+ 
919+             SMDiagnostic Err = SMDiagnostic ();
920+             Mod =
921+                 parseAssemblyString (compat_ir_stream.str (), Err, ctx.builder .getContext ());
922+         }
923+ 
887924        if  (!Mod) {
888925            std::string message = " Failed to parse LLVM assembly: \n " 
889926            raw_string_ostream stream (message);
@@ -893,7 +930,7 @@ static jl_cgval_t emit_llvmcall(jl_codectx_t &ctx, jl_value_t **args, size_t nar
893930            return  jl_cgval_t ();
894931        }
895932
896-         Function * f = Mod->getFunction (ir_name);
933+         f = Mod->getFunction (ir_name);
897934        f->addFnAttr (Attribute::AlwaysInline);
898935    }
899936    else  {
@@ -931,21 +968,88 @@ static jl_cgval_t emit_llvmcall(jl_codectx_t &ctx, jl_value_t **args, size_t nar
931968            Mod = std::move (ModuleOrErr.get ());
932969        }
933970
934-         Function * f = Mod->getFunction (jl_string_data (entry));
971+         f = Mod->getFunction (jl_string_data (entry));
935972        if  (!f) {
936973            emit_error (ctx, " Module IR does not contain specified entry function" 
937974            JL_GC_POP ();
938975            return  jl_cgval_t ();
939976        }
977+         assert (!f->isDeclaration ());
940978        f->setName (ir_name);
979+     }
941980
942-         //  verify the function type
943-         assert (!f->isDeclaration ());
944-         assert (f->getReturnType () == rettype);
945-         int  i = 0 ;
946-         for  (SmallVector<Type *, 0 >::iterator it = argtypes.begin ();
947-             it != argtypes.end (); ++it, ++i)
948-             assert (*it == f->getFunctionType ()->getParamType (i));
981+     //  backwards compatibility: support for IR with integer pointers
982+     bool  mismatched_pointers = false ;
983+     for  (size_t  i = 0 ; i < nargt; ++i) {
984+         jl_value_t  *tti = jl_svecref (tt, i);
985+         if  (jl_is_cpointer_type (tti) &&
986+             !f->getFunctionType ()->getParamType (i)->isPointerTy ()) {
987+             mismatched_pointers = true ;
988+             break ;
989+         }
990+     }
991+     if  (mismatched_pointers) {
992+         if  (jl_options.depwarn ) {
993+             if  (jl_options.depwarn  == JL_OPTIONS_DEPWARN_ERROR)
994+                 jl_error (" llvmcall with integer pointers is deprecated, " 
995+                          " use an actual pointer type instead." 
996+             jl_printf (JL_STDERR,
997+                       " WARNING: llvmcall with integer pointers is deprecated.\n " 
998+                       " Use actual pointers instead, replacing i32 or i64 with i8* or ptr" 
999+             if  (jl_lineno != 0 )
1000+                 jl_printf (JL_STDERR, " , likely near %s:%d" 
1001+             jl_printf (JL_STDERR, " \n " 
1002+         }
1003+ 
1004+         //  wrap the function, performing the necesary pointer conversion
1005+ 
1006+         Function *inner = f;
1007+         inner->setName (ir_name + " .inner" 
1008+ 
1009+         FunctionType *wrapper_ft = FunctionType::get (rettype, argtypes, false );
1010+         Function *wrapper =
1011+             Function::Create (wrapper_ft, inner->getLinkage (), ir_name, *Mod);
1012+ 
1013+         wrapper->copyAttributesFrom (inner);
1014+         inner->addFnAttr (Attribute::AlwaysInline);
1015+ 
1016+         BasicBlock *entry = BasicBlock::Create (ctx.builder .getContext (), " " 
1017+         IRBuilder<> irbuilder (entry);
1018+         SmallVector<Value *, 0 > wrapper_args;
1019+         for  (size_t  i = 0 ; i < nargt; ++i) {
1020+             jl_value_t  *tti = jl_svecref (tt, i);
1021+             Value *v = wrapper->getArg (i);
1022+             if  (jl_is_cpointer_type (tti))
1023+                 v = irbuilder.CreatePtrToInt (v, ctx.types ().T_size );
1024+             wrapper_args.push_back (v);
1025+         }
1026+         Value *call = irbuilder.CreateCall (inner, wrapper_args);
1027+         //  check if void
1028+         if  (rettype->isVoidTy ())
1029+             irbuilder.CreateRetVoid ();
1030+         else  {
1031+             if  (jl_is_cpointer_type (rtt))
1032+                 call = irbuilder.CreateIntToPtr (call, ctx.types ().T_ptr );
1033+             irbuilder.CreateRet (call);
1034+         }
1035+ 
1036+         f = wrapper;
1037+     }
1038+ 
1039+     //  verify the function type
1040+     assert (f->getReturnType () == rettype);
1041+     int  i = 0 ;
1042+     for  (SmallVector<Type *, 0 >::iterator it = argtypes.begin (); it != argtypes.end ();
1043+          ++it, ++i) {
1044+         if  (*it != f->getFunctionType ()->getParamType (i)) {
1045+             std::string message;
1046+             raw_string_ostream stream (message);
1047+             stream << " Malformed llvmcall: argument " 1  << "  type " 
1048+                    << *f->getFunctionType ()->getParamType (i)
1049+                    << "  does not match expected argument type " 
1050+             emit_error (ctx, stream.str ());
1051+             return  jl_cgval_t ();
1052+         }
9491053    }
9501054
9511055    //  copy module properties that should always match
@@ -983,7 +1087,7 @@ static jl_cgval_t emit_llvmcall(jl_codectx_t &ctx, jl_value_t **args, size_t nar
9831087    if  (inst->getType () != rettype) {
9841088        std::string message;
9851089        raw_string_ostream stream (message);
986-         stream << " llvmcall return type " getType ()
1090+         stream << " Malformed  llvmcall:  return type " getType ()
9871091               << "  does not match declared return type" 
9881092        emit_error (ctx, stream.str ());
9891093        return  jl_cgval_t ();
0 commit comments