Skip to content

optimizer: propagate callsite inlining annotation across kwfunc #43911

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
any_const_result = false
const_results = Union{InferenceResult,Nothing}[]
multiple_matches = napplicable > 1
ft = argtypes[1]
inline_propagation = isa(ft, Kwfunc) ? ft.inline : nothing

if f !== nothing && napplicable == 1 && is_method_pure(applicable[1]::MethodMatch)
val = pure_eval_call(f, argtypes)
Expand Down Expand Up @@ -76,14 +78,14 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
if splitunions
splitsigs = switchtupleunion(sig)
for sig_n in splitsigs
result = abstract_call_method(interp, method, sig_n, svec(), multiple_matches, sv)
result = abstract_call_method(interp, method, sig_n, svec(), multiple_matches, sv, inline_propagation)
rt, edge = result.rt, result.edge
if edge !== nothing
push!(edges, edge)
end
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
this_arginfo = ArgInfo(fargs, this_argtypes)
const_result = abstract_call_method_with_const_args(interp, result, f, this_arginfo, match, sv, false)
const_result = abstract_call_method_with_const_args(interp, result, f, this_arginfo, match, sv, false, inline_propagation)
if const_result !== nothing
const_rt, const_result = const_result
if const_rt !== rt && const_rt ⊑ rt
Expand All @@ -109,12 +111,12 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# the use set for the current SSA value.
saved_uses = sv.ssavalue_uses[sv.currpc]
sv.ssavalue_uses[sv.currpc] = empty_bitset
abstract_call_method(interp, method, csig, match.sparams, multiple_matches, sv)
abstract_call_method(interp, method, csig, match.sparams, multiple_matches, sv, inline_propagation)
sv.ssavalue_uses[sv.currpc] = saved_uses
end
end

result = abstract_call_method(interp, method, sig, match.sparams, multiple_matches, sv)
result = abstract_call_method(interp, method, sig, match.sparams, multiple_matches, sv, inline_propagation)
this_rt, edge = result.rt, result.edge
if edge !== nothing
push!(edges, edge)
Expand All @@ -123,7 +125,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# this is in preparation for inlining, or improving the return result
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
this_arginfo = ArgInfo(fargs, this_argtypes)
const_result = abstract_call_method_with_const_args(interp, result, f, this_arginfo, match, sv, false)
const_result = abstract_call_method_with_const_args(interp, result, f, this_arginfo, match, sv, false, inline_propagation)
if const_result !== nothing
const_this_rt, const_result = const_result
if const_this_rt !== this_rt && const_this_rt ⊑ this_rt
Expand Down Expand Up @@ -405,7 +407,9 @@ end
const RECURSION_UNUSED_MSG = "Bounded recursion detected with unused result. Annotated return type may be wider than true result."
const RECURSION_MSG = "Bounded recursion detected. Call was widened to force convergence."

function abstract_call_method(interp::AbstractInterpreter, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState)
function abstract_call_method(interp::AbstractInterpreter,
method::Method, @nospecialize(sig), sparams::SimpleVector,
hardlimit::Bool, sv::InferenceState, inline_propagation::Union{Nothing,Bool} = nothing)
if method.name === :depwarn && isdefined(Main, :Base) && method.module === Main.Base
add_remark!(interp, sv, "Refusing to infer into `depwarn`")
return MethodCallResult(Any, false, false, nothing)
Expand Down Expand Up @@ -564,7 +568,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
sparams = recomputed[2]::SimpleVector
end

rt, edge = typeinf_edge(interp, method, sig, sparams, sv)
rt, edge = typeinf_edge(interp, method, sig, sparams, sv, inline_propagation)
if edge === nothing
edgecycle = edgelimited = true
end
Expand All @@ -585,9 +589,9 @@ struct MethodCallResult
end
end

function abstract_call_method_with_const_args(interp::AbstractInterpreter, result::MethodCallResult,
@nospecialize(f), arginfo::ArgInfo, match::MethodMatch,
sv::InferenceState, va_override::Bool)
function abstract_call_method_with_const_args(interp::AbstractInterpreter,
result::MethodCallResult, @nospecialize(f), arginfo::ArgInfo, match::MethodMatch,
sv::InferenceState, va_override::Bool, inline_propagation::Union{Nothing,Bool} = nothing)
mi = maybe_get_const_prop_profitable(interp, result, f, arginfo, match, sv)
mi === nothing && return nothing
# try constant prop'
Expand Down Expand Up @@ -615,6 +619,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul
end
frame = InferenceState(inf_result, #=cache=#:local, interp)
frame === nothing && return nothing # this is probably a bad generated function (unsound), but just ignore it
inline_propagation !== nothing && propagate_caller_annotations!(inline_propagation, frame)
frame.parent = sv
typeinf(interp, frame) || return nothing
end
Expand Down Expand Up @@ -1404,7 +1409,11 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
if !isvarargtype(aty)
ft = widenconst(aty)
if isa(ft, DataType) && isdefined(ft.name, :mt) && isdefined(ft.name.mt, :kwsorter)
return CallMeta(Const(ft.name.mt.kwsorter), MethodResultPure())
flag = get_curr_ssaflag(sv)
t = Kwfunc(ft.name.mt.kwsorter,
is_stmt_inline(flag) ? true :
is_stmt_noinline(flag) ? false : nothing)
return CallMeta(t, MethodResultPure())
end
end
end
Expand Down
13 changes: 12 additions & 1 deletion base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,9 @@ end
generating_sysimg() = ccall(:jl_generating_output, Cint, ()) != 0 && JLOptions().incremental == 0

# compute (and cache) an inferred AST and return the current best estimate of the result type
function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize(atype), sparams::SimpleVector, caller::InferenceState)
function typeinf_edge(interp::AbstractInterpreter,
method::Method, @nospecialize(atype), sparams::SimpleVector, caller::InferenceState,
inline_propagation::Union{Nothing,Bool} = nothing)
mi = specialize_method(method, atype, sparams)::MethodInstance
code = get(code_cache(interp), mi, nothing)
if code isa CodeInstance # return existing rettype if the code is already inferred
Expand Down Expand Up @@ -822,6 +824,7 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
unlock_mi_inference(interp, mi)
return Any, nothing
end
inline_propagation !== nothing && propagate_caller_annotations!(inline_propagation, frame)
if caller.cached || caller.parent !== nothing # don't involve uncached functions in cycle resolution
frame.parent = caller
end
Expand All @@ -838,6 +841,14 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
return frame.bestguess, nothing
end

function propagate_caller_annotations!(inline::Bool, callee::InferenceState)
ssaflags = callee.src.ssaflags
for i = 1:length(ssaflags)
ssaflags[i] |= inline ? IR_FLAG_INLINE : IR_FLAG_NOINLINE
ssaflags[i] &= ~(inline ? IR_FLAG_NOINLINE : IR_FLAG_INLINE)
end
end

#### entry points for inferring a MethodInstance given a type signature ####

# compute an inferred AST and return type
Expand Down
19 changes: 18 additions & 1 deletion base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ struct PartialTypeVar
PartialTypeVar(tv::TypeVar, lb_certain::Bool, ub_certain::Bool) = new(tv, lb_certain, ub_certain)
end

struct Kwfunc
kwsorter
inline::Union{Nothing,Bool}
function Kwfunc(@nospecialize(kwsorter), inline::Union{Nothing,Bool})
return new(kwsorter, inline)
end
end

# Wraps a type and represents that the value may also be undef at this point.
# (only used in optimize, not abstractinterpret)
# N.B. in the lattice, this is epsilon bigger than `typ` (even Any)
Expand Down Expand Up @@ -105,7 +113,7 @@ struct NotFound end

const NOT_FOUND = NotFound()

const CompilerTypes = Union{MaybeUndef, Const, Conditional, NotFound, PartialStruct}
const CompilerTypes = Union{Const, Conditional, NotFound, PartialStruct, MaybeUndef, Kwfunc}
==(x::CompilerTypes, y::CompilerTypes) = x === y
==(x::Type, y::CompilerTypes) = false
==(x::CompilerTypes, y::Type) = false
Expand Down Expand Up @@ -215,6 +223,14 @@ The non-strict partial order over the type inference lattice.
end
return widenconst(a) ⊑ b
end
if isa(a, Kwfunc)
if isa(b, Kwfunc)
return a.kwsorter === b.kwsorter && a.inline === b.inline
end
a = Const(a.kwsorter)
elseif isa(b, Kwfunc)
b = Const(b.kwsorter)
end
if isa(a, Const)
if isa(b, Const)
return a.val === b.val
Expand Down Expand Up @@ -293,6 +309,7 @@ widenconst(c::AnyConditional) = Bool
widenconst((; val)::Const) = isa(val, Type) ? Type{val} : typeof(val)
widenconst(m::MaybeUndef) = widenconst(m.typ)
widenconst(c::PartialTypeVar) = TypeVar
widenconst(t::Kwfunc) = typeof(t.kwsorter)
widenconst(t::PartialStruct) = t.typ
widenconst(t::PartialOpaque) = t.typ
widenconst(t::Type) = t
Expand Down
2 changes: 2 additions & 0 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ end
function singleton_type(@nospecialize(ft))
if isa(ft, Const)
return ft.val
elseif isa(ft, Kwfunc)
return ft.kwsorter
elseif isconstType(ft)
return ft.parameters[1]
elseif ft isa DataType && isdefined(ft, :instance)
Expand Down
30 changes: 30 additions & 0 deletions test/compiler/inline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,36 @@ end
end
end

# propagate callsite inlining across kwfunc abstraction
function someeval(@nospecialize(x); mod=Main)
v = Core.eval(mod, :(x = $(esc(x)))) # by default this prevents inlining
return v
end
let src = code_typed1((Any,)) do x
someeval(x; mod=Main)
end
@test count(iscall((src, Core._expr)), src.code) == 0

src = code_typed1((Any,)) do x
@inline someeval(x; mod=Main)
end
@test count(iscall((src, Core._expr)), src.code) == 2
end
function someisdefined(x::Symbol; mod=Main)
v = isdefined(mod, :x)
return v
end
let src = code_typed1((Symbol,)) do x
someisdefined(x; mod=Main)
end
@test count(iscall((src, isdefined)), src.code) == 1

src = code_typed1((Symbol,)) do x
@noinline someisdefined(x; mod=Main)
end
@test count(iscall((src, isdefined)), src.code) == 0
end

# Issue #42264 - crash on certain union splits
let f(x) = (x...,)
# Test splatting with a Union of non-{Tuple, SimpleVector} types that require creating new `iterate` calls
Expand Down