Skip to content

Commit dc3953d

Browse files
maleadtvtjnash
andauthored
Reland: expose caller world to GeneratedFunctionStub (#48766)
Expose the demanded world to the GeneratedFunctionStub caller, for users such as Cassette. If this argument is used, the user must return a CodeInfo with the min/max world field set correctly. Make the internal representation a tiny bit more compact also, removing a little bit of unnecessary metadata. Remove support for returning `body isa CodeInfo` via this wrapper, since it is impossible to return a correct object via the GeneratedFunctionStub since it strips off the world argument, which is required for it to do so. This also removes support for not inferring these fully (expand_early=false). Also answer method lookup queries about the future correctly, by refusing to answer them. This helps keeps execution correct as methods get added to the system asynchronously. This reverts "fix #25678: return matters for generated functions (#40778)" (commit 92c84bf), since this is no longer sensible to return here anyways, so it is no longer permitted or supported by this macro. Fixes various issues where we failed to specify the correct world. Note that the passed world may be `typemax(UInt)`, which demands that the generator must return code that is unspecialized (guaranteed to run correctly in any world). Co-authored-by: Jameson Nash <vtjnash@gmail.com>
1 parent d7df15d commit dc3953d

31 files changed

+204
-203
lines changed

base/Base.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ in_sysimage(pkgid::PkgId) = pkgid in _sysimage_modules
479479
for match = _methods(+, (Int, Int), -1, get_world_counter())
480480
m = match.method
481481
delete!(push!(Set{Method}(), m), m)
482-
copy(Core.Compiler.retrieve_code_info(Core.Compiler.specialize_method(match)))
482+
copy(Core.Compiler.retrieve_code_info(Core.Compiler.specialize_method(match), typemax(UInt)))
483483

484484
empty!(Set())
485485
push!(push!(Set{Union{GlobalRef,Symbol}}(), :two), GlobalRef(Base, :two))

base/boot.jl

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -590,28 +590,25 @@ println(@nospecialize a...) = println(stdout, a...)
590590

591591
struct GeneratedFunctionStub
592592
gen
593-
argnames::Array{Any,1}
594-
spnames::Union{Nothing, Array{Any,1}}
595-
line::Int
596-
file::Symbol
597-
expand_early::Bool
593+
argnames::SimpleVector
594+
spnames::SimpleVector
598595
end
599596

600-
# invoke and wrap the results of @generated
601-
function (g::GeneratedFunctionStub)(@nospecialize args...)
597+
# invoke and wrap the results of @generated expression
598+
function (g::GeneratedFunctionStub)(world::UInt, source::LineNumberNode, @nospecialize args...)
599+
# args is (spvals..., argtypes...)
602600
body = g.gen(args...)
603-
if body isa CodeInfo
604-
return body
605-
end
606-
lam = Expr(:lambda, g.argnames,
607-
Expr(Symbol("scope-block"),
601+
file = source.file
602+
file isa Symbol || (file = :none)
603+
lam = Expr(:lambda, Expr(:argnames, g.argnames...).args,
604+
Expr(:var"scope-block",
608605
Expr(:block,
609-
LineNumberNode(g.line, g.file),
610-
Expr(:meta, :push_loc, g.file, Symbol("@generated body")),
606+
source,
607+
Expr(:meta, :push_loc, file, :var"@generated body"),
611608
Expr(:return, body),
612609
Expr(:meta, :pop_loc))))
613610
spnames = g.spnames
614-
if spnames === nothing
611+
if spnames === svec()
615612
return lam
616613
else
617614
return Expr(Symbol("with-static-parameters"), lam, spnames...)

base/compiler/abstractinterpretation.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
572572
break
573573
end
574574
topmost === nothing || continue
575-
if edge_matches_sv(infstate, method, sig, sparams, hardlimit, sv)
575+
if edge_matches_sv(interp, infstate, method, sig, sparams, hardlimit, sv)
576576
topmost = infstate
577577
edgecycle = true
578578
end
@@ -680,12 +680,13 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
680680
return MethodCallResult(rt, edgecycle, edgelimited, edge, effects)
681681
end
682682

683-
function edge_matches_sv(frame::InferenceState, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState)
683+
function edge_matches_sv(interp::AbstractInterpreter, frame::InferenceState, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState)
684684
# The `method_for_inference_heuristics` will expand the given method's generator if
685685
# necessary in order to retrieve this field from the generated `CodeInfo`, if it exists.
686686
# The other `CodeInfo`s we inspect will already have this field inflated, so we just
687687
# access it directly instead (to avoid regeneration).
688-
callee_method2 = method_for_inference_heuristics(method, sig, sparams) # Union{Method, Nothing}
688+
world = get_world_counter(interp)
689+
callee_method2 = method_for_inference_heuristics(method, sig, sparams, world) # Union{Method, Nothing}
689690

690691
inf_method2 = frame.src.method_for_inference_limit_heuristics # limit only if user token match
691692
inf_method2 isa Method || (inf_method2 = nothing)
@@ -722,11 +723,11 @@ function edge_matches_sv(frame::InferenceState, method::Method, @nospecialize(si
722723
end
723724

724725
# This function is used for computing alternate limit heuristics
725-
function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams::SimpleVector)
726-
if isdefined(method, :generator) && method.generator.expand_early && may_invoke_generator(method, sig, sparams)
726+
function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams::SimpleVector, world::UInt)
727+
if isdefined(method, :generator) && !(method.generator isa Core.GeneratedFunctionStub) && may_invoke_generator(method, sig, sparams)
727728
method_instance = specialize_method(method, sig, sparams)
728729
if isa(method_instance, MethodInstance)
729-
cinfo = get_staged(method_instance)
730+
cinfo = get_staged(method_instance, world)
730731
if isa(cinfo, CodeInfo)
731732
method2 = cinfo.method_for_inference_limit_heuristics
732733
if method2 isa Method

base/compiler/bootstrap.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ let interp = NativeInterpreter()
3636
else
3737
tt = Tuple{typeof(f), Vararg{Any}}
3838
end
39-
for m in _methods_by_ftype(tt, 10, typemax(UInt))::Vector
39+
for m in _methods_by_ftype(tt, 10, get_world_counter())::Vector
4040
# remove any TypeVars from the intersection
4141
m = m::MethodMatch
4242
typ = Any[m.spec_types.parameters...]

base/compiler/inferencestate.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,8 @@ end
366366

367367
function InferenceState(result::InferenceResult, cache::Symbol, interp::AbstractInterpreter)
368368
# prepare an InferenceState object for inferring lambda
369-
src = retrieve_code_info(result.linfo)
369+
world = get_world_counter(interp)
370+
src = retrieve_code_info(result.linfo, world)
370371
src === nothing && return nothing
371372
validate_code_in_debug_mode(result.linfo, src, "lowered")
372373
return InferenceState(result, src, cache, interp)

base/compiler/optimize.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ function OptimizationState(linfo::MethodInstance, src::CodeInfo, interp::Abstrac
181181
return OptimizationState(linfo, src, nothing, stmt_info, mod, sptypes, slottypes, inlining, nothing, false)
182182
end
183183
function OptimizationState(linfo::MethodInstance, interp::AbstractInterpreter)
184-
src = retrieve_code_info(linfo)
184+
world = get_world_counter(interp)
185+
src = retrieve_code_info(linfo, world)
185186
src === nothing && return nothing
186187
return OptimizationState(linfo, src, interp)
187188
end

base/compiler/typeinfer.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1030,7 +1030,7 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance)
10301030
end
10311031
end
10321032
if ccall(:jl_get_module_infer, Cint, (Any,), method.module) == 0 && !generating_sysimg()
1033-
return retrieve_code_info(mi)
1033+
return retrieve_code_info(mi, get_world_counter(interp))
10341034
end
10351035
lock_mi_inference(interp, mi)
10361036
result = InferenceResult(mi, typeinf_lattice(interp))

base/compiler/types.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ function NativeInterpreter(world::UInt = get_world_counter();
341341
inf_params::InferenceParams = InferenceParams(),
342342
opt_params::OptimizationParams = OptimizationParams())
343343
# Sometimes the caller is lazy and passes typemax(UInt).
344-
# we cap it to the current world age
344+
# we cap it to the current world age for correctness
345345
if world == typemax(UInt)
346346
world = get_world_counter()
347347
end

base/compiler/utilities.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,23 +114,23 @@ end
114114
invoke_api(li::CodeInstance) = ccall(:jl_invoke_api, Cint, (Any,), li)
115115
use_const_api(li::CodeInstance) = invoke_api(li) == 2
116116

117-
function get_staged(mi::MethodInstance)
117+
function get_staged(mi::MethodInstance, world::UInt)
118118
may_invoke_generator(mi) || return nothing
119119
try
120120
# user code might throw errors – ignore them
121-
ci = ccall(:jl_code_for_staged, Any, (Any,), mi)::CodeInfo
121+
ci = ccall(:jl_code_for_staged, Any, (Any, UInt), mi, world)::CodeInfo
122122
return ci
123123
catch
124124
return nothing
125125
end
126126
end
127127

128-
function retrieve_code_info(linfo::MethodInstance)
128+
function retrieve_code_info(linfo::MethodInstance, world::UInt)
129129
m = linfo.def::Method
130130
c = nothing
131131
if isdefined(m, :generator)
132132
# user code might throw errors – ignore them
133-
c = get_staged(linfo)
133+
c = get_staged(linfo, world)
134134
end
135135
if c === nothing && isdefined(m, :source)
136136
src = m.source

base/compiler/validation.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,15 +200,14 @@ end
200200

201201
"""
202202
validate_code!(errors::Vector{InvalidCodeError}, mi::MethodInstance,
203-
c::Union{Nothing,CodeInfo} = Core.Compiler.retrieve_code_info(mi))
203+
c::Union{Nothing,CodeInfo})
204204
205205
Validate `mi`, logging any violation by pushing an `InvalidCodeError` into `errors`.
206206
207207
If `isa(c, CodeInfo)`, also call `validate_code!(errors, c)`. It is assumed that `c` is
208-
the `CodeInfo` instance associated with `mi`.
208+
a `CodeInfo` instance associated with `mi`.
209209
"""
210-
function validate_code!(errors::Vector{InvalidCodeError}, mi::Core.MethodInstance,
211-
c::Union{Nothing,CodeInfo} = Core.Compiler.retrieve_code_info(mi))
210+
function validate_code!(errors::Vector{InvalidCodeError}, mi::Core.MethodInstance, c::Union{Nothing,CodeInfo})
212211
is_top_level = mi.def isa Module
213212
if is_top_level
214213
mnargs = 0

0 commit comments

Comments
 (0)