From eebacab497dcf9e049293434384690ea904e3e7c Mon Sep 17 00:00:00 2001 From: Jeff Bezanson Date: Fri, 21 Oct 2022 14:18:24 -0400 Subject: [PATCH] fix #46778, precompile() for abstract but compileable signatures --- src/gf.c | 125 ++++++++++++++++++++++++++++++++++++--------- test/precompile.jl | 7 +++ 2 files changed, 107 insertions(+), 25 deletions(-) diff --git a/src/gf.c b/src/gf.c index 1f896921d45f5..668072172a38a 100644 --- a/src/gf.c +++ b/src/gf.c @@ -2255,6 +2255,39 @@ JL_DLLEXPORT jl_value_t *jl_normalize_to_compilable_sig(jl_methtable_t *mt, jl_t return is_compileable ? (jl_value_t*)tt : jl_nothing; } +// return a MethodInstance for a compileable method_match +jl_method_instance_t *jl_method_match_to_mi(jl_method_match_t *match, size_t world, size_t min_valid, size_t max_valid, int mt_cache) +{ + jl_method_t *m = match->method; + jl_svec_t *env = match->sparams; + jl_tupletype_t *ti = match->spec_types; + jl_method_instance_t *mi = NULL; + if (jl_is_datatype(ti)) { + jl_methtable_t *mt = jl_method_get_table(m); + if ((jl_value_t*)mt != jl_nothing) { + // get the specialization without caching it + if (mt_cache && ((jl_datatype_t*)ti)->isdispatchtuple) { + // Since we also use this presence in the cache + // to trigger compilation when producing `.ji` files, + // inject it there now if we think it will be + // used via dispatch later (e.g. because it was hinted via a call to `precompile`) + JL_LOCK(&mt->writelock); + mi = cache_method(mt, &mt->cache, (jl_value_t*)mt, ti, m, world, min_valid, max_valid, env); + JL_UNLOCK(&mt->writelock); + } + else { + jl_value_t *tt = jl_normalize_to_compilable_sig(mt, ti, env, m); + JL_GC_PUSH1(&tt); + if (tt != jl_nothing) { + mi = jl_specializations_get_linfo(m, (jl_value_t*)tt, env); + } + JL_GC_POP(); + } + } + } + return mi; +} + // compile-time method lookup jl_method_instance_t *jl_get_specialization1(jl_tupletype_t *types JL_PROPAGATES_ROOT, size_t world, size_t *min_valid, size_t *max_valid, int mt_cache) { @@ -2274,36 +2307,78 @@ jl_method_instance_t *jl_get_specialization1(jl_tupletype_t *types JL_PROPAGATES *max_valid = max_valid2; if (matches == jl_false || jl_array_len(matches) != 1 || ambig) return NULL; - jl_value_t *tt = NULL; - JL_GC_PUSH2(&matches, &tt); + JL_GC_PUSH1(&matches); jl_method_match_t *match = (jl_method_match_t*)jl_array_ptr_ref(matches, 0); - jl_method_t *m = match->method; - jl_svec_t *env = match->sparams; - jl_tupletype_t *ti = match->spec_types; - jl_method_instance_t *nf = NULL; - if (jl_is_datatype(ti)) { - jl_methtable_t *mt = jl_method_get_table(m); - if ((jl_value_t*)mt != jl_nothing) { - // get the specialization without caching it - if (mt_cache && ((jl_datatype_t*)ti)->isdispatchtuple) { - // Since we also use this presence in the cache - // to trigger compilation when producing `.ji` files, - // inject it there now if we think it will be - // used via dispatch later (e.g. because it was hinted via a call to `precompile`) - JL_LOCK(&mt->writelock); - nf = cache_method(mt, &mt->cache, (jl_value_t*)mt, ti, m, world, min_valid2, max_valid2, env); - JL_UNLOCK(&mt->writelock); - } - else { - tt = jl_normalize_to_compilable_sig(mt, ti, env, m); - if (tt != jl_nothing) { - nf = jl_specializations_get_linfo(m, (jl_value_t*)tt, env); + jl_method_instance_t *mi = jl_method_match_to_mi(match, world, min_valid2, max_valid2, mt_cache); + JL_GC_POP(); + return mi; +} + +// Get a MethodInstance for a precompile() call. This uses a special kind of lookup that +// tries to find a method for which the requested signature is compileable. +jl_method_instance_t *jl_get_compile_hint_specialization(jl_tupletype_t *types JL_PROPAGATES_ROOT, size_t world, size_t *min_valid, size_t *max_valid, int mt_cache) +{ + if (jl_has_free_typevars((jl_value_t*)types)) + return NULL; // don't poison the cache due to a malformed query + if (!jl_has_concrete_subtype((jl_value_t*)types)) + return NULL; + + size_t min_valid2 = 1; + size_t max_valid2 = ~(size_t)0; + int ambig = 0; + jl_value_t *matches = jl_matching_methods(types, jl_nothing, -1, 0, world, &min_valid2, &max_valid2, &ambig); + if (*min_valid < min_valid2) + *min_valid = min_valid2; + if (*max_valid > max_valid2) + *max_valid = max_valid2; + size_t i, n = jl_array_len(matches); + if (n == 0) + return NULL; + JL_GC_PUSH1(&matches); + jl_method_match_t *match = NULL; + if (n == 1) { + match = (jl_method_match_t*)jl_array_ptr_ref(matches, 0); + } + else { + // first, select methods for which `types` is compileable + size_t count = 0; + for (i = 0; i < n; i++) { + jl_method_match_t *match1 = (jl_method_match_t*)jl_array_ptr_ref(matches, i); + if (jl_isa_compileable_sig(types, match1->method)) + jl_array_ptr_set(matches, count++, (jl_value_t*)match1); + } + jl_array_del_end((jl_array_t*)matches, n - count); + n = count; + // now remove methods that are more specific than others in the list. + // this is because the intent of precompiling e.g. f(::DataType) is to + // compile that exact method if it exists, and not lots of f(::Type{X}) methods + int exclude; + count = 0; + for (i = 0; i < n; i++) { + jl_method_match_t *match1 = (jl_method_match_t*)jl_array_ptr_ref(matches, i); + exclude = 0; + for (size_t j = n-1; j > i; j--) { // more general methods maybe more likely to be at end + jl_method_match_t *match2 = (jl_method_match_t*)jl_array_ptr_ref(matches, j); + if (jl_type_morespecific(match1->method->sig, match2->method->sig)) { + exclude = 1; + break; } } + if (!exclude) + jl_array_ptr_set(matches, count++, (jl_value_t*)match1); + if (count > 1) + break; } + // at this point if there are 0 matches left we found nothing, or if there are + // more than one the request is ambiguous and we ignore it. + if (count == 1) + match = (jl_method_match_t*)jl_array_ptr_ref(matches, 0); } + jl_method_instance_t *mi = NULL; + if (match != NULL) + mi = jl_method_match_to_mi(match, world, min_valid2, max_valid2, mt_cache); JL_GC_POP(); - return nf; + return mi; } static void _generate_from_hint(jl_method_instance_t *mi, size_t world) @@ -2370,7 +2445,7 @@ JL_DLLEXPORT int jl_compile_hint(jl_tupletype_t *types) size_t world = jl_atomic_load_acquire(&jl_world_counter); size_t min_valid = 0; size_t max_valid = ~(size_t)0; - jl_method_instance_t *mi = jl_get_specialization1(types, world, &min_valid, &max_valid, 1); + jl_method_instance_t *mi = jl_get_compile_hint_specialization(types, world, &min_valid, &max_valid, 1); if (mi == NULL) return 0; JL_GC_PROMISE_ROOTED(mi); diff --git a/test/precompile.jl b/test/precompile.jl index 098d1ffbba231..5b49ad4a3b31a 100644 --- a/test/precompile.jl +++ b/test/precompile.jl @@ -1556,3 +1556,10 @@ end empty!(Base.DEPOT_PATH) append!(Base.DEPOT_PATH, original_depot_path) + +@testset "issue 46778" begin + f46778(::Any, ::Type{Int}) = 1 + f46778(::Any, ::DataType) = 2 + @test precompile(Tuple{typeof(f46778), Int, DataType}) + @test which(f46778, Tuple{Any,DataType}).specializations[1].cache.invoke != C_NULL +end