Skip to content

Commit 981b542

Browse files
Kenostevengj
authored andcommitted
Extend invoke to accept CodeInstance (#56660)
This is an alternative mechanism to #56650 that largely achieves the same result, but by hooking into `invoke` rather than a generated function. They are orthogonal mechanisms, and its possible we want both. However, in #56650, both Jameson and Valentin were skeptical of the generated function signature bottleneck. This PR is sort of a hybrid of mechanism in #52964 and what I proposed in #56650 (comment). In particular, this PR: 1. Extends `invoke` to support a CodeInstance in place of its usual `types` argument. 2. Adds a new `typeinf` optimized generic. The semantics of this optimized generic allow the compiler to instead call a companion `typeinf_edge` function, allowing a mid-inference interpreter switch (like #52964), without being forced through a concrete signature bottleneck. However, if calling `typeinf_edge` does not work (e.g. because the compiler version is mismatched), this still has well defined semantics, you just don't get inference support. The additional benefit of the `typeinf` optimized generic is that it lets custom cache owners tell the runtime how to "cure" code instances that have lost their native code. Currently the runtime only knows how to do that for `owner == nothing` CodeInstances (by re-running inference). This extension is not implemented, but the idea is that the runtime would be permitted to call the `typeinf` optimized generic on the dead CodeInstance's `owner` and `def` fields to obtain a cured CodeInstance (or a user-actionable error from the plugin). This PR includes an implementation of `with_new_compiler` from #56650. This PR includes just enough compiler support to make the compiler optimize this to the same code that #56650 produced: ``` julia> @code_typed with_new_compiler(sin, 1.0) CodeInfo( 1 ─ $(Expr(:foreigncall, :(:jl_get_tls_world_age), UInt64, svec(), 0, :(:ccall)))::UInt64 │ %2 = builtin Core.getfield(args, 1)::Float64 │ %3 = invoke sin(%2::Float64)::Float64 └── return %3 ) => Float64 ``` However, the implementation here is extremely incomplete. I'm putting it up only as a directional sketch to see if people prefer it over #56650. If so, I would prepare a cleaned up version of this PR that has the optimized generics as well as the curing support, but not the full inference integration (which needs a fair bit more work).
1 parent 2895c3d commit 981b542

File tree

14 files changed

+242
-15
lines changed

14 files changed

+242
-15
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# This file is machine-generated - editing it directly is not advised
2+
3+
julia_version = "1.12.0-DEV"
4+
manifest_format = "2.0"
5+
project_hash = "84f495a1bf065c95f732a48af36dd0cd2cefb9d5"
6+
7+
[[deps.Compiler]]
8+
path = "../.."
9+
uuid = "807dbc54-b67e-4c79-8afb-eafe4df6f2e1"
10+
version = "0.0.2"
11+
12+
[[deps.CompilerDevTools]]
13+
path = "."
14+
uuid = "92b2d91f-d2bd-4c05-9214-4609ac33433f"
15+
version = "0.0.0"
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
name = "CompilerDevTools"
2+
uuid = "92b2d91f-d2bd-4c05-9214-4609ac33433f"
3+
4+
[deps]
5+
Compiler = "807dbc54-b67e-4c79-8afb-eafe4df6f2e1"
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
module CompilerDevTools
2+
3+
using Compiler
4+
using Core.IR
5+
6+
struct SplitCacheOwner; end
7+
struct SplitCacheInterp <: Compiler.AbstractInterpreter
8+
world::UInt
9+
inf_params::Compiler.InferenceParams
10+
opt_params::Compiler.OptimizationParams
11+
inf_cache::Vector{Compiler.InferenceResult}
12+
function SplitCacheInterp(;
13+
world::UInt = Base.get_world_counter(),
14+
inf_params::Compiler.InferenceParams = Compiler.InferenceParams(),
15+
opt_params::Compiler.OptimizationParams = Compiler.OptimizationParams(),
16+
inf_cache::Vector{Compiler.InferenceResult} = Compiler.InferenceResult[])
17+
new(world, inf_params, opt_params, inf_cache)
18+
end
19+
end
20+
21+
Compiler.InferenceParams(interp::SplitCacheInterp) = interp.inf_params
22+
Compiler.OptimizationParams(interp::SplitCacheInterp) = interp.opt_params
23+
Compiler.get_inference_world(interp::SplitCacheInterp) = interp.world
24+
Compiler.get_inference_cache(interp::SplitCacheInterp) = interp.inf_cache
25+
Compiler.cache_owner(::SplitCacheInterp) = SplitCacheOwner()
26+
27+
import Core.OptimizedGenerics.CompilerPlugins: typeinf, typeinf_edge
28+
@eval @noinline typeinf(::SplitCacheOwner, mi::MethodInstance, source_mode::UInt8) =
29+
Base.invoke_in_world(which(typeinf, Tuple{SplitCacheOwner, MethodInstance, UInt8}).primary_world, Compiler.typeinf_ext, SplitCacheInterp(; world=Base.tls_world_age()), mi, source_mode)
30+
31+
@eval @noinline function typeinf_edge(::SplitCacheOwner, mi::MethodInstance, parent_frame::Compiler.InferenceState, world::UInt, source_mode::UInt8)
32+
# TODO: This isn't quite right, we're just sketching things for now
33+
interp = SplitCacheInterp(; world)
34+
Compiler.typeinf_edge(interp, mi.def, mi.specTypes, Core.svec(), parent_frame, false, false)
35+
end
36+
37+
# TODO: This needs special compiler support to properly case split for multiple
38+
# method matches, etc.
39+
@noinline function mi_for_tt(tt, world=Base.tls_world_age())
40+
interp = SplitCacheInterp(; world)
41+
match, _ = Compiler.findsup(tt, Compiler.method_table(interp))
42+
Base.specialize_method(match)
43+
end
44+
45+
function with_new_compiler(f, args...)
46+
tt = Base.signature_type(f, typeof(args))
47+
world = Base.tls_world_age()
48+
new_compiler_ci = Core.OptimizedGenerics.CompilerPlugins.typeinf(
49+
SplitCacheOwner(), mi_for_tt(tt), Compiler.SOURCE_MODE_ABI
50+
)
51+
invoke(f, new_compiler_ci, args...)
52+
end
53+
54+
export with_new_compiler
55+
56+
end

Compiler/src/abstractinterpretation.jl

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2218,16 +2218,46 @@ function abstract_invoke(interp::AbstractInterpreter, arginfo::ArgInfo, si::Stmt
22182218
ft = widenconst(ft′)
22192219
ft === Bottom && return Future(CallMeta(Bottom, Any, EFFECTS_THROWS, NoCallInfo()))
22202220
types = argtype_by_index(argtypes, 3)
2221-
if types isa Const && types.val isa Method
2222-
method = types.val::Method
2223-
types = method # argument value
2224-
lookupsig = method.sig # edge kind
2225-
argtype = argtypes_to_type(pushfirst!(argtype_tail(argtypes, 4), ft))
2226-
nargtype = typeintersect(lookupsig, argtype)
2227-
nargtype === Bottom && return Future(CallMeta(Bottom, TypeError, EFFECTS_THROWS, NoCallInfo()))
2228-
nargtype isa DataType || return Future(CallMeta(Any, Any, Effects(), NoCallInfo())) # other cases are not implemented below
2221+
if types isa Const && types.val isa Union{Method, CodeInstance}
2222+
method_or_ci = types.val
2223+
if isa(method_or_ci, CodeInstance)
2224+
our_world = sv.world.this
2225+
argtype = argtypes_to_type(pushfirst!(argtype_tail(argtypes, 4), ft))
2226+
sig = method_or_ci.def.specTypes
2227+
exct = method_or_ci.exctype
2228+
if !hasintersect(argtype, sig)
2229+
return Future(CallMeta(Bottom, TypeError, EFFECTS_THROWS, NoCallInfo()))
2230+
elseif !(argtype <: sig)
2231+
exct = Union{exct, TypeError}
2232+
end
2233+
callee_valid_range = WorldRange(method_or_ci.min_world, method_or_ci.max_world)
2234+
if !(our_world in callee_valid_range)
2235+
if our_world < first(callee_valid_range)
2236+
update_valid_age!(sv, WorldRange(first(sv.world.valid_worlds), first(callee_valid_range)-1))
2237+
else
2238+
update_valid_age!(sv, WorldRange(last(callee_valid_range)+1, last(sv.world.valid_worlds)))
2239+
end
2240+
return Future(CallMeta(Bottom, ErrorException, EFFECTS_THROWS, NoCallInfo()))
2241+
end
2242+
# TODO: When we add curing, we may want to assume this is nothrow
2243+
if (method_or_ci.owner === Nothing && method_ir_ci.def.def isa Method)
2244+
exct = Union{exct, ErrorException}
2245+
end
2246+
update_valid_age!(sv, callee_valid_range)
2247+
return Future(CallMeta(method_or_ci.rettype, exct, Effects(decode_effects(method_or_ci.ipo_purity_bits), nothrow=(exct===Bottom)),
2248+
InvokeCICallInfo(method_or_ci)))
2249+
else
2250+
method = method_or_ci::Method
2251+
types = method # argument value
2252+
lookupsig = method.sig # edge kind
2253+
argtype = argtypes_to_type(pushfirst!(argtype_tail(argtypes, 4), ft))
2254+
nargtype = typeintersect(lookupsig, argtype)
2255+
nargtype === Bottom && return Future(CallMeta(Bottom, TypeError, EFFECTS_THROWS, NoCallInfo()))
2256+
nargtype isa DataType || return Future(CallMeta(Any, Any, Effects(), NoCallInfo())) # other cases are not implemented below
2257+
# Fall through to generic invoke handling
2258+
end
22292259
else
2230-
widenconst(types) >: Method && return Future(CallMeta(Any, Any, Effects(), NoCallInfo()))
2260+
widenconst(types) >: Union{Method, CodeInstance} && return Future(CallMeta(Any, Any, Effects(), NoCallInfo()))
22312261
(types, isexact, isconcrete, istype) = instanceof_tfunc(argtype_by_index(argtypes, 3), false)
22322262
isexact || return Future(CallMeta(Any, Any, Effects(), NoCallInfo()))
22332263
unwrapped = unwrap_unionall(types)

Compiler/src/abstractlattice.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ end
229229
if isa(t, Const)
230230
# don't consider mutable values useful constants
231231
val = t.val
232-
return isa(val, Symbol) || isa(val, Type) || isa(val, Method) || !ismutable(val)
232+
return isa(val, Symbol) || isa(val, Type) || isa(val, Method) || isa(val, CodeInstance) || !ismutable(val)
233233
end
234234
isa(t, PartialTypeVar) && return false # this isn't forwardable
235235
return is_const_prop_profitable_arg(widenlattice(𝕃), t)

Compiler/src/bootstrap.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,15 @@
55
# especially try to make sure any recursive and leaf functions have concrete signatures,
66
# since we won't be able to specialize & infer them at runtime
77

8-
activate_codegen!() = ccall(:jl_set_typeinf_func, Cvoid, (Any,), typeinf_ext_toplevel)
8+
function activate_codegen!()
9+
ccall(:jl_set_typeinf_func, Cvoid, (Any,), typeinf_ext_toplevel)
10+
Core.eval(Compiler, quote
11+
let typeinf_world_age = Base.tls_world_age()
12+
@eval Core.OptimizedGenerics.CompilerPlugins.typeinf(::Nothing, mi::MethodInstance, source_mode::UInt8) =
13+
Base.invoke_in_world($(Expr(:$, :typeinf_world_age)), typeinf_ext_toplevel, mi, Base.tls_world_age(), source_mode)
14+
end
15+
end)
16+
end
917

1018
function bootstrap!()
1119
let time() = ccall(:jl_clock_now, Float64, ())

Compiler/src/stmtinfo.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,17 @@ end
268268
add_edges_impl(edges::Vector{Any}, info::UnionSplitApplyCallInfo) =
269269
for split in info.infos; add_edges!(edges, split); end
270270

271+
"""
272+
info::InvokeCICallInfo
273+
274+
Represents a resolved call to `Core.invoke` targeting a `Core.CodeInstance`
275+
"""
276+
struct InvokeCICallInfo <: CallInfo
277+
edge::CodeInstance
278+
end
279+
add_edges_impl(edges::Vector{Any}, info::InvokeCICallInfo) =
280+
add_one_edge!(edges, info.edge)
281+
271282
"""
272283
info::InvokeCallInfo
273284

Compiler/src/utilities.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ function count_const_size(@nospecialize(x), count_self::Bool = true)
5454
# No definite size
5555
(isa(x, GenericMemory) || isa(x, String) || isa(x, SimpleVector)) &&
5656
return MAX_INLINE_CONST_SIZE + 1
57-
if isa(x, Module) || isa(x, Method)
58-
# We allow modules and methods, because we already assume they are externally
57+
if isa(x, Module) || isa(x, Method) || isa(x, CodeInstance)
58+
# We allow modules, methods and CodeInstance, because we already assume they are externally
5959
# rooted, so we count their contents as 0 size.
6060
return sizeof(Ptr{Cvoid})
6161
end

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ New library features
103103
* New `ltruncate`, `rtruncate` and `ctruncate` functions for truncating strings to text width, accounting for char widths ([#55351])
104104
* `isless` (and thus `cmp`, sorting, etc.) is now supported for zero-dimensional `AbstractArray`s ([#55772])
105105
* `invoke` now supports passing a Method instead of a type signature making this interface somewhat more flexible for certain uncommon use cases ([#56692]).
106+
* `invoke` now supports passing a CodeInstance instead of a type, which can enable
107+
certain compiler plugin workflows ([#56660]).
106108

107109
Standard library changes
108110
------------------------

base/docs/basedocs.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2031,6 +2031,7 @@ applicable
20312031
"""
20322032
invoke(f, argtypes::Type, args...; kwargs...)
20332033
invoke(f, argtypes::Method, args...; kwargs...)
2034+
invoke(f, argtypes::CodeInstance, args...; kwargs...)
20342035
20352036
Invoke a method for the given generic function `f` matching the specified types `argtypes` on the
20362037
specified arguments `args` and passing the keyword arguments `kwargs`. The arguments `args` must
@@ -2056,6 +2057,22 @@ Note in particular that the specified `Method` may be entirely unreachable from
20562057
If the method is part of the ordinary method table, this call behaves similar
20572058
to `invoke(f, method.sig, args...)`.
20582059
2060+
!!! compat "Julia 1.12"
2061+
Passing a `Method` requires Julia 1.12.
2062+
2063+
# Passing a `CodeInstance` instead of a signature
2064+
The `argtypes` argument may be a `CodeInstance`, bypassing both method lookup and specialization.
2065+
The semantics of this invocation are similar to a function pointer call of the `CodeInstance`'s
2066+
`invoke` pointer. It is an error to invoke a `CodeInstance` with arguments that do not match its
2067+
parent MethodInstance or from a world age not included in the `min_world`/`max_world` range.
2068+
It is undefined behavior to invoke a CodeInstance whose behavior does not match the constraints
2069+
specified in its fields. For some code instances with `owner !== nothing` (i.e. those generated
2070+
by external compilers), it may be an error to invoke them after passing through precompilation.
2071+
This is an advanced interface intended for use with external compiler plugins.
2072+
2073+
!!! compat "Julia 1.12"
2074+
Passing a `CodeInstance` requires Julia 1.12.
2075+
20592076
# Examples
20602077
```jldoctest
20612078
julia> f(x::Real) = x^2;

base/optimized_generics.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,31 @@ module KeyValue
5454
function get end
5555
end
5656

57+
# Compiler-recognized intrinsics for compiler plugins
58+
"""
59+
module CompilerPlugins
60+
61+
Implements a pair of functions `typeinf`/`typeinf_edge`. When the optimizer sees
62+
a call to `typeinf`, it has license to instead call `typeinf_edge`, supplying the
63+
current inference stack in `parent_frame` (but otherwise supplying the arguments
64+
to `typeinf`). typeinf_edge will return the `CodeInstance` that `typeinf` would
65+
have returned at runtime. The optimizer may perform a non-IPO replacement of
66+
the call to `typeinf` by the result of `typeinf_edge`. In addition, the IPO-safe
67+
fields of the `CodeInstance` may be propagated in IPO mode.
68+
"""
69+
module CompilerPlugins
70+
"""
71+
typeinf(owner, mi, source_mode)::CodeInstance
72+
73+
Return a `CodeInstance` for the given `mi` whose valid results include at
74+
the least current tls world and satisfies the requirements of `source_mode`.
75+
"""
76+
function typeinf end
77+
78+
"""
79+
typeinf_edge(owner, mi, parent_frame, world, abi_mode)::CodeInstance
80+
"""
81+
function typeinf_edge end
82+
end
83+
5784
end

src/builtins.c

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1587,6 +1587,28 @@ JL_CALLABLE(jl_f_invoke)
15871587
if (!jl_tuple1_isa(args[0], &args[2], nargs - 1, (jl_datatype_t*)m->sig))
15881588
jl_type_error("invoke: argument type error", argtypes, arg_tuple(args[0], &args[2], nargs - 1));
15891589
return jl_gf_invoke_by_method(m, args[0], &args[2], nargs - 1);
1590+
} else if (jl_is_code_instance(argtypes)) {
1591+
jl_code_instance_t *codeinst = (jl_code_instance_t*)args[1];
1592+
jl_callptr_t invoke = jl_atomic_load_acquire(&codeinst->invoke);
1593+
if (jl_tuple1_isa(args[0], &args[2], nargs - 2, (jl_datatype_t*)codeinst->def->specTypes)) {
1594+
jl_type_error("invoke: argument type error", codeinst->def->specTypes, arg_tuple(args[0], &args[2], nargs - 2));
1595+
}
1596+
if (jl_atomic_load_relaxed(&codeinst->min_world) > jl_current_task->world_age ||
1597+
jl_current_task->world_age > jl_atomic_load_relaxed(&codeinst->max_world)) {
1598+
jl_error("invoke: CodeInstance not valid for this world");
1599+
}
1600+
if (!invoke) {
1601+
jl_compile_codeinst(codeinst);
1602+
invoke = jl_atomic_load_acquire(&codeinst->invoke);
1603+
}
1604+
if (invoke) {
1605+
return invoke(args[0], &args[2], nargs - 2, codeinst);
1606+
} else {
1607+
if (codeinst->owner != jl_nothing || !jl_is_method(codeinst->def->def.value)) {
1608+
jl_error("Failed to invoke or compile external codeinst");
1609+
}
1610+
return jl_gf_invoke_by_method(codeinst->def->def.method, args[0], &args[2], nargs - 1);
1611+
}
15901612
}
15911613
if (!jl_is_tuple_type(jl_unwrap_unionall(argtypes)))
15921614
jl_type_error("invoke", (jl_value_t*)jl_anytuple_type_type, argtypes);

src/interpreter.c

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,28 @@ static jl_value_t *do_invoke(jl_value_t **args, size_t nargs, interpreter_state
137137
argv[i-1] = eval_value(args[i], s);
138138
jl_value_t *c = args[0];
139139
assert(jl_is_code_instance(c) || jl_is_method_instance(c));
140-
jl_method_instance_t *meth = jl_is_method_instance(c) ? (jl_method_instance_t*)c : ((jl_code_instance_t*)c)->def;
141-
jl_value_t *result = jl_invoke(argv[0], nargs == 2 ? NULL : &argv[1], nargs - 2, meth);
140+
jl_value_t *result = NULL;
141+
if (jl_is_code_instance(c)) {
142+
jl_code_instance_t *codeinst = (jl_code_instance_t*)c;
143+
assert(jl_atomic_load_relaxed(&codeinst->min_world) <= jl_current_task->world_age &&
144+
jl_current_task->world_age <= jl_atomic_load_relaxed(&codeinst->max_world));
145+
jl_callptr_t invoke = jl_atomic_load_acquire(&codeinst->invoke);
146+
if (!invoke) {
147+
jl_compile_codeinst(codeinst);
148+
invoke = jl_atomic_load_acquire(&codeinst->invoke);
149+
}
150+
if (invoke) {
151+
result = invoke(argv[0], nargs == 2 ? NULL : &argv[1], nargs - 2, codeinst);
152+
153+
} else {
154+
if (codeinst->owner != jl_nothing) {
155+
jl_error("Failed to invoke or compile external codeinst");
156+
}
157+
result = jl_invoke(argv[0], nargs == 2 ? NULL : &argv[1], nargs - 2, codeinst->def);
158+
}
159+
} else {
160+
result = jl_invoke(argv[0], nargs == 2 ? NULL : &argv[1], nargs - 2, (jl_method_instance_t*)c);
161+
}
142162
JL_GC_POP();
143163
return result;
144164
}

test/core.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8353,9 +8353,23 @@ end
83538353
@test eval(Expr(:toplevel, :(@define_call(f_macro_defined1)))) == 1
83548354
@test @define_call(f_macro_defined2) == 1
83558355

8356+
# `invoke` of `Method`
83568357
let m = which(+, (Int, Int))
83578358
@eval f56692(i) = invoke(+, $m, i, 4)
83588359
global g56692() = f56692(5) == 9 ? "true" : false
83598360
end
83608361
@test @inferred(f56692(3)) == 7
83618362
@test @inferred(g56692()) == "true"
8363+
8364+
# `invoke` of `CodeInstance`
8365+
f_invalidate_me() = return 1
8366+
f_invoke_me() = return f_invalidate_me()
8367+
@test f_invoke_me() == 1
8368+
const f_invoke_me_ci = Base.specialize_method(Base._which(Tuple{typeof(f_invoke_me)})).cache
8369+
f_call_me() = invoke(f_invoke_me, f_invoke_me_ci)
8370+
@test invoke(f_invoke_me, f_invoke_me_ci) == 1
8371+
@test f_call_me() == 1
8372+
@test_throws TypeError invoke(f_invoke_me, f_invoke_me_ci, 1)
8373+
f_invalidate_me() = 2
8374+
@test_throws ErrorException invoke(f_invoke_me, f_invoke_me_ci)
8375+
@test_throws ErrorException f_call_me()

0 commit comments

Comments
 (0)