Skip to content

WIP: Allow generated functions to return a CodeInstance #56650

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions Compiler/extras/CompilerDevTools/Manifest.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.12.0-DEV"
manifest_format = "2.0"
project_hash = "84f495a1bf065c95f732a48af36dd0cd2cefb9d5"

[[deps.Compiler]]
path = "../.."
uuid = "807dbc54-b67e-4c79-8afb-eafe4df6f2e1"
version = "0.0.2"

[[deps.CompilerDevTools]]
path = "."
uuid = "92b2d91f-d2bd-4c05-9214-4609ac33433f"
version = "0.0.0"
5 changes: 5 additions & 0 deletions Compiler/extras/CompilerDevTools/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
name = "CompilerDevTools"
uuid = "92b2d91f-d2bd-4c05-9214-4609ac33433f"

[deps]
Compiler = "807dbc54-b67e-4c79-8afb-eafe4df6f2e1"
62 changes: 62 additions & 0 deletions Compiler/extras/CompilerDevTools/src/CompilerDevTools.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
module CompilerDevTools

using Compiler
using Core.IR

include(joinpath(dirname(pathof(Compiler)), "..", "test", "newinterp.jl"))

@newinterp SplitCacheInterp

function generate_codeinst(world::UInt, #=source=#::LineNumberNode, this, fargtypes)
tt = Base.to_tuple_type(fargtypes)
match = Base._which(tt; raise=false, world)
match === nothing && return nothing # method match failed – the fallback implementation will raise a proper MethodError
mi = Compiler.specialize_method(match)
interp = SplitCacheInterp(; world)
new_compiler_ci = Compiler.typeinf_ext(interp, mi, Compiler.SOURCE_MODE_ABI)

# Construct a CodeInstance that matches `with_new_compiler` and forwards
# to new_compiler_ci

src = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ())
src.slotnames = Symbol[:self, :args]
src.slotflags = fill(zero(UInt8), 2)
src.slottypes = Any[this, fargtypes]
src.isva = true
src.nargs = 2

code = Any[]
ssavaluetypes = Any[]
ncalleeargs = length(fargtypes)
for i = 1:ncalleeargs
push!(code, Expr(:call, getfield, Core.Argument(2), i))
push!(ssavaluetypes, fargtypes[i])
end
push!(code, Expr(:invoke, new_compiler_ci, (SSAValue(i) for i = 1:ncalleeargs)...))
push!(ssavaluetypes, new_compiler_ci.rettype)
push!(code, ReturnNode(SSAValue(ncalleeargs+1)))
push!(ssavaluetypes, Union{})
src.code = code
src.ssavaluetypes = ssavaluetypes

return CodeInstance(
mi, nothing, new_compiler_ci.rettype, new_compiler_ci.exctype,
isdefined(new_compiler_ci, :rettype_const) ? new_compiler_ci.rettype_const : nothing,
src,
isdefined(new_compiler_ci, :rettype_const) ? Int32(0x02) : Int32(0x00),
new_compiler_ci.min_world, new_compiler_ci.max_world,
new_compiler_ci.ipo_purity_bits, nothing, new_compiler_ci.relocatability,
nothing, Core.svec(new_compiler_ci))
end

function refresh()
@eval function with_new_compiler(args...)
$(Expr(:meta, :generated_only))
$(Expr(:meta, :generated, generate_codeinst))
end
end
refresh()

export with_new_compiler

end
4 changes: 2 additions & 2 deletions Compiler/src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,12 @@ end
function call_get_staged(mi::MethodInstance, world::UInt, cache_ci::RefValue{CodeInstance})
token = @_gc_preserve_begin cache_ci
cache_ci_ptr = pointer_from_objref(cache_ci)
src = ccall(:jl_code_for_staged, Ref{CodeInfo}, (Any, UInt, Ptr{CodeInstance}), mi, world, cache_ci_ptr)
src = ccall(:jl_code_for_staged, Any, (Any, UInt, Ptr{CodeInstance}), mi, world, cache_ci_ptr)
@_gc_preserve_end token
return src
end
function call_get_staged(mi::MethodInstance, world::UInt, ::Nothing)
return ccall(:jl_code_for_staged, Ref{CodeInfo}, (Any, UInt, Ptr{Cvoid}), mi, world, C_NULL)
return ccall(:jl_code_for_staged, Any, (Any, UInt, Ptr{Cvoid}), mi, world, C_NULL)
end

function get_cached_uninferred(mi::MethodInstance, world::UInt)
Expand Down
6 changes: 5 additions & 1 deletion base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ function code_lowered(@nospecialize(f), @nospecialize(t=Tuple); generated::Bool=
for m in method_instances(f, t, world)
if generated && hasgenerator(m)
if may_invoke_generator(m)
code = ccall(:jl_code_for_staged, Ref{CodeInfo}, (Any, UInt, Ptr{Cvoid}), m, world, C_NULL)
code = ccall(:jl_code_for_staged, Any, (Any, UInt, Ptr{Cvoid}), m, world, C_NULL)
if isa(code, CodeInstance)
error("Generator `@generated` method ", m, " ",
"returned an optimized result")
end
else
error("Could not expand generator for `@generated` method ", m, ". ",
"This can happen if the provided argument types (", t, ") are ",
Expand Down
30 changes: 27 additions & 3 deletions src/interpreter.c
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,28 @@ static jl_value_t *do_invoke(jl_value_t **args, size_t nargs, interpreter_state
argv[i-1] = eval_value(args[i], s);
jl_value_t *c = args[0];
assert(jl_is_code_instance(c) || jl_is_method_instance(c));
jl_method_instance_t *meth = jl_is_method_instance(c) ? (jl_method_instance_t*)c : ((jl_code_instance_t*)c)->def;
jl_value_t *result = jl_invoke(argv[0], nargs == 2 ? NULL : &argv[1], nargs - 2, meth);
jl_value_t *result = NULL;
if (jl_is_code_instance(c)) {
jl_code_instance_t *codeinst = (jl_code_instance_t*)c;
assert(jl_atomic_load_relaxed(&codeinst->min_world) <= jl_current_task->world_age &&
jl_current_task->world_age <= jl_atomic_load_relaxed(&codeinst->max_world));
jl_callptr_t invoke = jl_atomic_load_acquire(&codeinst->invoke);
if (!invoke) {
jl_compile_codeinst(codeinst);
invoke = jl_atomic_load_acquire(&codeinst->invoke);
}
if (invoke) {
result = invoke(argv[0], nargs == 2 ? NULL : &argv[1], nargs - 2, codeinst);

} else {
if (codeinst->owner != jl_nothing) {
jl_error("Failed to invoke or compile external codeinst");
}
result = jl_invoke(argv[0], nargs == 2 ? NULL : &argv[1], nargs - 2, codeinst->def);
}
} else {
result = jl_invoke(argv[0], nargs == 2 ? NULL : &argv[1], nargs - 2, (jl_method_instance_t*)c);
}
JL_GC_POP();
return result;
}
Expand Down Expand Up @@ -729,7 +749,11 @@ jl_value_t *jl_code_or_ci_for_interpreter(jl_method_instance_t *mi, size_t world
jl_code_instance_t *uninferred = jl_cached_uninferred(cache, world);
if (!uninferred) {
assert(mi->def.method->generator);
src = jl_code_for_staged(mi, world, &uninferred);
ret = jl_code_for_staged(mi, world, &uninferred);
if (jl_is_code_instance(ret)) {
jl_mi_cache_insert(mi, (jl_code_instance_t*)ret);
return (jl_value_t*)ret;
}
}
ret = (jl_value_t*)uninferred;
src = (jl_code_info_t*)jl_atomic_load_relaxed(&uninferred->inferred);
Expand Down
2 changes: 1 addition & 1 deletion src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -1854,7 +1854,7 @@ JL_DLLEXPORT jl_value_t *jl_get_binding_value_if_resolved(jl_binding_t *b JL_PRO
JL_DLLEXPORT jl_value_t *jl_get_binding_value_if_resolved_and_const(jl_binding_t *b JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT;
JL_DLLEXPORT jl_value_t *jl_declare_const_gf(jl_binding_t *b, jl_module_t *mod, jl_sym_t *name);
JL_DLLEXPORT jl_method_t *jl_method_def(jl_svec_t *argdata, jl_methtable_t *mt, jl_code_info_t *f, jl_module_t *module);
JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo, size_t world, jl_code_instance_t **cache);
JL_DLLEXPORT jl_value_t *jl_code_for_staged(jl_method_instance_t *linfo, size_t world, jl_code_instance_t **cache);
JL_DLLEXPORT jl_code_info_t *jl_copy_code_info(jl_code_info_t *src);
JL_DLLEXPORT size_t jl_get_world_counter(void) JL_NOTSAFEPOINT;
JL_DLLEXPORT size_t jl_get_tls_world_age(void) JL_NOTSAFEPOINT;
Expand Down
18 changes: 15 additions & 3 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ JL_DLLEXPORT jl_code_instance_t *jl_cache_uninferred(jl_method_instance_t *mi, j

// Return a newly allocated CodeInfo for the function signature
// effectively described by the tuple (specTypes, env, Method) inside linfo
JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t world, jl_code_instance_t **cache)
JL_DLLEXPORT jl_value_t *jl_code_for_staged(jl_method_instance_t *mi, size_t world, jl_code_instance_t **cache)
{
jl_code_instance_t *cache_ci = jl_atomic_load_relaxed(&mi->cache);
jl_code_instance_t *uninferred_ci = jl_cached_uninferred(cache_ci, world);
Expand All @@ -753,6 +753,7 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t
assert(generator != NULL);
assert(jl_is_method(def));
jl_code_info_t *func = NULL;
jl_value_t *ret = NULL;
jl_value_t *ex = NULL;
jl_value_t *kind = NULL;
jl_code_info_t *uninferred = NULL;
Expand All @@ -774,7 +775,16 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t
ex = jl_call_staged(def, generator, world, mi->sparam_vals, jl_svec_data(ttdt->parameters), jl_nparams(ttdt));

// do some post-processing
if (jl_is_code_info(ex)) {
if (jl_is_code_instance(ex)) {
jl_code_instance_t *ci = (jl_code_instance_t*)ex;
if (ci->owner != jl_nothing)
jl_error("CodeInstance returned from generator must have owner == nothing");
if (ci->next)
jl_error("CodeInstance returned from generator must not be in the cache");
ret = ex;
goto done;
}
else if (jl_is_code_info(ex)) {
func = (jl_code_info_t*)ex;
jl_array_t *stmts = (jl_array_t*)func->code;
jl_resolve_globals_in_ir(stmts, def->module, mi->sparam_vals, 1);
Expand Down Expand Up @@ -865,6 +875,8 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t
*cache = cached_ci;
}

ret = (jl_value_t*)func;
done:
ct->ptls->in_pure_callback = last_in;
jl_lineno = last_lineno;
ct->world_age = last_age;
Expand All @@ -875,7 +887,7 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t
jl_rethrow();
}
JL_GC_POP();
return func;
return ret;
}

JL_DLLEXPORT jl_code_info_t *jl_copy_code_info(jl_code_info_t *src)
Expand Down