Skip to content

Commit fbd73c0

Browse files
committed
generators: expose caller world to GeneratedFunctionStub
Expose the demanded world to the GeneratedFunctionStub caller, for users such as Cassette. If this argument is used, the uesr must return a CodeInfo with the min/max world field set correctly. Make the internal representation a tiny bit more compact also, removing a little bit of unnecessary metadata. Remove support for returning `body isa CodeInfo` via this wrapper, since it is impossible to return a correct object via the GeneratedFunctionStub since it strips off the world argument, which is required for it to do so. This also removes support for not inferring these fully (expand_early=false). Also answer method lookup queries about the future correctly, by refusing to answer them. This helps keeps execution correct as methods get added to the system asynchronously. This reverts "fix #25678: return matters for generated functions (#40778)" (commit 92c84bf), since this is no longer sensible to return here anyways, so it is no longer permitted or supported by this macro. Fixes various issues where we failed to specify the correct world.
1 parent 43c6f75 commit fbd73c0

31 files changed

+209
-208
lines changed

base/Base.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ in_sysimage(pkgid::PkgId) = pkgid in _sysimage_modules
468468
for match = _methods(+, (Int, Int), -1, get_world_counter())
469469
m = match.method
470470
delete!(push!(Set{Method}(), m), m)
471-
copy(Core.Compiler.retrieve_code_info(Core.Compiler.specialize_method(match)))
471+
copy(Core.Compiler.retrieve_code_info(Core.Compiler.specialize_method(match), typemax(UInt)))
472472

473473
empty!(Set())
474474
push!(push!(Set{Union{GlobalRef,Symbol}}(), :two), GlobalRef(Base, :two))

base/boot.jl

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -590,28 +590,25 @@ println(@nospecialize a...) = println(stdout, a...)
590590

591591
struct GeneratedFunctionStub
592592
gen
593-
argnames::Array{Any,1}
594-
spnames::Union{Nothing, Array{Any,1}}
595-
line::Int
596-
file::Symbol
597-
expand_early::Bool
593+
argnames::SimpleVector
594+
spnames::SimpleVector
598595
end
599596

600-
# invoke and wrap the results of @generated
601-
function (g::GeneratedFunctionStub)(@nospecialize args...)
597+
# invoke and wrap the results of @generated expression
598+
function (g::GeneratedFunctionStub)(world::UInt, source::LineNumberNode, @nospecialize args...)
599+
# args is (spvals..., argtypes...)
602600
body = g.gen(args...)
603-
if body isa CodeInfo
604-
return body
605-
end
606-
lam = Expr(:lambda, g.argnames,
607-
Expr(Symbol("scope-block"),
601+
file = source.file
602+
file isa Symbol || (file = :none)
603+
lam = Expr(:lambda, Expr(:argnames, g.argnames...).args,
604+
Expr(:var"scope-block",
608605
Expr(:block,
609-
LineNumberNode(g.line, g.file),
610-
Expr(:meta, :push_loc, g.file, Symbol("@generated body")),
606+
source,
607+
Expr(:meta, :push_loc, file, :var"@generated body"),
611608
Expr(:return, body),
612609
Expr(:meta, :pop_loc))))
613610
spnames = g.spnames
614-
if spnames === nothing
611+
if spnames === svec()
615612
return lam
616613
else
617614
return Expr(Symbol("with-static-parameters"), lam, spnames...)

base/compiler/abstractinterpretation.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
573573
break
574574
end
575575
topmost === nothing || continue
576-
if edge_matches_sv(infstate, method, sig, sparams, hardlimit, sv)
576+
if edge_matches_sv(interp, infstate, method, sig, sparams, hardlimit, sv)
577577
topmost = infstate
578578
edgecycle = true
579579
end
@@ -681,12 +681,13 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
681681
return MethodCallResult(rt, edgecycle, edgelimited, edge, effects)
682682
end
683683

684-
function edge_matches_sv(frame::InferenceState, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState)
684+
function edge_matches_sv(interp::AbstractInterpreter, frame::InferenceState, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState)
685685
# The `method_for_inference_heuristics` will expand the given method's generator if
686686
# necessary in order to retrieve this field from the generated `CodeInfo`, if it exists.
687687
# The other `CodeInfo`s we inspect will already have this field inflated, so we just
688688
# access it directly instead (to avoid regeneration).
689-
callee_method2 = method_for_inference_heuristics(method, sig, sparams) # Union{Method, Nothing}
689+
world = get_world_counter(interp)
690+
callee_method2 = method_for_inference_heuristics(method, sig, sparams, world) # Union{Method, Nothing}
690691

691692
inf_method2 = frame.src.method_for_inference_limit_heuristics # limit only if user token match
692693
inf_method2 isa Method || (inf_method2 = nothing)
@@ -723,11 +724,11 @@ function edge_matches_sv(frame::InferenceState, method::Method, @nospecialize(si
723724
end
724725

725726
# This function is used for computing alternate limit heuristics
726-
function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams::SimpleVector)
727-
if isdefined(method, :generator) && method.generator.expand_early && may_invoke_generator(method, sig, sparams)
727+
function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams::SimpleVector, world::UInt)
728+
if isdefined(method, :generator) && !(method.generator isa Core.GeneratedFunctionStub) && may_invoke_generator(method, sig, sparams)
728729
method_instance = specialize_method(method, sig, sparams)
729730
if isa(method_instance, MethodInstance)
730-
cinfo = get_staged(method_instance)
731+
cinfo = get_staged(method_instance, world)
731732
if isa(cinfo, CodeInfo)
732733
method2 = cinfo.method_for_inference_limit_heuristics
733734
if method2 isa Method
@@ -791,24 +792,25 @@ end
791792
function pure_eval_eligible(interp::AbstractInterpreter,
792793
@nospecialize(f), applicable::Vector{Any}, arginfo::ArgInfo)
793794
# XXX we need to check that this pure function doesn't call any overlayed method
795+
world = get_world_counter(interp)
794796
return f !== nothing &&
795797
length(applicable) == 1 &&
796-
is_method_pure(applicable[1]::MethodMatch) &&
798+
is_method_pure(applicable[1]::MethodMatch, world) &&
797799
is_all_const_arg(arginfo, #=start=#2)
798800
end
799801

800-
function is_method_pure(method::Method, @nospecialize(sig), sparams::SimpleVector)
802+
function is_method_pure(method::Method, @nospecialize(sig), sparams::SimpleVector, world::UInt)
801803
if isdefined(method, :generator)
802-
method.generator.expand_early || return false
804+
method.generator isa Core.GeneratedFunctionStub && return false
803805
mi = specialize_method(method, sig, sparams)
804806
isa(mi, MethodInstance) || return false
805-
staged = get_staged(mi)
807+
staged = get_staged(mi, world)
806808
(staged isa CodeInfo && (staged::CodeInfo).pure) || return false
807809
return true
808810
end
809811
return method.pure
810812
end
811-
is_method_pure(match::MethodMatch) = is_method_pure(match.method, match.spec_types, match.sparams)
813+
is_method_pure(match::MethodMatch, world::UInt) = is_method_pure(match.method, match.spec_types, match.sparams, world)
812814

813815
function pure_eval_call(interp::AbstractInterpreter,
814816
@nospecialize(f), applicable::Vector{Any}, arginfo::ArgInfo)

base/compiler/bootstrap.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ let interp = NativeInterpreter()
3636
else
3737
tt = Tuple{typeof(f), Vararg{Any}}
3838
end
39-
for m in _methods_by_ftype(tt, 10, typemax(UInt))::Vector
39+
for m in _methods_by_ftype(tt, 10, get_world_counter())::Vector
4040
# remove any TypeVars from the intersection
4141
m = m::MethodMatch
4242
typ = Any[m.spec_types.parameters...]

base/compiler/inferencestate.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,8 @@ end
342342

343343
function InferenceState(result::InferenceResult, cache::Symbol, interp::AbstractInterpreter)
344344
# prepare an InferenceState object for inferring lambda
345-
src = retrieve_code_info(result.linfo)
345+
world = get_world_counter(interp)
346+
src = retrieve_code_info(result.linfo, world)
346347
src === nothing && return nothing
347348
validate_code_in_debug_mode(result.linfo, src, "lowered")
348349
return InferenceState(result, src, cache, interp)

base/compiler/optimize.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::Optimiz
182182
return OptimizationState(linfo, src, nothing, stmt_info, mod, sptypes, slottypes, inlining, nothing)
183183
end
184184
function OptimizationState(linfo::MethodInstance, params::OptimizationParams, interp::AbstractInterpreter)
185-
src = retrieve_code_info(linfo)
185+
world = get_world_counter(interp)
186+
src = retrieve_code_info(linfo, world)
186187
src === nothing && return nothing
187188
return OptimizationState(linfo, src, params, interp)
188189
end

base/compiler/typeinfer.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1041,7 +1041,7 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance)
10411041
end
10421042
end
10431043
if ccall(:jl_get_module_infer, Cint, (Any,), method.module) == 0 && !generating_sysimg()
1044-
return retrieve_code_info(mi)
1044+
return retrieve_code_info(mi, get_world_counter(interp))
10451045
end
10461046
lock_mi_inference(interp, mi)
10471047
result = InferenceResult(mi)

base/compiler/types.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ struct NativeInterpreter <: AbstractInterpreter
317317
cache = Vector{InferenceResult}() # Initially empty cache
318318

319319
# Sometimes the caller is lazy and passes typemax(UInt).
320-
# we cap it to the current world age
320+
# we cap it to the current world age for correctness
321321
if world == typemax(UInt)
322322
world = get_world_counter()
323323
end

base/compiler/utilities.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,23 +114,23 @@ end
114114
invoke_api(li::CodeInstance) = ccall(:jl_invoke_api, Cint, (Any,), li)
115115
use_const_api(li::CodeInstance) = invoke_api(li) == 2
116116

117-
function get_staged(mi::MethodInstance)
117+
function get_staged(mi::MethodInstance, world::UInt)
118118
may_invoke_generator(mi) || return nothing
119119
try
120120
# user code might throw errors – ignore them
121-
ci = ccall(:jl_code_for_staged, Any, (Any,), mi)::CodeInfo
121+
ci = ccall(:jl_code_for_staged, Any, (Any, UInt), mi, world)::CodeInfo
122122
return ci
123123
catch
124124
return nothing
125125
end
126126
end
127127

128-
function retrieve_code_info(linfo::MethodInstance)
128+
function retrieve_code_info(linfo::MethodInstance, world::UInt)
129129
m = linfo.def::Method
130130
c = nothing
131131
if isdefined(m, :generator)
132132
# user code might throw errors – ignore them
133-
c = get_staged(linfo)
133+
c = get_staged(linfo, world)
134134
end
135135
if c === nothing && isdefined(m, :source)
136136
src = m.source

base/compiler/validation.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,15 +200,14 @@ end
200200

201201
"""
202202
validate_code!(errors::Vector{>:InvalidCodeError}, mi::MethodInstance,
203-
c::Union{Nothing,CodeInfo} = Core.Compiler.retrieve_code_info(mi))
203+
c::Union{Nothing,CodeInfo})
204204
205205
Validate `mi`, logging any violation by pushing an `InvalidCodeError` into `errors`.
206206
207207
If `isa(c, CodeInfo)`, also call `validate_code!(errors, c)`. It is assumed that `c` is
208-
the `CodeInfo` instance associated with `mi`.
208+
a `CodeInfo` instance associated with `mi`.
209209
"""
210-
function validate_code!(errors::Vector{>:InvalidCodeError}, mi::Core.MethodInstance,
211-
c::Union{Nothing,CodeInfo} = Core.Compiler.retrieve_code_info(mi))
210+
function validate_code!(errors::Vector{>:InvalidCodeError}, mi::Core.MethodInstance, c::Union{Nothing,CodeInfo})
212211
is_top_level = mi.def isa Module
213212
if is_top_level
214213
mnargs = 0

0 commit comments

Comments
 (0)