Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimizer: support callsite annotations of @inline and @noinline #40754

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ New language features
---------------------

* `Module(:name, false, false)` can be used to create a `module` that does not import `Core`. ([#40110])
* `@noinline` can now be used at function callsites. ([#40754])
* `@inline` and `@noinline` annotations may now be used in function bodies. ([#40754])

Language changes
----------------
Expand Down
6 changes: 4 additions & 2 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,8 @@ function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, method
if isdefined(code, :inferred) && !cache_inlineable
cache_inf = code.inferred
if !(cache_inf === nothing)
cache_inlineable = inlining_policy(interp)(cache_inf) !== nothing
# TODO maybe we want to respect callsite `@inline`/`@noinline` annotations here ?
cache_inlineable = inlining_policy(interp)(cache_inf, nothing) !== nothing
end
end
if !cache_inlineable
Expand Down Expand Up @@ -1806,7 +1807,8 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
if isa(fname, SlotNumber)
changes = StateUpdate(fname, VarState(Any, false), changes, false)
end
elseif hd === :inbounds || hd === :meta || hd === :loopinfo || hd === :code_coverage_effect
elseif hd === :code_coverage_effect ||
(hd !== :boundscheck && hd !== nothing && is_meta_expr_head(hd)) # :boundscheck can be narrowed to Bool
# these do not generate code
else
t = abstract_eval_statement(interp, stmt, changes, frame)
Expand Down
38 changes: 35 additions & 3 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ struct InliningState{S <: Union{EdgeTracker, Nothing}, T, P}
policy::P
end

function default_inlining_policy(@nospecialize(src))
function default_inlining_policy(@nospecialize(src), stmt_flag::Union{Nothing,UInt8})
if isa(src, CodeInfo) || isa(src, Vector{UInt8})
src_inferred = ccall(:jl_ir_flag_inferred, Bool, (Any,), src)
src_inlineable = ccall(:jl_ir_flag_inlineable, Bool, (Any,), src)
src_inlineable = is_stmt_inline(stmt_flag) || ccall(:jl_ir_flag_inlineable, Bool, (Any,), src)
return src_inferred && src_inlineable ? src : nothing
end
if isa(src, OptimizationState) && isdefined(src, :ir)
return src.src.inlineable ? src.ir : nothing
return (is_stmt_inline(stmt_flag) || src.src.inlineable) ? src.ir : nothing
Comment on lines +31 to +38
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the easy and customizable approach to support callsite @inline annotation is to check the statement flag within inlining_policy, and thus I ended up changing the signature of the interface.
@Keno does it sound reasonable/okay to you ?

end
return nothing
end
Expand Down Expand Up @@ -134,6 +134,10 @@ const SLOT_USEDUNDEF = 32 # slot has uses that might raise UndefVarError
# This statement was marked as @inbounds by the user. If replaced by inlining,
# any contained boundschecks may be removed
const IR_FLAG_INBOUNDS = 0x01
# This statement was marked as @inline by the user
const IR_FLAG_INLINE = 0x01 << 1
# This statement was marked as @noinline by the user
const IR_FLAG_NOINLINE = 0x01 << 2
# This statement may be removed if its result is unused. In particular it must
# thus be both pure and effect free.
const IR_FLAG_EFFECT_FREE = 0x01 << 4
Expand Down Expand Up @@ -179,6 +183,11 @@ function isinlineable(m::Method, me::OptimizationState, params::OptimizationPara
return inlineable
end

is_stmt_inline(stmt_flag::UInt8) = stmt_flag & IR_FLAG_INLINE != 0
is_stmt_inline(::Nothing) = false
is_stmt_noinline(stmt_flag::UInt8) = stmt_flag & IR_FLAG_NOINLINE != 0
is_stmt_noinline(::Nothing) = false # not used for now

# These affect control flow within the function (so may not be removed
# if there is no usage within the function), but don't affect the purity
# of the function as a whole.
Expand Down Expand Up @@ -366,6 +375,7 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg
renumber_ir_elements!(code, changemap, labelmap)

inbounds_depth = 0 # Number of stacked inbounds
inline_flags = BitVector()
meta = Any[]
flags = fill(0x00, length(code))
for i = 1:length(code)
Expand All @@ -380,6 +390,20 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg
inbounds_depth -= 1
end
stmt = nothing
elseif isexpr(stmt, :inline)
if stmt.args[1]::Bool
push!(inline_flags, true)
else
pop!(inline_flags)
end
stmt = nothing
elseif isexpr(stmt, :noinline)
if stmt.args[1]::Bool
push!(inline_flags, false)
else
pop!(inline_flags)
end
stmt = nothing
else
stmt = normalize(stmt, meta)
end
Expand All @@ -388,8 +412,16 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg
if inbounds_depth > 0
flags[i] |= IR_FLAG_INBOUNDS
end
if !isempty(inline_flags)
if last(inline_flags)
flags[i] |= IR_FLAG_INLINE
else
flags[i] |= IR_FLAG_NOINLINE
end
end
end
end
@assert isempty(inline_flags) "malformed meta flags"
strip_trailing_junk!(ci, code, stmtinfo, flags)
cfg = compute_basic_blocks(code)
types = Any[]
Expand Down
57 changes: 26 additions & 31 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::
argexprs::Vector{Any}, atypes::Vector{Any}, arginfos::Vector{Any},
arg_start::Int, istate::InliningState)

flag = ir.stmts[idx][:flag]
new_argexprs = Any[argexprs[arg_start]]
new_atypes = Any[atypes[arg_start]]
# loop over original arguments and flatten any known iterators
Expand Down Expand Up @@ -655,8 +656,9 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::
info = call.info
handled = false
if isa(info, ConstCallInfo)
if maybe_handle_const_call!(ir, state1.id, new_stmt, info, new_sig,
call.rt, istate, false, todo)
if !is_stmt_noinline(flag) &&
maybe_handle_const_call!(ir, state1.id, new_stmt, info, new_sig,
call.rt, istate, flag, false, todo)
handled = true
else
info = info.call
Expand All @@ -667,7 +669,7 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::
MethodMatchInfo[info] : info.matches
# See if we can inline this call to `iterate`
analyze_single_call!(ir, todo, state1.id, new_stmt,
new_sig, call.rt, info, istate)
new_sig, call.rt, info, istate, flag)
end
if i != length(thisarginfo.each)
valT = getfield_tfunc(call.rt, Const(1))
Expand Down Expand Up @@ -716,7 +718,7 @@ function compileable_specialization(et::Union{EdgeTracker, Nothing}, result::Inf
return mi
end

function resolve_todo(todo::InliningTodo, state::InliningState)
function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
spec = todo.spec::DelayedInliningSpec

#XXX: update_valid_age!(min_valid[1], max_valid[1], sv)
Expand Down Expand Up @@ -754,7 +756,7 @@ function resolve_todo(todo::InliningTodo, state::InliningState)
end

if src !== nothing
src = state.policy(src)
src = state.policy(src, flag)
end

if src === nothing
Expand All @@ -769,17 +771,9 @@ function resolve_todo(todo::InliningTodo, state::InliningState)
return InliningTodo(todo.mi, src)
end

function resolve_todo(todo::UnionSplit, state::InliningState)
function resolve_todo(todo::UnionSplit, state::InliningState, flag::UInt8)
UnionSplit(todo.fully_covered, todo.atype,
Pair{Any,Any}[sig=>resolve_todo(item, state) for (sig, item) in todo.cases])
end

function resolve_todo!(todo::Vector{Pair{Int, Any}}, state::InliningState)
for i = 1:length(todo)
idx, item = todo[i]
todo[i] = idx=>resolve_todo(item, state)
end
todo
Pair{Any,Any}[sig=>resolve_todo(item, state, flag) for (sig, item) in todo.cases])
end

function validate_sparams(sparams::SimpleVector)
Expand All @@ -790,7 +784,7 @@ function validate_sparams(sparams::SimpleVector)
end

function analyze_method!(match::MethodMatch, atypes::Vector{Any},
state::InliningState, @nospecialize(stmttyp))
state::InliningState, @nospecialize(stmttyp), flag::UInt8)
method = match.method
methsig = method.sig

Expand All @@ -806,11 +800,9 @@ function analyze_method!(match::MethodMatch, atypes::Vector{Any},
end

# Bail out if any static parameters are left as TypeVar
ok = true
validate_sparams(match.sparams) || return nothing


if !state.params.inlining
if !state.params.inlining || is_stmt_noinline(flag)
return compileable_specialization(state.et, match)
end

Expand All @@ -824,7 +816,7 @@ function analyze_method!(match::MethodMatch, atypes::Vector{Any},
# If we don't have caches here, delay resolving this MethodInstance
# until the batch inlining step (or an external post-processing pass)
state.mi_cache === nothing && return todo
return resolve_todo(todo, state)
return resolve_todo(todo, state, flag)
end

function InliningTodo(mi::MethodInstance, ir::IRCode)
Expand Down Expand Up @@ -1050,7 +1042,7 @@ is_builtin(s::Signature) =
s.ft ⊑ Builtin

function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, info::InvokeCallInfo,
state::InliningState, todo::Vector{Pair{Int, Any}})
state::InliningState, todo::Vector{Pair{Int, Any}}, flag::UInt8)
stmt = ir.stmts[idx][:inst]
calltype = ir.stmts[idx][:type]

Expand All @@ -1064,7 +1056,7 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, info::InvokeCallIn
atypes = atypes[4:end]
pushfirst!(atypes, atype0)

result = analyze_method!(info.match, atypes, state, calltype)
result = analyze_method!(info.match, atypes, state, calltype, flag)
handle_single_case!(ir, stmt, idx, result, true, todo)
return nothing
end
Expand Down Expand Up @@ -1159,7 +1151,7 @@ end

function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, @nospecialize(stmt),
sig::Signature, @nospecialize(calltype), infos::Vector{MethodMatchInfo},
state::InliningState)
state::InliningState, flag::UInt8)
cases = Pair{Any, Any}[]
signature_union = Union{}
only_method = nothing # keep track of whether there is one matching method
Expand Down Expand Up @@ -1192,7 +1184,7 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int
fully_covered = false
continue
end
case = analyze_method!(match, sig.atypes, state, calltype)
case = analyze_method!(match, sig.atypes, state, calltype, flag)
if case === nothing
fully_covered = false
continue
Expand All @@ -1219,7 +1211,7 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int
match = meth[1]
end
fully_covered = true
case = analyze_method!(match, sig.atypes, state, calltype)
case = analyze_method!(match, sig.atypes, state, calltype, flag)
case === nothing && return
push!(cases, Pair{Any,Any}(match.spec_types, case))
end
Expand All @@ -1241,7 +1233,7 @@ end

function maybe_handle_const_call!(ir::IRCode, idx::Int, stmt::Expr,
info::ConstCallInfo, sig::Signature, @nospecialize(calltype),
state::InliningState,
state::InliningState, flag::UInt8,
isinvoke::Bool, todo::Vector{Pair{Int, Any}})
# when multiple matches are found, bail out and later inliner will union-split this signature
# TODO effectively use multiple constant analysis results here
Expand All @@ -1253,7 +1245,7 @@ function maybe_handle_const_call!(ir::IRCode, idx::Int, stmt::Expr,
validate_sparams(item.mi.sparam_vals) || return true
mthd_sig = item.mi.def.sig
mistypes = item.mi.specTypes
state.mi_cache !== nothing && (item = resolve_todo(item, state))
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
if sig.atype <: mthd_sig
handle_single_case!(ir, stmt, idx, item, isinvoke, todo)
return true
Expand Down Expand Up @@ -1291,6 +1283,8 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
info = info.info
end

flag = ir.stmts[idx][:flag]

# Inference determined this couldn't be analyzed. Don't question it.
if info === false
continue
Expand All @@ -1300,23 +1294,24 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
# it'll have performed a specialized analysis for just this case. Use its
# result.
if isa(info, ConstCallInfo)
if maybe_handle_const_call!(ir, idx, stmt, info, sig, calltype, state, sig.f === Core.invoke, todo)
if !is_stmt_noinline(flag) &&
maybe_handle_const_call!(ir, idx, stmt, info, sig, calltype, state, flag, sig.f === Core.invoke, todo)
continue
else
info = info.call
end
end

if isa(info, OpaqueClosureCallInfo)
result = analyze_method!(info.match, sig.atypes, state, calltype)
result = analyze_method!(info.match, sig.atypes, state, calltype, flag)
handle_single_case!(ir, stmt, idx, result, false, todo)
continue
end

# Handle invoke
if sig.f === Core.invoke
if isa(info, InvokeCallInfo)
inline_invoke!(ir, idx, sig, info, state, todo)
inline_invoke!(ir, idx, sig, info, state, todo, flag)
end
continue
end
Expand All @@ -1330,7 +1325,7 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
continue
end

analyze_single_call!(ir, todo, idx, stmt, sig, calltype, infos, state)
analyze_single_call!(ir, todo, idx, stmt, sig, calltype, infos, state, flag)
end
todo
end
Expand Down
5 changes: 3 additions & 2 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ end

# Meta expression head, these generally can't be deleted even when they are
# in a dead branch but can be ignored when analyzing uses/liveness.
is_meta_expr_head(head::Symbol) = (head === :inbounds || head === :boundscheck || head === :meta || head === :loopinfo)
is_meta_expr_head(head::Symbol) = (head === :inbounds || head === :boundscheck || head === :meta ||
head === :loopinfo || head === :inline || head === :noinline)

sym_isless(a::Symbol, b::Symbol) = ccall(:strcmp, Int32, (Ptr{UInt8}, Ptr{UInt8}), a, b) < 0

Expand Down Expand Up @@ -187,7 +188,7 @@ function specialize_method(method::Method, @nospecialize(atypes), sparams::Simpl
if preexisting
# check cached specializations
# for an existing result stored there
return ccall(:jl_specializations_lookup, Any, (Any, Any), method, atypes)
return ccall(:jl_specializations_lookup, Any, (Any, Any), method, atypes)::Union{Nothing,MethodInstance}
end
return ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any), method, atypes, sparams)
end
Expand Down
4 changes: 3 additions & 1 deletion base/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ const VALID_EXPR_HEADS = IdDict{Symbol,UnitRange}(
:leave => 1:1,
:pop_exception => 1:1,
:inbounds => 1:1,
:inline => 1:1,
:noinline => 1:1,
:boundscheck => 0:0,
:copyast => 1:1,
:meta => 0:typemax(Int),
Expand Down Expand Up @@ -141,7 +143,7 @@ function validate_code!(errors::Vector{>:InvalidCodeError}, c::CodeInfo, is_top_
head === :const || head === :enter || head === :leave || head === :pop_exception ||
head === :method || head === :global || head === :static_parameter ||
head === :new || head === :splatnew || head === :thunk || head === :loopinfo ||
head === :throw_undef_if_not || head === :code_coverage_effect
head === :throw_undef_if_not || head === :code_coverage_effect || head === :inline || head === :noinline
validate_val!(x)
else
# TODO: nothing is actually in statement position anymore
Expand Down
Loading