Skip to content

Commit 2ed4acd

Browse files
committed
Handle external functions in other images
When we find an edge to an external function (already cached in an loaded pkgimage), we emit a global variable which we will patch during loading with the address of the function to call.
1 parent 381f5bd commit 2ed4acd

File tree

9 files changed

+160
-32
lines changed

9 files changed

+160
-32
lines changed

src/aotcompile.cpp

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ typedef struct {
9393
std::vector<GlobalValue*> jl_sysimg_gvars;
9494
std::map<jl_code_instance_t*, std::tuple<uint32_t, uint32_t>> jl_fvar_map;
9595
std::vector<void*> jl_value_to_llvm;
96+
std::vector<jl_code_instance_t*> jl_external_to_llvm;
9697
} jl_native_code_desc_t;
9798

9899
extern "C" JL_DLLEXPORT
@@ -118,6 +119,15 @@ void jl_get_llvm_gvs_impl(void *native_code, arraylist_t *gvs)
118119
memcpy(gvs->items, data->jl_value_to_llvm.data(), gvs->len * sizeof(void*));
119120
}
120121

122+
extern "C" JL_DLLEXPORT
123+
void jl_get_llvm_external_fns_impl(void *native_code, arraylist_t *external_fns)
124+
{
125+
jl_native_code_desc_t *data = (jl_native_code_desc_t*)native_code;
126+
arraylist_grow(external_fns, data->jl_external_to_llvm.size());
127+
memcpy(external_fns->items, data->jl_external_to_llvm.data(),
128+
external_fns->len * sizeof(jl_code_instance_t*));
129+
}
130+
121131
extern "C" JL_DLLEXPORT
122132
LLVMOrcThreadSafeModuleRef jl_get_llvm_module_impl(void *native_code)
123133
{
@@ -251,10 +261,12 @@ static void jl_ci_cache_lookup(const jl_cgparams_t &cgparams, jl_method_instance
251261
// takes the running content that has collected in the shadow module and dump it to disk
252262
// this builds the object file portion of the sysimage files for fast startup, and can
253263
// also be used be extern consumers like GPUCompiler.jl to obtain a module containing
254-
// all reachable & inferrrable functions. The `policy` flag switches between the default
255-
// mode `0`, the extern mode `1`.
264+
// all reachable & inferrrable functions.
265+
// The `policy` flag switches between the default mode `0` and the extern mode `1` used by GPUCompiler.
266+
// `_imaging_mode` controls if raw pointers can be embedded (e.g. the code will be loaded into the same session).
267+
// `_external_linkage` create linkages between pkgimages.
256268
extern "C" JL_DLLEXPORT
257-
void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvmmod, const jl_cgparams_t *cgparams, int _policy, int _imaging_mode)
269+
void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvmmod, const jl_cgparams_t *cgparams, int _policy, int _imaging_mode, int _external_linkage)
258270
{
259271
++CreateNativeCalls;
260272
CreateNativeMax.updateMax(jl_array_len(methods));
@@ -284,6 +296,7 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
284296
compiler_start_time = jl_hrtime();
285297

286298
params.imaging = imaging;
299+
params.external_linkage = _external_linkage;
287300

288301
// compile all methods for the current world and type-inference world
289302
size_t compile_for[] = { jl_typeinf_world, jl_atomic_load_acquire(&jl_world_counter) };
@@ -342,6 +355,46 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
342355
}
343356
CreateNativeMethods += emitted.size();
344357

358+
size_t offset = gvars.size();
359+
data->jl_external_to_llvm.resize(params.external_fns.size());
360+
361+
for (auto &extern_fn : params.external_fns) {
362+
jl_code_instance_t *this_code = std::get<0>(extern_fn.first);
363+
bool specsig = std::get<1>(extern_fn.first);
364+
assert(specsig && "Error external_fns doesn't handle non-specsig yet");
365+
Function *F = extern_fn.second;
366+
Module *M = F->getParent();
367+
368+
Type *T_funcp = F->getFunctionType()->getPointerTo();
369+
// Can't create a GC with type FunctionType. Alias also doesn't work
370+
GlobalVariable *GV = new GlobalVariable(*M, T_funcp, false,
371+
GlobalVariable::ExternalLinkage,
372+
Constant::getNullValue(T_funcp),
373+
F->getName());
374+
375+
376+
// Need to insert load instruction... can't RAUW
377+
for (Value *Use: F->users()) {
378+
if (auto CI = dyn_cast<CallInst>(Use)) {
379+
auto Callee = new LoadInst(T_funcp, GV, "", false, Align(1), CI); // TODO correct Align?
380+
CI->setCalledOperand(Callee);
381+
continue;
382+
} else {
383+
llvm::outs() << *Use << "\n";
384+
assert(false);
385+
}
386+
}
387+
388+
assert(F->getNumUses() == 0); // declaration counts as use
389+
GV->takeName(F);
390+
F->eraseFromParent();
391+
392+
size_t idx = gvars.size() - offset;
393+
assert(idx >= 0);
394+
data->jl_external_to_llvm.at(idx) = this_code;
395+
gvars.push_back(std::string(GV->getName()));
396+
}
397+
345398
// clones the contents of the module `m` to the shadow_output collector
346399
// while examining and recording what kind of function pointer we have
347400
for (auto &def : emitted) {

src/codegen-stubs.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ JL_DLLEXPORT void jl_dump_native_fallback(void *native_code,
1414
const char *bc_fname, const char *unopt_bc_fname, const char *obj_fname, const char *asm_fname,
1515
const char *sysimg_data, size_t sysimg_len) UNAVAILABLE
1616
JL_DLLEXPORT void jl_get_llvm_gvs_fallback(void *native_code, arraylist_t *gvs) UNAVAILABLE
17+
JL_DLLEXPORT void jl_get_llvm_external_fns_fallback(void *native_code, arraylist_t *gvs) UNAVAILABLE
1718

1819
JL_DLLEXPORT void jl_extern_c_fallback(jl_function_t *f, jl_value_t *rt, jl_value_t *argt, char *name) UNAVAILABLE
1920
JL_DLLEXPORT jl_value_t *jl_dump_method_asm_fallback(jl_method_instance_t *linfo, size_t world,
@@ -66,7 +67,7 @@ JL_DLLEXPORT size_t jl_jit_total_bytes_fallback(void)
6667
return 0;
6768
}
6869

69-
JL_DLLEXPORT void *jl_create_native_fallback(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvmmod, const jl_cgparams_t *cgparams, int _policy, int _imaging_mode) UNAVAILABLE
70+
JL_DLLEXPORT void *jl_create_native_fallback(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvmmod, const jl_cgparams_t *cgparams, int _policy, int _imaging_mode, int _external_linkage) UNAVAILABLE
7071

7172
JL_DLLEXPORT void jl_dump_compiles_fallback(void *s)
7273
{

src/codegen.cpp

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,6 +1418,7 @@ class jl_codectx_t {
14181418
jl_codegen_params_t &emission_context;
14191419
llvm::MapVector<jl_code_instance_t*, jl_codegen_call_target_t> call_targets;
14201420
std::map<void*, GlobalVariable*> &global_targets;
1421+
std::map<std::tuple<jl_code_instance_t*, bool>, Function*> &external_calls;
14211422
Function *f = NULL;
14221423
// local var info. globals are not in here.
14231424
std::vector<jl_varinfo_t> slots;
@@ -1454,6 +1455,7 @@ class jl_codectx_t {
14541455

14551456
bool debug_enabled = false;
14561457
bool use_cache = false;
1458+
bool external_linkage = false;
14571459
const jl_cgparams_t *params = NULL;
14581460

14591461
std::vector<orc::ThreadSafeModule> llvmcall_modules;
@@ -1463,8 +1465,10 @@ class jl_codectx_t {
14631465
emission_context(params),
14641466
call_targets(),
14651467
global_targets(params.globals),
1468+
external_calls(params.external_fns),
14661469
world(params.world),
14671470
use_cache(params.cache),
1471+
external_linkage(params.external_linkage),
14681472
params(params.params) { }
14691473

14701474
jl_typecache_t &types() {
@@ -4017,9 +4021,18 @@ static jl_cgval_t emit_invoke(jl_codectx_t &ctx, const jl_cgval_t &lival, const
40174021
std::string name;
40184022
StringRef protoname;
40194023
bool need_to_emit = true;
4020-
// TODO: We should check if the code is available externally
4021-
// and then emit a trampoline.
4022-
if (ctx.use_cache) {
4024+
bool cache_valid = ctx.use_cache;
4025+
bool external = false;
4026+
if (ctx.external_linkage) {
4027+
if (jl_object_in_image((jl_value_t*)codeinst)) {
4028+
// Target is present in another pkgimage
4029+
jl_printf(JL_STDERR, "\n (emit_invoke:) Want to resolve method!\n");
4030+
cache_valid = true;
4031+
external = true;
4032+
}
4033+
}
4034+
4035+
if (cache_valid) {
40234036
// optimization: emit the correct name immediately, if we know it
40244037
// TODO: use `emitted` map here too to try to consolidate names?
40254038
auto invoke = jl_atomic_load_relaxed(&codeinst->invoke);
@@ -4046,6 +4059,13 @@ static jl_cgval_t emit_invoke(jl_codectx_t &ctx, const jl_cgval_t &lival, const
40464059
result = emit_call_specfun_other(ctx, mi, codeinst->rettype, protoname, argv, nargs, &cc, &return_roots, rt);
40474060
else
40484061
result = emit_call_specfun_boxed(ctx, codeinst->rettype, protoname, argv, nargs, rt);
4062+
if (external) {
4063+
assert(!need_to_emit);
4064+
auto calledF = jl_Module->getFunction(protoname);
4065+
assert(calledF);
4066+
// TODO: Check if already present?
4067+
ctx.external_calls[std::make_tuple(codeinst, specsig)] = calledF;
4068+
}
40494069
handled = true;
40504070
if (need_to_emit) {
40514071
Function *trampoline_decl = cast<Function>(jl_Module->getNamedValue(protoname));
@@ -5365,7 +5385,17 @@ static Function *emit_tojlinvoke(jl_code_instance_t *codeinst, Module *M, jl_cod
53655385
Function *theFunc;
53665386
Value *theFarg;
53675387
auto invoke = jl_atomic_load_relaxed(&codeinst->invoke);
5368-
if (params.cache && invoke != NULL) {
5388+
5389+
bool cache_valid = params.cache;
5390+
if (params.external_linkage) {
5391+
if (jl_object_in_image((jl_value_t*)codeinst)) {
5392+
// Target is present in another pkgimage
5393+
jl_printf(JL_STDERR, "\n (emit_jlinvoke) Want to resolve method\n");
5394+
cache_valid = true;
5395+
}
5396+
}
5397+
5398+
if (cache_valid && invoke != NULL) {
53695399
StringRef theFptrName = jl_ExecutionEngine->getFunctionAtAddress((uintptr_t)invoke, codeinst);
53705400
theFunc = cast<Function>(
53715401
M->getOrInsertFunction(theFptrName, jlinvoke_func->_type(ctx.builder.getContext())).getCallee());
@@ -8262,11 +8292,11 @@ void jl_compile_workqueue(
82628292
StringRef preal_decl = "";
82638293
bool preal_specsig = false;
82648294
auto invoke = jl_atomic_load_relaxed(&codeinst->invoke);
8265-
// TODO: available_extern
8266-
// We need to emit a trampoline that loads the target address in an extern_module from a GV
8267-
// Right now we will unecessarily emit a function we have already compiled in a native module
8268-
// again in a calling module.
8269-
if (params.cache && invoke != NULL) {
8295+
bool cache_valid = params.cache;
8296+
if (params.external_linkage) {
8297+
cache_valid = jl_object_in_image((jl_value_t*)codeinst);
8298+
}
8299+
if (cache_valid && invoke != NULL) {
82708300
auto fptr = jl_atomic_load_relaxed(&codeinst->specptr.fptr);
82718301
if (invoke == jl_fptr_args_addr) {
82728302
preal_decl = jl_ExecutionEngine->getFunctionAtAddress((uintptr_t)fptr, codeinst);
@@ -8275,7 +8305,7 @@ void jl_compile_workqueue(
82758305
preal_decl = jl_ExecutionEngine->getFunctionAtAddress((uintptr_t)fptr, codeinst);
82768306
preal_specsig = true;
82778307
}
8278-
}
8308+
}
82798309
else {
82808310
auto &result = emitted[codeinst];
82818311
jl_llvm_functions_t *decls = NULL;

src/jitlayers.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ typedef struct _jl_codegen_params_t {
173173
// outputs
174174
std::vector<std::pair<jl_code_instance_t*, jl_codegen_call_target_t>> workqueue;
175175
std::map<void*, GlobalVariable*> globals;
176+
std::map<std::tuple<jl_code_instance_t*,bool>, Function*> external_fns;
176177
std::map<jl_datatype_t*, DIType*> ditypes;
177178
std::map<jl_datatype_t*, Type*> llvmtypes;
178179
DenseMap<Constant*, GlobalVariable*> mergedConstants;
@@ -200,6 +201,7 @@ typedef struct _jl_codegen_params_t {
200201
size_t world = 0;
201202
const jl_cgparams_t *params = &jl_default_cgparams;
202203
bool cache = false;
204+
bool external_linkage = false;
203205
bool imaging;
204206
_jl_codegen_params_t(orc::ThreadSafeContext ctx) : tsctx(std::move(ctx)), tsctx_lock(tsctx.getLock()), imaging(imaging_default()) {}
205207
} jl_codegen_params_t;

src/jl_exported_data.inc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@
127127
XX(jl_voidpointer_type) \
128128
XX(jl_void_type) \
129129
XX(jl_weakref_type) \
130+
XX(jl_build_ids) \
131+
XX(jl_linkage_blobs) \
130132

131133
// Data symbols that are defined inside the public libjulia
132134
#define JL_EXPORTED_DATA_SYMBOLS(XX) \

src/jl_exported_funcs.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,7 @@
535535
YY(jl_get_LLVM_VERSION) \
536536
YY(jl_dump_native) \
537537
YY(jl_get_llvm_gvs) \
538+
YY(jl_get_llvm_external_fns) \
538539
YY(jl_dump_function_asm) \
539540
YY(jl_LLVMCreateDisasm) \
540541
YY(jl_LLVMDisasmInstruction) \

src/julia_internal.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -970,11 +970,12 @@ JL_DLLEXPORT jl_value_t *jl_dump_fptr_asm(uint64_t fptr, char raw_mc, const char
970970
JL_DLLEXPORT jl_value_t *jl_dump_function_ir(jl_llvmf_dump_t *dump, char strip_ir_metadata, char dump_module, const char *debuginfo);
971971
JL_DLLEXPORT jl_value_t *jl_dump_function_asm(jl_llvmf_dump_t *dump, char raw_mc, const char* asm_variant, const char *debuginfo, char binary);
972972

973-
void *jl_create_native(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvmmod, const jl_cgparams_t *cgparams, int policy, int imaging_mode);
973+
void *jl_create_native(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvmmod, const jl_cgparams_t *cgparams, int policy, int imaging_mode, int cache);
974974
void jl_dump_native(void *native_code,
975975
const char *bc_fname, const char *unopt_bc_fname, const char *obj_fname, const char *asm_fname,
976976
const char *sysimg_data, size_t sysimg_len);
977977
void jl_get_llvm_gvs(void *native_code, arraylist_t *gvs);
978+
void jl_get_llvm_external_fns(void *native_code, arraylist_t *gvs);
978979
JL_DLLEXPORT void jl_get_function_id(void *native_code, jl_code_instance_t *ncode,
979980
int32_t *func_idx, int32_t *specfunc_idx);
980981

src/precompile.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ static int precompile_enq_all_specializations_(jl_methtable_t *mt, void *env)
343343
return jl_typemap_visitor(jl_atomic_load_relaxed(&mt->defs), precompile_enq_all_specializations__, env);
344344
}
345345

346-
static void *jl_precompile_(jl_array_t *m)
346+
static void *jl_precompile_(jl_array_t *m, int external_linkage)
347347
{
348348
jl_array_t *m2 = NULL;
349349
jl_method_instance_t *mi = NULL;
@@ -366,7 +366,7 @@ static void *jl_precompile_(jl_array_t *m)
366366
jl_array_ptr_1d_push(m2, item);
367367
}
368368
}
369-
void *native_code = jl_create_native(m2, NULL, NULL, 0, 1);
369+
void *native_code = jl_create_native(m2, NULL, NULL, 0, 1, external_linkage);
370370
JL_GC_POP();
371371
return native_code;
372372
}
@@ -379,7 +379,7 @@ static void *jl_precompile(int all)
379379
if (all)
380380
jl_compile_all_defs(m);
381381
jl_foreach_reachable_mtable(precompile_enq_all_specializations_, m);
382-
void *native_code = jl_precompile_(m);
382+
void *native_code = jl_precompile_(m, 0);
383383
JL_GC_POP();
384384
return native_code;
385385
}
@@ -398,7 +398,7 @@ static void *jl_precompile_worklist(jl_array_t *worklist)
398398
assert(jl_is_module(mod));
399399
foreach_mtable_in_module(mod, precompile_enq_all_specializations_, m);
400400
}
401-
void *native_code = jl_precompile_(m);
401+
void *native_code = jl_precompile_(m, 1);
402402
JL_GC_POP();
403403
return native_code;
404404
}

0 commit comments

Comments
 (0)