Skip to content

Commit

Permalink
Get better type info from partially generated functions
Browse files Browse the repository at this point in the history
Consider the following function:
```
julia> function foo(a, b)
           ntuple(i->(a+b; i), Val(4))
       end
foo (generic function with 1 method)
```

(In particular note that the return type of the closure does not depend on the types
of `a` and b`). Unfortunately, prior to this change, inference was unable to determine
the return type in this situation:

```
julia> code_typed(foo, Tuple{Any, Any}, trace=true)
Refused to call generated function with non-concrete argument types ntuple(::getfield(Main, Symbol("##15#16")){_A,_B} where _B where _A, ::Val{4}) [GeneratedNotConcrete]

1-element Array{Any,1}:
 CodeInfo(
1 ─ %1 = Main.:(##15#16)::Const(##15#16, false)
│   %2 = Core.typeof(a)::DataType
│   %3 = Core.typeof(b)::DataType
│   %4 = Core.apply_type(%1, %2, %3)::Type{##15#16{_A,_B}} where _B where _A
│   %5 = %new(%4, a, b)::##15#16{_A,_B} where _B where _A
│   %6 = Main.ntuple(%5, $(QuoteNode(Val{4}())))::Any
└──      return %6
) => Any
```

Looking at the definition of ntuple

https://github.com/JuliaLang/julia/blob/abb09f88804c4e74c752a66157e767c9b0f8945d/base/ntuple.jl#L45-L56

we see that it is a generated function an inference thus refuses to invoke it,
unless it can prove the concrete type of *all* arguments to the function. As
the above example illustrates, this restriction is more stringent than necessary.
It is true that we cannot invoke generated functions on arbitrary abstract
signatures (because we neither want to the user to have to be able to nor
do we trust that users are able to preverse monotonicity - i.e. that the return
type of the generated code will always be a subtype of the return type of a more
abstract signature).

However, if some piece of information is not used (the type of the passed function
in this case), there is no problem with calling the generated function (since
information that is unnused cannot possibly affect monotnicity).

This PR allows us to recognize pieces of information that are *syntactically* unused,
and call the generated functions, even if we do not have those pieces of information.

As a result, we are now able to infer the return type of the above function:
```
julia> code_typed(foo, Tuple{Any, Any})
1-element Array{Any,1}:
 CodeInfo(
1 ─ %1 = Main.:(##3#4)::Const(##3#4, false)
│   %2 = Core.typeof(a)::DataType
│   %3 = Core.typeof(b)::DataType
│   %4 = Core.apply_type(%1, %2, %3)::Type{##3#4{_A,_B}} where _B where _A
│   %5 = %new(%4, a, b)::##3#4{_A,_B} where _B where _A
│   %6 = Main.ntuple(%5, $(QuoteNode(Val{4}())))::NTuple{4,Int64}
└──      return %6
) => NTuple{4,Int64}
```

In particular, we use the new frontent `used` flags from the previous commit.
One additional complication is that we want to accesss these flags without
uncompressing the generator source, so we change the compression scheme to
place the flags at a known location.

Fixes #31004
  • Loading branch information
Keno committed Feb 11, 2019
1 parent 635b8c5 commit 3bc4ea7
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 15 deletions.
5 changes: 1 addition & 4 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,7 @@ function code_for_method(method::Method, @nospecialize(atypes), sparams::SimpleV
if world < min_world(method) || world > max_world(method)
return nothing
end
if isdefined(method, :generator) && !isdispatchtuple(atypes)
# don't call staged functions on abstract types.
# (see issues #8504, #10230)
# we can't guarantee that their type behavior is monotonic.
if isdefined(method, :generator) && !may_invoke_generator(method, atypes, sparams)
return nothing
end
if preexisting
Expand Down
71 changes: 68 additions & 3 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -941,15 +941,80 @@ struct CodegenParams
emit_function, emitted_function)
end

const SLOT_USED = 0x8
ast_slotflag(@nospecialize(code), i) = ccall(:jl_ast_slotflag, UInt8, (Any, Csize_t), code, i - 1)

"""
may_invoke_generator(method, atypes, sparams)
Computes whether or not we may invoke the generator for the given `method` on
the given atypes and sparams. For correctness, all generated function are
required to return monotonic answers. However, since we don't expect users to
be able to successfully implement this criterion, we only call generated
functions on concrete types. The one exception to this is that we allow calling
generators with abstract types if the generator does not use said abstract type
(and thus cannot incorrectly use it to break monotonicity). This function
computes whether we are in either of these cases.
"""
function may_invoke_generator(method::Method, @nospecialize(atypes), sparams::SimpleVector)
# If we have complete information, we may always call the generator
isdispatchtuple(atypes) && return true

# We don't have complete information, but it is possible that the generator
# syntactically doesn't make use of the information we don't have. Check
# for that.

# For now, only handle the (common, generated by the frontend case) that the
# generator only has one method
isa(method.generator, Core.GeneratedFunctionStub) || return false
generator_mt = typeof(method.generator.gen).name.mt
length(generator_mt) == 1 || return false

generator_method = first(MethodList(generator_mt))
nsparams = length(sparams)
isdefined(generator_method, :source) || return false
code = generator_method.source
nslots = ccall(:jl_ast_nslots, Int, (Any,), code)
at = unwrap_unionall(atypes)
(nslots >= 1 + length(sparams) + length(at.parameters)) || return false

for i = 1:nsparams
if isa(sparams[i], TypeVar)
if (ast_slotflag(code, 1 + i) & SLOT_USED) != 0
return false
end
end
end
for i = 1:length(at.parameters)
if !isdispatchelem(at.parameters[i])
if (ast_slotflag(code, 1 + i + nsparams) & SLOT_USED) != 0
return false
end
end
end
return true
end

# give a decent error message if we try to instantiate a staged function on non-leaf types
function func_for_method_checked(m::Method, @nospecialize types)
function func_for_method_checked(m::Method, @nospecialize(types), sparams::SimpleVector)
if isdefined(m, :generator) && !Core.Compiler.may_invoke_generator(m, types, sparams)
error("cannot call @generated function `", m, "` ",
"with abstract argument types: ", types)
end
return m
end

function func_for_method_checked(m::Method, @nospecialize(types))
depwarn("The two argument form of `func_for_method_checked` is deprecated. Pass sparams in addition.",
:func_for_method_checked)
if isdefined(m, :generator) && !isdispatchtuple(types)
error("cannot call @generated function `", m, "` ",
"with abstract argument types: ", types)
end
return m
end


"""
code_typed(f, types; optimize=true, debuginfo=:default)
Expand Down Expand Up @@ -978,7 +1043,7 @@ function code_typed(@nospecialize(f), @nospecialize(types=Tuple);
types = to_tuple_type(types)
asts = []
for x in _methods(f, types, -1, world)
meth = func_for_method_checked(x[3], types)
meth = func_for_method_checked(x[3], types, x[2])
(code, ty) = Core.Compiler.typeinf_code(meth, x[1], x[2], optimize, params)
code === nothing && error("inference not successful") # inference disabled?
debuginfo == :none && remove_linenums!(code)
Expand All @@ -997,7 +1062,7 @@ function return_types(@nospecialize(f), @nospecialize(types=Tuple))
world = ccall(:jl_get_world_counter, UInt, ())
params = Core.Compiler.Params(world)
for x in _methods(f, types, -1, world)
meth = func_for_method_checked(x[3], types)
meth = func_for_method_checked(x[3], types, x[2])
ty = Core.Compiler.typeinf_type(meth, x[1], x[2], params)
ty === nothing && error("inference not successful") # inference disabled?
push!(rt, ty)
Expand Down
26 changes: 22 additions & 4 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -1212,7 +1212,7 @@ static void write_mod_list(ios_t *s, jl_array_t *a)
}

// "magic" string and version header of .ji file
static const int JI_FORMAT_VERSION = 7;
static const int JI_FORMAT_VERSION = 8;
static const char JI_MAGIC[] = "\373jli\r\n\032\n"; // based on PNG signature
static const uint16_t BOM = 0xFEFF; // byte-order marker
static void write_header(ios_t *s)
Expand Down Expand Up @@ -2459,6 +2459,13 @@ JL_DLLEXPORT jl_array_t *jl_compress_ast(jl_method_t *m, jl_code_info_t *code)
size_t nsyms = jl_array_len(code->slotnames);
assert(nsyms >= m->nargs && nsyms < INT32_MAX); // required by generated functions
write_int32(s.s, nsyms);
assert(nsyms == jl_array_len(code->slotflags));
ios_write(s.s, (char*)jl_array_data(code->slotflags), nsyms);

// N.B.: The layout of everything before this point is explicitly referenced
// by the various jl_ast_ accessors. Make sure to adjust those if you change
// the data layout.

for (i = 0; i < nsyms; i++) {
jl_sym_t *name = (jl_sym_t*)jl_array_ptr_ref(code->slotnames, i);
assert(jl_is_symbol(name));
Expand All @@ -2468,7 +2475,7 @@ JL_DLLEXPORT jl_array_t *jl_compress_ast(jl_method_t *m, jl_code_info_t *code)
}

size_t nf = jl_datatype_nfields(jl_code_info_type);
for (i = 0; i < nf - 5; i++) {
for (i = 0; i < nf - 6; i++) {
if (i == 1) // skip codelocs
continue;
int copy = (i != 2); // don't copy contents of method_for_inference_limit_heuristics field
Expand Down Expand Up @@ -2536,6 +2543,9 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ast(jl_method_t *m, jl_array_t *data)
code->pure = !!(flags & (1 << 0));

size_t nslots = read_int32(&src);
code->slotflags = jl_alloc_array_1d(jl_array_uint8_type, nslots);
ios_read(s.s, (char*)jl_array_data(code->slotflags), nslots);

jl_array_t *syms = jl_alloc_vec_any(nslots);
code->slotnames = syms;
for (i = 0; i < nslots; i++) {
Expand All @@ -2547,7 +2557,7 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ast(jl_method_t *m, jl_array_t *data)
}

size_t nf = jl_datatype_nfields(jl_code_info_type);
for (i = 0; i < nf - 5; i++) {
for (i = 0; i < nf - 6; i++) {
if (i == 1)
continue;
assert(jl_field_isptr(jl_code_info_type, i));
Expand Down Expand Up @@ -2620,6 +2630,14 @@ JL_DLLEXPORT ssize_t jl_ast_nslots(jl_array_t *data)
}
}

JL_DLLEXPORT uint8_t jl_ast_slotflag(jl_array_t *data, size_t i)
{
assert(i < jl_ast_nslots(data));
if (jl_is_code_info(data))
return ((uint8_t*)((jl_code_info_t*)data)->slotflags->data)[i];
return ((uint8_t*)data->data)[1 + sizeof(int32_t) + i];
}

JL_DLLEXPORT void jl_fill_argnames(jl_array_t *data, jl_array_t *names)
{
size_t i, nargs = jl_array_len(names);
Expand All @@ -2637,7 +2655,7 @@ JL_DLLEXPORT void jl_fill_argnames(jl_array_t *data, jl_array_t *names)
int nslots = jl_load_unaligned_i32(d + 1);
assert(nslots >= nargs);
(void)nslots;
char *namestr = d + 5;
char *namestr = d + 5 + nslots;
for (i = 0; i < nargs; i++) {
size_t namelen = strlen(namestr);
jl_sym_t *name = jl_symbol_n(namestr, namelen);
Expand Down
1 change: 1 addition & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -1548,6 +1548,7 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ast(jl_method_t *m, jl_array_t *data)
JL_DLLEXPORT uint8_t jl_ast_flag_inferred(jl_array_t *data);
JL_DLLEXPORT uint8_t jl_ast_flag_inlineable(jl_array_t *data);
JL_DLLEXPORT uint8_t jl_ast_flag_pure(jl_array_t *data);
JL_DLLEXPORT uint8_t jl_ast_slotflag(jl_array_t *data, size_t i);
JL_DLLEXPORT void jl_fill_argnames(jl_array_t *data, jl_array_t *names);

JL_DLLEXPORT int jl_is_operator(char *sym);
Expand Down
5 changes: 3 additions & 2 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ STATIC_INLINE jl_value_t *jl_call_staged(jl_method_t *def, jl_value_t *generator
JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo)
{
JL_TIMING(STAGED_FUNCTION);
jl_tupletype_t *tt = (jl_tupletype_t*)linfo->specTypes;
jl_value_t *tt = linfo->specTypes;
jl_method_t *def = linfo->def.method;
jl_value_t *generator = def->generator;
assert(generator != NULL);
Expand All @@ -402,7 +402,8 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo)
ptls->world_age = def->min_world;

// invoke code generator
ex = jl_call_staged(linfo->def.method, generator, linfo->sparam_vals, jl_svec_data(tt->parameters), jl_nparams(tt));
jl_tupletype_t *ttdt = (jl_tupletype_t*)jl_unwrap_unionall(tt);
ex = jl_call_staged(linfo->def.method, generator, linfo->sparam_vals, jl_svec_data(ttdt->parameters), jl_nparams(ttdt));

if (jl_is_code_info(ex)) {
func = (jl_code_info_t*)ex;
Expand Down
2 changes: 1 addition & 1 deletion stdlib/InteractiveUtils/src/codeview.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ function _dump_function(@nospecialize(f), @nospecialize(t), native::Bool, wrappe
t = to_tuple_type(t)
tt = signature_type(f, t)
(ti, env) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), tt, meth.sig)::Core.SimpleVector
meth = Base.func_for_method_checked(meth, ti)
meth = Base.func_for_method_checked(meth, ti, env)
linfo = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, (Any, Any, Any, UInt), meth, ti, env, world)
# get the code for it
return _dump_function_linfo(linfo, world, native, wrapper, strip_ir_metadata, dump_module, syntax, optimize, debuginfo, params)
Expand Down
23 changes: 22 additions & 1 deletion test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,7 @@ function get_linfo(@nospecialize(f), @nospecialize(t))
tt = Tuple{ft, t.parameters...}
precompile(tt)
(ti, env) = ccall(:jl_type_intersection_with_env, Ref{Core.SimpleVector}, (Any, Any), tt, meth.sig)
meth = Base.func_for_method_checked(meth, tt)
meth = Base.func_for_method_checked(meth, tt, env)
return ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
(Any, Any, Any, UInt), meth, tt, env, world)
end
Expand Down Expand Up @@ -2224,3 +2224,24 @@ _call_rttf_test() = Core.Compiler.return_type(_rttf_test, Tuple{Any})
f_with_Type_arg(::Type{T}) where {T} = T
@test Base.return_types(f_with_Type_arg, (Any,)) == Any[Type]
@test Base.return_types(f_with_Type_arg, (Type{Vector{T}} where T,)) == Any[Type{Vector{T}} where T]

# Generated functions that only reference some of their arguments
@inline function my_ntuple(f::F, ::Val{N}) where {F,N}
N::Int
(N >= 0) || throw(ArgumentError(string("tuple length should be ≥0, got ", N)))
if @generated
quote
@Base.nexprs $N i -> t_i = f(i)
@Base.ncall $N tuple t
end
else
Tuple(f(i) for i = 1:N)
end
end
call_ntuple(a, b) = my_ntuple(i->(a+b; i), Val(4))
@test Base.return_types(call_ntuple, Tuple{Any,Any}) == [NTuple{4, Int}]
@test length(code_typed(my_ntuple, Tuple{Any, Val{4}})) == 1
@test_throws ErrorException code_typed(my_ntuple, Tuple{Any, Val})

@generated unionall_sig_generated(::Vector{T}, b::Vector{S}) where {T, S} = :($b)
@test length(code_typed(unionall_sig_generated, Tuple{Any, Vector{Int}})) == 1

0 comments on commit 3bc4ea7

Please sign in to comment.