Skip to content

Commit aa05c98

Browse files
authored
precompile: do better union split and concrete compilation search (#56496)
This fixes some bugs that prevent compile-all from working correctly at all, and uses more of it for normal compile. Increases sysimg size from about 140 to 170 MB of data and 11 to 15 MB of code
2 parents 072d9d1 + 882f940 commit aa05c98

File tree

5 files changed

+126
-90
lines changed

5 files changed

+126
-90
lines changed

src/aotcompile.cpp

+34-30
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,8 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
524524
// Const returns do not do codegen, but juliac inspects codegen results so make a dummy fvar entry to represent it
525525
if (jl_options.trim != JL_TRIM_NO && jl_atomic_load_relaxed(&codeinst->invoke) == jl_fptr_const_return_addr) {
526526
data->jl_fvar_map[codeinst] = std::make_tuple((uint32_t)-3, (uint32_t)-3);
527-
} else {
527+
}
528+
else {
528529
JL_GC_PROMISE_ROOTED(codeinst->rettype);
529530
orc::ThreadSafeModule result_m = jl_create_ts_module(name_from_method_instance(codeinst->def),
530531
params.tsctx, clone.getModuleUnlocked()->getDataLayout(),
@@ -609,6 +610,9 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
609610
else if (func == "jl_fptr_sparam") {
610611
func_id = -2;
611612
}
613+
else if (decls.functionObject == "jl_f_opaque_closure_call") {
614+
func_id = -4;
615+
}
612616
else {
613617
//Safe b/c context is locked by params
614618
data->jl_sysimg_fvars.push_back(cast<Function>(clone.getModuleUnlocked()->getNamedValue(func)));
@@ -896,19 +900,18 @@ struct Partition {
896900
size_t weight;
897901
};
898902

899-
static bool canPartition(const GlobalValue &G) {
900-
if (auto F = dyn_cast<Function>(&G)) {
901-
if (F->hasFnAttribute(Attribute::AlwaysInline))
902-
return false;
903-
}
904-
return true;
903+
static bool canPartition(const Function &F)
904+
{
905+
return !F.hasFnAttribute(Attribute::AlwaysInline);
905906
}
906907

907-
static inline bool verify_partitioning(const SmallVectorImpl<Partition> &partitions, const Module &M, size_t fvars_size, size_t gvars_size) {
908+
static inline bool verify_partitioning(const SmallVectorImpl<Partition> &partitions, const Module &M, DenseMap<GlobalValue *, unsigned> &fvars, DenseMap<GlobalValue *, unsigned> &gvars) {
908909
bool bad = false;
909910
#ifndef JL_NDEBUG
910-
SmallVector<uint32_t, 0> fvars(fvars_size);
911-
SmallVector<uint32_t, 0> gvars(gvars_size);
911+
size_t fvars_size = fvars.size();
912+
size_t gvars_size = gvars.size();
913+
SmallVector<uint32_t, 0> fvars_partition(fvars_size);
914+
SmallVector<uint32_t, 0> gvars_partition(gvars_size);
912915
StringMap<uint32_t> GVNames;
913916
for (uint32_t i = 0; i < partitions.size(); i++) {
914917
for (auto &name : partitions[i].globals) {
@@ -919,18 +922,18 @@ static inline bool verify_partitioning(const SmallVectorImpl<Partition> &partiti
919922
GVNames[name.getKey()] = i;
920923
}
921924
for (auto &fvar : partitions[i].fvars) {
922-
if (fvars[fvar.second] != 0) {
925+
if (fvars_partition[fvar.second] != 0) {
923926
bad = true;
924-
dbgs() << "Duplicate fvar " << fvar.first() << " in partitions " << i << " and " << fvars[fvar.second] - 1 << "\n";
927+
dbgs() << "Duplicate fvar " << fvar.first() << " in partitions " << i << " and " << fvars_partition[fvar.second] - 1 << "\n";
925928
}
926-
fvars[fvar.second] = i+1;
929+
fvars_partition[fvar.second] = i+1;
927930
}
928931
for (auto &gvar : partitions[i].gvars) {
929-
if (gvars[gvar.second] != 0) {
932+
if (gvars_partition[gvar.second] != 0) {
930933
bad = true;
931-
dbgs() << "Duplicate gvar " << gvar.first() << " in partitions " << i << " and " << gvars[gvar.second] - 1 << "\n";
934+
dbgs() << "Duplicate gvar " << gvar.first() << " in partitions " << i << " and " << gvars_partition[gvar.second] - 1 << "\n";
932935
}
933-
gvars[gvar.second] = i+1;
936+
gvars_partition[gvar.second] = i+1;
934937
}
935938
}
936939
for (auto &GV : M.global_values()) {
@@ -941,13 +944,6 @@ static inline bool verify_partitioning(const SmallVectorImpl<Partition> &partiti
941944
}
942945
} else {
943946
// Local global values are not partitioned
944-
if (!canPartition(GV)) {
945-
if (GVNames.count(GV.getName())) {
946-
bad = true;
947-
dbgs() << "Shouldn't have partitioned " << GV.getName() << ", but is in partition " << GVNames[GV.getName()] << "\n";
948-
}
949-
continue;
950-
}
951947
if (!GVNames.count(GV.getName())) {
952948
bad = true;
953949
dbgs() << "Global " << GV << " not in any partition\n";
@@ -967,13 +963,14 @@ static inline bool verify_partitioning(const SmallVectorImpl<Partition> &partiti
967963
}
968964
}
969965
for (uint32_t i = 0; i < fvars_size; i++) {
970-
if (fvars[i] == 0) {
966+
if (fvars_partition[i] == 0) {
967+
auto gv = find_if(fvars.begin(), fvars.end(), [i](auto var) { return var.second == i; });
971968
bad = true;
972-
dbgs() << "fvar " << i << " not in any partition\n";
969+
dbgs() << "fvar " << gv->first->getName() << " at " << i << " not in any partition\n";
973970
}
974971
}
975972
for (uint32_t i = 0; i < gvars_size; i++) {
976-
if (gvars[i] == 0) {
973+
if (gvars_partition[i] == 0) {
977974
bad = true;
978975
dbgs() << "gvar " << i << " not in any partition\n";
979976
}
@@ -1035,8 +1032,6 @@ static SmallVector<Partition, 32> partitionModule(Module &M, unsigned threads) {
10351032
for (auto &G : M.global_values()) {
10361033
if (G.isDeclaration())
10371034
continue;
1038-
if (!canPartition(G))
1039-
continue;
10401035
// Currently ccallable global aliases have extern linkage, we only want to make the
10411036
// internally linked functions/global variables extern+hidden
10421037
if (G.hasLocalLinkage()) {
@@ -1045,7 +1040,8 @@ static SmallVector<Partition, 32> partitionModule(Module &M, unsigned threads) {
10451040
}
10461041
if (auto F = dyn_cast<Function>(&G)) {
10471042
partitioner.make(&G, getFunctionWeight(*F).weight);
1048-
} else {
1043+
}
1044+
else {
10491045
partitioner.make(&G, 1);
10501046
}
10511047
}
@@ -1117,7 +1113,9 @@ static SmallVector<Partition, 32> partitionModule(Module &M, unsigned threads) {
11171113
}
11181114
}
11191115

1120-
bool verified = verify_partitioning(partitions, M, fvars.size(), gvars.size());
1116+
bool verified = verify_partitioning(partitions, M, fvars, gvars);
1117+
if (!verified)
1118+
M.dump();
11211119
assert(verified && "Partitioning failed to partition globals correctly");
11221120
(void) verified;
11231121

@@ -1371,6 +1369,12 @@ static void materializePreserved(Module &M, Partition &partition) {
13711369
continue;
13721370
if (Preserve.contains(&F))
13731371
continue;
1372+
if (!canPartition(F)) {
1373+
F.setLinkage(GlobalValue::AvailableExternallyLinkage);
1374+
F.setVisibility(GlobalValue::HiddenVisibility);
1375+
F.setDSOLocal(true);
1376+
continue;
1377+
}
13741378
F.deleteBody();
13751379
F.setLinkage(GlobalValue::ExternalLinkage);
13761380
F.setVisibility(GlobalValue::HiddenVisibility);

src/gf.c

+7-1
Original file line numberDiff line numberDiff line change
@@ -3188,6 +3188,12 @@ JL_DLLEXPORT void jl_compile_method_instance(jl_method_instance_t *mi, jl_tuplet
31883188
}
31893189
}
31903190

3191+
JL_DLLEXPORT void jl_compile_method_sig(jl_method_t *m, jl_value_t *types, jl_svec_t *env, size_t world)
3192+
{
3193+
jl_method_instance_t *mi = jl_specializations_get_linfo(m, types, env);
3194+
jl_compile_method_instance(mi, NULL, world);
3195+
}
3196+
31913197
JL_DLLEXPORT int jl_compile_hint(jl_tupletype_t *types)
31923198
{
31933199
size_t world = jl_atomic_load_acquire(&jl_world_counter);
@@ -3197,7 +3203,7 @@ JL_DLLEXPORT int jl_compile_hint(jl_tupletype_t *types)
31973203
if (mi == NULL)
31983204
return 0;
31993205
JL_GC_PROMISE_ROOTED(mi);
3200-
jl_compile_method_instance(mi, types, world);
3206+
jl_compile_method_instance(mi, NULL, world);
32013207
return 1;
32023208
}
32033209

src/julia_internal.h

+1
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,7 @@ JL_DLLEXPORT jl_module_t *jl_debuginfo_module1(jl_value_t *debuginfo_def) JL_NOT
695695
JL_DLLEXPORT const char *jl_debuginfo_name(jl_value_t *func) JL_NOTSAFEPOINT;
696696

697697
JL_DLLEXPORT void jl_compile_method_instance(jl_method_instance_t *mi, jl_tupletype_t *types, size_t world);
698+
JL_DLLEXPORT void jl_compile_method_sig(jl_method_t *m, jl_value_t *types, jl_svec_t *sparams, size_t world);
698699
JL_DLLEXPORT int jl_compile_hint(jl_tupletype_t *types);
699700
JL_DLLEXPORT int jl_add_entrypoint(jl_tupletype_t *types);
700701
jl_code_info_t *jl_code_for_interpreter(jl_method_instance_t *lam JL_PROPAGATES_ROOT, size_t world);

src/precompile_utils.c

+72-58
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
// f{<:Union{...}}(...) is a common pattern
2-
// and expanding the Union may give a leaf function
3-
static void _compile_all_tvar_union(jl_value_t *methsig)
1+
// This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
// f(...) where {T<:Union{...}} is a common pattern
4+
// and expanding the Union may give some leaf functions
5+
static int _compile_all_tvar_union(jl_value_t *methsig)
46
{
57
int tvarslen = jl_subtype_env_size(methsig);
68
jl_value_t *sigbody = methsig;
@@ -13,86 +15,94 @@ static void _compile_all_tvar_union(jl_value_t *methsig)
1315
assert(jl_is_unionall(sigbody));
1416
idx[i] = 0;
1517
env[2 * i] = (jl_value_t*)((jl_unionall_t*)sigbody)->var;
16-
env[2 * i + 1] = jl_bottom_type; // initialize the list with Union{}, since T<:Union{} is always a valid option
18+
jl_value_t *tv = env[2 * i];
19+
while (jl_is_typevar(tv))
20+
tv = ((jl_tvar_t*)tv)->ub;
21+
if (jl_is_abstracttype(tv) && !jl_is_type_type(tv)) {
22+
JL_GC_POP();
23+
return 0; // Any as TypeVar is common and not useful here to try to analyze further
24+
}
25+
env[2 * i + 1] = tv;
1726
sigbody = ((jl_unionall_t*)sigbody)->body;
1827
}
1928

20-
for (i = 0; i < tvarslen; /* incremented by inner loop */) {
21-
jl_value_t **sig = &roots[0];
29+
int all = 1;
30+
int incr = 0;
31+
while (!incr) {
32+
for (i = 0, incr = 1; i < tvarslen; i++) {
33+
jl_value_t *tv = env[2 * i];
34+
while (jl_is_typevar(tv))
35+
tv = ((jl_tvar_t*)tv)->ub;
36+
if (jl_is_uniontype(tv)) {
37+
size_t l = jl_count_union_components(tv);
38+
size_t j = idx[i];
39+
env[2 * i + 1] = jl_nth_union_component(tv, j);
40+
++j;
41+
if (incr) {
42+
if (j == l) {
43+
idx[i] = 0;
44+
}
45+
else {
46+
idx[i] = j;
47+
incr = 0;
48+
}
49+
}
50+
}
51+
}
52+
jl_value_t *sig = NULL;
2253
JL_TRY {
2354
// TODO: wrap in UnionAll for each tvar in env[2*i + 1] ?
2455
// currently doesn't matter much, since jl_compile_hint doesn't work on abstract types
25-
*sig = (jl_value_t*)jl_instantiate_type_with(sigbody, env, tvarslen);
56+
sig = (jl_value_t*)jl_instantiate_type_with(sigbody, env, tvarslen);
2657
}
2758
JL_CATCH {
28-
goto getnext; // sigh, we found an invalid type signature. should we warn the user?
29-
}
30-
if (!jl_has_concrete_subtype(*sig))
31-
goto getnext; // signature wouldn't be callable / is invalid -- skip it
32-
if (jl_is_concrete_type(*sig)) {
33-
if (jl_compile_hint((jl_tupletype_t *)*sig))
34-
goto getnext; // success
59+
sig = NULL;
3560
}
36-
37-
getnext:
38-
for (i = 0; i < tvarslen; i++) {
39-
jl_tvar_t *tv = (jl_tvar_t*)env[2 * i];
40-
if (jl_is_uniontype(tv->ub)) {
41-
size_t l = jl_count_union_components(tv->ub);
42-
size_t j = idx[i];
43-
if (j == l) {
44-
env[2 * i + 1] = jl_bottom_type;
45-
idx[i] = 0;
46-
}
47-
else {
48-
jl_value_t *ty = jl_nth_union_component(tv->ub, j);
49-
if (!jl_is_concrete_type(ty))
50-
ty = (jl_value_t*)jl_new_typevar(tv->name, tv->lb, ty);
51-
env[2 * i + 1] = ty;
52-
idx[i] = j + 1;
53-
break;
54-
}
55-
}
56-
else {
57-
env[2 * i + 1] = (jl_value_t*)tv;
58-
}
61+
if (sig) {
62+
roots[0] = sig;
63+
if (jl_is_datatype(sig) && jl_has_concrete_subtype(sig))
64+
all = all && jl_compile_hint((jl_tupletype_t*)sig);
65+
else
66+
all = 0;
5967
}
6068
}
6169
JL_GC_POP();
70+
return all;
6271
}
6372

6473
// f(::Union{...}, ...) is a common pattern
6574
// and expanding the Union may give a leaf function
66-
static void _compile_all_union(jl_value_t *sig)
75+
static int _compile_all_union(jl_value_t *sig)
6776
{
6877
jl_tupletype_t *sigbody = (jl_tupletype_t*)jl_unwrap_unionall(sig);
6978
size_t count_unions = 0;
79+
size_t union_size = 1;
7080
size_t i, l = jl_svec_len(sigbody->parameters);
7181
jl_svec_t *p = NULL;
7282
jl_value_t *methsig = NULL;
7383

7484
for (i = 0; i < l; i++) {
7585
jl_value_t *ty = jl_svecref(sigbody->parameters, i);
76-
if (jl_is_uniontype(ty))
77-
++count_unions;
78-
else if (ty == jl_bottom_type)
79-
return; // why does this method exist?
80-
else if (jl_is_datatype(ty) && !jl_has_free_typevars(ty) &&
81-
((!jl_is_kind(ty) && ((jl_datatype_t*)ty)->isconcretetype) ||
82-
((jl_datatype_t*)ty)->name == jl_type_typename))
83-
return; // no amount of union splitting will make this a leaftype signature
86+
if (jl_is_uniontype(ty)) {
87+
count_unions += 1;
88+
union_size *= jl_count_union_components(ty);
89+
}
90+
else if (jl_is_datatype(ty) &&
91+
((!((jl_datatype_t*)ty)->isconcretetype || jl_is_kind(ty)) &&
92+
((jl_datatype_t*)ty)->name != jl_type_typename))
93+
return 0; // no amount of union splitting will make this a dispatch signature
8494
}
8595

86-
if (count_unions == 0 || count_unions >= 6) {
87-
_compile_all_tvar_union(sig);
88-
return;
96+
if (union_size <= 1 || union_size > 8) {
97+
return _compile_all_tvar_union(sig);
8998
}
9099

91100
int *idx = (int*)alloca(sizeof(int) * count_unions);
92101
for (i = 0; i < count_unions; i++) {
93102
idx[i] = 0;
94103
}
95104

105+
int all = 1;
96106
JL_GC_PUSH2(&p, &methsig);
97107
int idx_ctr = 0, incr = 0;
98108
while (!incr) {
@@ -122,10 +132,12 @@ static void _compile_all_union(jl_value_t *sig)
122132
}
123133
methsig = jl_apply_tuple_type(p, 1);
124134
methsig = jl_rewrap_unionall(methsig, sig);
125-
_compile_all_tvar_union(methsig);
135+
if (!_compile_all_tvar_union(methsig))
136+
all = 0;
126137
}
127138

128139
JL_GC_POP();
140+
return all;
129141
}
130142

131143
static int compile_all_collect__(jl_typemap_entry_t *ml, void *env)
@@ -147,29 +159,32 @@ static int compile_all_collect_(jl_methtable_t *mt, void *env)
147159
return 1;
148160
}
149161

150-
static void jl_compile_all_defs(jl_array_t *mis)
162+
static void jl_compile_all_defs(jl_array_t *mis, int all)
151163
{
152164
jl_array_t *allmeths = jl_alloc_vec_any(0);
153165
JL_GC_PUSH1(&allmeths);
154166

155167
jl_foreach_reachable_mtable(compile_all_collect_, allmeths);
156168

169+
size_t world = jl_atomic_load_acquire(&jl_world_counter);
157170
size_t i, l = jl_array_nrows(allmeths);
158171
for (i = 0; i < l; i++) {
159172
jl_method_t *m = (jl_method_t*)jl_array_ptr_ref(allmeths, i);
160173
if (jl_is_datatype(m->sig) && jl_isa_compileable_sig((jl_tupletype_t*)m->sig, jl_emptysvec, m)) {
161174
// method has a single compilable specialization, e.g. its definition
162175
// signature is concrete. in this case we can just hint it.
163-
jl_compile_hint((jl_tupletype_t*)m->sig);
176+
jl_compile_method_sig(m, m->sig, jl_emptysvec, world);
164177
}
165178
else {
166179
// first try to create leaf signatures from the signature declaration and compile those
167180
_compile_all_union(m->sig);
168181

169-
// finally, compile a fully generic fallback that can work for all arguments
170-
jl_method_instance_t *unspec = jl_get_unspecialized(m);
171-
if (unspec)
172-
jl_array_ptr_1d_push(mis, (jl_value_t*)unspec);
182+
if (all) {
183+
// finally, compile a fully generic fallback that can work for all arguments (even invoke)
184+
jl_method_instance_t *unspec = jl_get_unspecialized(m);
185+
if (unspec)
186+
jl_array_ptr_1d_push(mis, (jl_value_t*)unspec);
187+
}
173188
}
174189
}
175190

@@ -273,8 +288,7 @@ static void *jl_precompile(int all)
273288
// array of MethodInstances and ccallable aliases to include in the output
274289
jl_array_t *m = jl_alloc_vec_any(0);
275290
JL_GC_PUSH1(&m);
276-
if (all)
277-
jl_compile_all_defs(m);
291+
jl_compile_all_defs(m, all);
278292
jl_foreach_reachable_mtable(precompile_enq_all_specializations_, m);
279293
void *native_code = jl_precompile_(m, 0);
280294
JL_GC_POP();

0 commit comments

Comments
 (0)