From e4ea0731f857be3e4de847d147c249f0cca64def Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Wed, 23 Jun 2021 20:22:15 +0900 Subject: [PATCH] optimizer: supports callsite annotations of inlining, fixes #18773 Enable `@inline`/`@noinline` annotations on function callsites. From #40754. Now `@inline` and `@noinline` can be applied to a code block and then the compiler will try to (not) inline calls within the block: ```julia @inline f(...) # The compiler will try to inline `f` @inline f(...) + g(...) # The compiler will try to inline `f`, `g` and `+` @inline f(args...) = ... # Of course annotations on a definition is still allowed ``` Here are couple of notes on how those callsite annotations will work: - callsite annotation always has the precedence over the annotation applied to the definition of the called function, whichever we use `@inline`/`@noinline`: ```julia @inline function explicit_inline(args...) # body end let @noinline explicit_inline(args...) # this call will not be inlined end ``` - when callsite annotations are nested, the innermost annotations has the precedence ```julia @noinline let a0, b0 = ... a = @inline f(a0) # the compiler will try to inline this call b = notinlined(b0) # the compiler will NOT try to inline this call return a, b end ``` They're both tested and included in documentations. --- base/compiler/abstractinterpretation.jl | 11 +- base/compiler/optimize.jl | 52 ++++++++-- base/compiler/ssair/inlining.jl | 73 ++++++------- base/compiler/typeinfer.jl | 2 +- base/compiler/types.jl | 8 +- base/compiler/utilities.jl | 5 +- base/compiler/validation.jl | 4 +- base/expr.jl | 132 +++++++++++++++++++++--- base/meta.jl | 3 +- src/ast.scm | 2 +- src/codegen.cpp | 4 +- src/interpreter.c | 2 +- src/julia-syntax.scm | 6 +- src/macroexpand.scm | 2 +- src/method.c | 3 +- test/compiler/inline.jl | 90 ++++++++++++++++ 16 files changed, 317 insertions(+), 82 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 23b00134c6071..d2ec266b2c1ab 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -554,7 +554,7 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::Me return nothing end mi = mi::MethodInstance - if !force && !const_prop_methodinstance_heuristic(interp, method, mi) + if !force && !const_prop_methodinstance_heuristic(interp, match, mi) add_remark!(interp, sv, "[constprop] Disabled by method instance heuristic") return nothing end @@ -656,7 +656,8 @@ end # This is a heuristic to avoid trying to const prop through complicated functions # where we would spend a lot of time, but are probably unlikely to get an improved # result anyway. -function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, method::Method, mi::MethodInstance) +function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, match::MethodMatch, mi::MethodInstance) + method = match.method if method.is_for_opaque_closure # Not inlining an opaque closure can be very expensive, so be generous # with the const-prop-ability. It is quite possible that we can't infer @@ -674,7 +675,8 @@ function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, method if isdefined(code, :inferred) && !cache_inlineable cache_inf = code.inferred if !(cache_inf === nothing) - cache_inlineable = inlining_policy(interp)(cache_inf) !== nothing + # TODO maybe we want to respect callsite `@inline`/`@noinline` annotations here ? + cache_inlineable = inlining_policy(interp)(cache_inf, nothing, match) !== nothing end end if !cache_inlineable @@ -1844,7 +1846,8 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) if isa(fname, SlotNumber) changes = StateUpdate(fname, VarState(Any, false), changes, false) end - elseif hd === :inbounds || hd === :meta || hd === :loopinfo || hd === :code_coverage_effect + elseif hd === :code_coverage_effect || + (hd !== :boundscheck && hd !== nothing && is_meta_expr_head(hd)) # :boundscheck can be narrowed to Bool # these do not generate code else t = abstract_eval_statement(interp, stmt, changes, frame) diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index 6d059247a43ea..ef7fde422fc98 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -28,14 +28,20 @@ struct InliningState{S <: Union{EdgeTracker, Nothing}, T, P} policy::P end -function default_inlining_policy(@nospecialize(src)) +function default_inlining_policy(@nospecialize(src), stmt_flag::Union{Nothing,UInt8}, match::Union{MethodMatch,InferenceResult}) if isa(src, CodeInfo) || isa(src, Vector{UInt8}) src_inferred = ccall(:jl_ir_flag_inferred, Bool, (Any,), src) - src_inlineable = ccall(:jl_ir_flag_inlineable, Bool, (Any,), src) + src_inlineable = is_stmt_inline(stmt_flag) || ccall(:jl_ir_flag_inlineable, Bool, (Any,), src) return src_inferred && src_inlineable ? src : nothing - end - if isa(src, OptimizationState) && isdefined(src, :ir) - return src.src.inlineable ? src.ir : nothing + elseif isa(src, OptimizationState) && isdefined(src, :ir) + return (is_stmt_inline(stmt_flag) || src.src.inlineable) ? src.ir : nothing + elseif src === nothing && is_stmt_inline(stmt_flag) && isa(match, MethodMatch) + # when the source isn't available at this moment, try to re-infer and inline it + # HACK in order to avoid cycles here, we disable inlining and makes sure the following inference never comes here + # TODO sort out `AbstractInterpreter` interface to handle this well, and also inference should try to keep the source if the statement will be inlined + interp = NativeInterpreter(; opt_params = OptimizationParams(; inlining = false)) + src, rt = typeinf_code(interp, match.method, match.spec_types, match.sparams, true) + return src end return nothing end @@ -134,6 +140,10 @@ const SLOT_USEDUNDEF = 32 # slot has uses that might raise UndefVarError # This statement was marked as @inbounds by the user. If replaced by inlining, # any contained boundschecks may be removed const IR_FLAG_INBOUNDS = 0x01 +# This statement was marked as @inline by the user +const IR_FLAG_INLINE = 0x01 << 1 +# This statement was marked as @noinline by the user +const IR_FLAG_NOINLINE = 0x01 << 2 # This statement may be removed if its result is unused. In particular it must # thus be both pure and effect free. const IR_FLAG_EFFECT_FREE = 0x01 << 4 @@ -174,11 +184,16 @@ function isinlineable(m::Method, me::OptimizationState, params::OptimizationPara end end if !inlineable - inlineable = inline_worthy(me.ir, params, union_penalties, cost_threshold + bonus) + inlineable = inline_worthy(me.ir::IRCode, params, union_penalties, cost_threshold + bonus) end return inlineable end +is_stmt_inline(stmt_flag::UInt8) = stmt_flag & IR_FLAG_INLINE != 0 +is_stmt_inline(::Nothing) = false +is_stmt_noinline(stmt_flag::UInt8) = stmt_flag & IR_FLAG_NOINLINE != 0 +is_stmt_noinline(::Nothing) = false # not used for now + # These affect control flow within the function (so may not be removed # if there is no usage within the function), but don't affect the purity # of the function as a whole. @@ -366,6 +381,7 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg renumber_ir_elements!(code, changemap, labelmap) inbounds_depth = 0 # Number of stacked inbounds + inline_flags = BitVector() meta = Any[] flags = fill(0x00, length(code)) for i = 1:length(code) @@ -380,16 +396,38 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg inbounds_depth -= 1 end stmt = nothing + elseif isexpr(stmt, :inline) + if stmt.args[1]::Bool + push!(inline_flags, true) + else + pop!(inline_flags) + end + stmt = nothing + elseif isexpr(stmt, :noinline) + if stmt.args[1]::Bool + push!(inline_flags, false) + else + pop!(inline_flags) + end + stmt = nothing else stmt = normalize(stmt, meta) end code[i] = stmt - if !(stmt === nothing) + if stmt !== nothing if inbounds_depth > 0 flags[i] |= IR_FLAG_INBOUNDS end + if !isempty(inline_flags) + if last(inline_flags) + flags[i] |= IR_FLAG_INLINE + else + flags[i] |= IR_FLAG_NOINLINE + end + end end end + @assert isempty(inline_flags) "malformed meta flags" strip_trailing_junk!(ci, code, stmtinfo, flags) cfg = compute_basic_blocks(code) types = Any[] diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 78edef88439e9..999b67e26a1ab 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -600,6 +600,7 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx:: argexprs::Vector{Any}, atypes::Vector{Any}, arginfos::Vector{Any}, arg_start::Int, istate::InliningState) + flag = ir.stmts[idx][:flag] new_argexprs = Any[argexprs[arg_start]] new_atypes = Any[atypes[arg_start]] # loop over original arguments and flatten any known iterators @@ -655,8 +656,9 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx:: info = call.info handled = false if isa(info, ConstCallInfo) - if maybe_handle_const_call!(ir, state1.id, new_stmt, info, new_sig, - call.rt, istate, false, todo) + if !is_stmt_noinline(flag) && maybe_handle_const_call!( + ir, state1.id, new_stmt, info, new_sig,call.rt, istate, flag, false, todo) + handled = true else info = info.call @@ -667,7 +669,7 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx:: MethodMatchInfo[info] : info.matches # See if we can inline this call to `iterate` analyze_single_call!(ir, todo, state1.id, new_stmt, - new_sig, call.rt, info, istate) + new_sig, call.rt, info, istate, flag) end if i != length(thisarginfo.each) valT = getfield_tfunc(call.rt, Const(1)) @@ -716,16 +718,16 @@ function compileable_specialization(et::Union{EdgeTracker, Nothing}, result::Inf return mi end -function resolve_todo(todo::InliningTodo, state::InliningState) - spec = todo.spec::DelayedInliningSpec +function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8) + (; match) = todo.spec::DelayedInliningSpec #XXX: update_valid_age!(min_valid[1], max_valid[1], sv) isconst, src = false, nothing - if isa(spec.match, InferenceResult) - let inferred_src = spec.match.src + if isa(match, InferenceResult) + let inferred_src = match.src if isa(inferred_src, Const) if !is_inlineable_constant(inferred_src.val) - return compileable_specialization(state.et, spec.match) + return compileable_specialization(state.et, match) end isconst, src = true, quoted(inferred_src.val) else @@ -753,12 +755,10 @@ function resolve_todo(todo::InliningTodo, state::InliningState) return ConstantCase(src) end - if src !== nothing - src = state.policy(src) - end + src = state.policy(src, flag, match) if src === nothing - return compileable_specialization(et, spec.match) + return compileable_specialization(et, match) end if isa(src, IRCode) @@ -769,17 +769,9 @@ function resolve_todo(todo::InliningTodo, state::InliningState) return InliningTodo(todo.mi, src) end -function resolve_todo(todo::UnionSplit, state::InliningState) +function resolve_todo(todo::UnionSplit, state::InliningState, flag::UInt8) UnionSplit(todo.fully_covered, todo.atype, - Pair{Any,Any}[sig=>resolve_todo(item, state) for (sig, item) in todo.cases]) -end - -function resolve_todo!(todo::Vector{Pair{Int, Any}}, state::InliningState) - for i = 1:length(todo) - idx, item = todo[i] - todo[i] = idx=>resolve_todo(item, state) - end - todo + Pair{Any,Any}[sig=>resolve_todo(item, state, flag) for (sig, item) in todo.cases]) end function validate_sparams(sparams::SimpleVector) @@ -790,7 +782,7 @@ function validate_sparams(sparams::SimpleVector) end function analyze_method!(match::MethodMatch, atypes::Vector{Any}, - state::InliningState, @nospecialize(stmttyp)) + state::InliningState, @nospecialize(stmttyp), flag::UInt8) method = match.method methsig = method.sig @@ -806,11 +798,9 @@ function analyze_method!(match::MethodMatch, atypes::Vector{Any}, end # Bail out if any static parameters are left as TypeVar - ok = true validate_sparams(match.sparams) || return nothing - - if !state.params.inlining + if !state.params.inlining || is_stmt_noinline(flag) return compileable_specialization(state.et, match) end @@ -824,7 +814,7 @@ function analyze_method!(match::MethodMatch, atypes::Vector{Any}, # If we don't have caches here, delay resolving this MethodInstance # until the batch inlining step (or an external post-processing pass) state.mi_cache === nothing && return todo - return resolve_todo(todo, state) + return resolve_todo(todo, state, flag) end function InliningTodo(mi::MethodInstance, ir::IRCode) @@ -1050,7 +1040,7 @@ is_builtin(s::Signature) = s.ft ⊑ Builtin function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, (; match, result)::InvokeCallInfo, - state::InliningState, todo::Vector{Pair{Int, Any}}) + state::InliningState, todo::Vector{Pair{Int, Any}}, flag::UInt8) stmt = ir.stmts[idx][:inst] calltype = ir.stmts[idx][:type] @@ -1064,17 +1054,17 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, (; match, result): atypes = atypes[4:end] pushfirst!(atypes, atype0) - if isa(result, InferenceResult) + if isa(result, InferenceResult) && !is_stmt_noinline(flag) item = InliningTodo(result, atypes, calltype) validate_sparams(item.mi.sparam_vals) || return nothing if argtypes_to_type(atypes) <: item.mi.def.sig - state.mi_cache !== nothing && (item = resolve_todo(item, state)) + state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) handle_single_case!(ir, stmt, idx, item, true, todo) return nothing end end - result = analyze_method!(match, atypes, state, calltype) + result = analyze_method!(match, atypes, state, calltype, flag) handle_single_case!(ir, stmt, idx, result, true, todo) return nothing end @@ -1169,7 +1159,7 @@ end function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, @nospecialize(stmt), sig::Signature, @nospecialize(calltype), infos::Vector{MethodMatchInfo}, - state::InliningState) + state::InliningState, flag::UInt8) cases = Pair{Any, Any}[] signature_union = Union{} only_method = nothing # keep track of whether there is one matching method @@ -1202,7 +1192,7 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int fully_covered = false continue end - case = analyze_method!(match, sig.atypes, state, calltype) + case = analyze_method!(match, sig.atypes, state, calltype, flag) if case === nothing fully_covered = false continue @@ -1229,7 +1219,7 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int match = meth[1] end fully_covered = true - case = analyze_method!(match, sig.atypes, state, calltype) + case = analyze_method!(match, sig.atypes, state, calltype, flag) case === nothing && return push!(cases, Pair{Any,Any}(match.spec_types, case)) end @@ -1251,7 +1241,7 @@ end function maybe_handle_const_call!(ir::IRCode, idx::Int, stmt::Expr, info::ConstCallInfo, sig::Signature, @nospecialize(calltype), - state::InliningState, + state::InliningState, flag::UInt8, isinvoke::Bool, todo::Vector{Pair{Int, Any}}) # when multiple matches are found, bail out and later inliner will union-split this signature # TODO effectively use multiple constant analysis results here @@ -1263,7 +1253,7 @@ function maybe_handle_const_call!(ir::IRCode, idx::Int, stmt::Expr, validate_sparams(item.mi.sparam_vals) || return true mthd_sig = item.mi.def.sig mistypes = item.mi.specTypes - state.mi_cache !== nothing && (item = resolve_todo(item, state)) + state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) if sig.atype <: mthd_sig handle_single_case!(ir, stmt, idx, item, isinvoke, todo) return true @@ -1301,6 +1291,8 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) info = info.info end + flag = ir.stmts[idx][:flag] + # Inference determined this couldn't be analyzed. Don't question it. if info === false continue @@ -1310,7 +1302,8 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) # it'll have performed a specialized analysis for just this case. Use its # result. if isa(info, ConstCallInfo) - if maybe_handle_const_call!(ir, idx, stmt, info, sig, calltype, state, sig.f === Core.invoke, todo) + if !is_stmt_noinline(flag) && maybe_handle_const_call!( + ir, idx, stmt, info, sig, calltype, state, flag, sig.f === Core.invoke, todo) continue else info = info.call @@ -1318,7 +1311,7 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) end if isa(info, OpaqueClosureCallInfo) - result = analyze_method!(info.match, sig.atypes, state, calltype) + result = analyze_method!(info.match, sig.atypes, state, calltype, flag) handle_single_case!(ir, stmt, idx, result, false, todo) continue end @@ -1326,7 +1319,7 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) # Handle invoke if sig.f === Core.invoke if isa(info, InvokeCallInfo) - inline_invoke!(ir, idx, sig, info, state, todo) + inline_invoke!(ir, idx, sig, info, state, todo, flag) end continue end @@ -1340,7 +1333,7 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) continue end - analyze_single_call!(ir, todo, idx, stmt, sig, calltype, infos, state) + analyze_single_call!(ir, todo, idx, stmt, sig, calltype, infos, state, flag) end todo end diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index 4ad96ae2e72f0..7a6df9e6bc231 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -343,7 +343,7 @@ function maybe_compress_codeinfo(interp::AbstractInterpreter, linfo::MethodInsta nslots = length(ci.slotflags) resize!(ci.slottypes, nslots) resize!(ci.slotnames, nslots) - return ccall(:jl_compress_ir, Any, (Any, Any), def, ci) + return ccall(:jl_compress_ir, Vector{UInt8}, (Any, Any), def, ci) else return ci end diff --git a/base/compiler/types.jl b/base/compiler/types.jl index 773047d2b00e5..a579fa13f4989 100644 --- a/base/compiler/types.jl +++ b/base/compiler/types.jl @@ -10,10 +10,10 @@ swapped in as long as they follow the AbstractInterpreter API. All AbstractInterpreters are expected to provide at least the following methods: -- InferenceParams(interp) - return an `InferenceParams` instance -- OptimizationParams(interp) - return an `OptimizationParams` instance -- get_world_counter(interp) - return the world age for this interpreter -- get_inference_cache(interp) - return the runtime inference cache +- `InferenceParams(interp)` - return an `InferenceParams` instance +- `OptimizationParams(interp)` - return an `OptimizationParams` instance +- `get_world_counter(interp)` - return the world age for this interpreter +- `get_inference_cache(interp)` - return the runtime inference cache """ abstract type AbstractInterpreter; end diff --git a/base/compiler/utilities.jl b/base/compiler/utilities.jl index 3b84395c676d2..58b2eeeeb9aef 100644 --- a/base/compiler/utilities.jl +++ b/base/compiler/utilities.jl @@ -59,7 +59,8 @@ end # Meta expression head, these generally can't be deleted even when they are # in a dead branch but can be ignored when analyzing uses/liveness. -is_meta_expr_head(head::Symbol) = (head === :inbounds || head === :boundscheck || head === :meta || head === :loopinfo) +is_meta_expr_head(head::Symbol) = (head === :inbounds || head === :boundscheck || head === :meta || + head === :loopinfo || head === :inline || head === :noinline) sym_isless(a::Symbol, b::Symbol) = ccall(:strcmp, Int32, (Ptr{UInt8}, Ptr{UInt8}), a, b) < 0 @@ -187,7 +188,7 @@ function specialize_method(method::Method, @nospecialize(atypes), sparams::Simpl if preexisting # check cached specializations # for an existing result stored there - return ccall(:jl_specializations_lookup, Any, (Any, Any), method, atypes) + return ccall(:jl_specializations_lookup, Any, (Any, Any), method, atypes)::Union{Nothing,MethodInstance} end return ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any), method, atypes, sparams) end diff --git a/base/compiler/validation.jl b/base/compiler/validation.jl index f6b89f8f5cd04..0c49ccd2bb269 100644 --- a/base/compiler/validation.jl +++ b/base/compiler/validation.jl @@ -16,6 +16,8 @@ const VALID_EXPR_HEADS = IdDict{Symbol,UnitRange}( :leave => 1:1, :pop_exception => 1:1, :inbounds => 1:1, + :inline => 1:1, + :noinline => 1:1, :boundscheck => 0:0, :copyast => 1:1, :meta => 0:typemax(Int), @@ -141,7 +143,7 @@ function validate_code!(errors::Vector{>:InvalidCodeError}, c::CodeInfo, is_top_ head === :const || head === :enter || head === :leave || head === :pop_exception || head === :method || head === :global || head === :static_parameter || head === :new || head === :splatnew || head === :thunk || head === :loopinfo || - head === :throw_undef_if_not || head === :code_coverage_effect + head === :throw_undef_if_not || head === :code_coverage_effect || head === :inline || head === :noinline validate_val!(x) else # TODO: nothing is actually in statement position anymore diff --git a/base/expr.jl b/base/expr.jl index 9df363714679e..483e75067e478 100644 --- a/base/expr.jl +++ b/base/expr.jl @@ -209,11 +209,52 @@ end !!! compat "Julia 1.8" The usage within a function body requires at least Julia 1.8. + +--- + @inline block + +Give a hint to the compiler that calls within `block` are worth inlining. + +```julia +# The compiler will try to inline `f` +@inline f(...) + +# The compiler will try to inline `f`, `g` and `+` +@inline f(...) + g(...) +``` + +!!! note + A callsite annotation always has the precedence over the annotation applied to the + definition of the called function: + ```julia + @noinline function explicit_noinline(args...) + # body + end + + let + @inline explicit_noinline(args...) # will be inlined + end + ``` + +!!! note + When there are nested callsite annotations, the innermost annotation has the precedence: + ```julia + @noinline let a0, b0 = ... + a = @inline f(a0) # the compiler will try to inline this call + b = f(b0) # the compiler will NOT try to inline this call + return a, b + end + ``` + +!!! compat "Julia 1.8" + The callsite annotation requires at least Julia 1.8. """ -macro inline(ex) - esc(isa(ex, Expr) ? pushmeta!(ex, :inline) : ex) +macro inline(x) + return annotate_meta_def_or_block(x, :inline) +end +macro inline() + return Expr(:meta, :inline) end -macro inline() Expr(:meta, :inline) end """ @noinline @@ -245,13 +286,55 @@ end !!! compat "Julia 1.8" The usage within a function body requires at least Julia 1.8. +--- + @noinline block + +Give a hint to the compiler that it should not inline the calls within `block`. + +```julia +# The compiler will try to not inline `f` +@noinline f(...) + +# The compiler will try to not inline `f`, `g` and `+` +@noinline f(...) + g(...) +``` + +!!! note + A callsite annotation always has the precedence over the annotation applied to the + definition of the called function: + ```julia + @inline function explicit_inline(args...) + # body + end + + let + @noinline explicit_inline(args...) # will not be inlined + end + ``` + +!!! note + When there are nested callsite annotations, the innermost annotation has the precedence: + ```julia + @inline let a0, b0 = ... + a = @noinline f(a0) # the compiler will NOT try to inline this call + b = f(b0) # the compiler will try to inline this call + return a, b + end + ``` + +!!! compat "Julia 1.8" + The callsite annotation requires at least Julia 1.8. + +--- !!! note If the function is trivial (for example returning a constant) it might get inlined anyway. """ -macro noinline(ex) - esc(isa(ex, Expr) ? pushmeta!(ex, :noinline) : ex) +macro noinline(x) + return annotate_meta_def_or_block(x, :noinline) +end +macro noinline() + return Expr(:meta, :noinline) end -macro noinline() Expr(:meta, :noinline) end """ @pure ex @@ -303,6 +386,15 @@ end ## some macro utilities ## +unwrap_macrocalls(@nospecialize(x)) = x +function unwrap_macrocalls(ex::Expr) + inner = ex + while inner.head === :macrocall + inner = inner.args[end]::Expr + end + return inner +end + function pushmeta!(ex::Expr, sym::Symbol, args::Any...) if isempty(args) tag = sym @@ -310,10 +402,7 @@ function pushmeta!(ex::Expr, sym::Symbol, args::Any...) tag = Expr(sym, args...)::Expr end - inner = ex - while inner.head === :macrocall - inner = inner.args[end]::Expr - end + inner = unwrap_macrocalls(ex) idx, exargs = findmeta(inner) if idx != 0 @@ -363,8 +452,23 @@ function findmetaarg(metaargs, sym) return 0 end -function is_short_function_def(ex) - ex.head === :(=) || return false +function annotate_meta_def_or_block(@nospecialize(ex), meta::Symbol) + inner = unwrap_macrocalls(ex) + if is_function_def(inner) + # annotation on a definition + return esc(pushmeta!(ex, meta)) + else + # annotation on a block + return Expr(:block, + Expr(meta, true), + Expr(:local, Expr(:(=), :val, esc(ex))), + Expr(meta, false), + :val) + end +end + +function is_short_function_def(@nospecialize(ex)) + isexpr(ex, :(=)) || return false while length(ex.args) >= 1 && isa(ex.args[1], Expr) (ex.args[1].head === :call) && return true (ex.args[1].head === :where || ex.args[1].head === :(::)) || return false @@ -372,9 +476,11 @@ function is_short_function_def(ex) end return false end +is_function_def(@nospecialize(ex)) = + return isexpr(ex, :function) || is_short_function_def(ex) || isexpr(ex, :->) function findmeta(ex::Expr) - if ex.head === :function || is_short_function_def(ex) || ex.head === :-> + if is_function_def(ex) body = ex.args[2]::Expr body.head === :block || error(body, " is not a block expression") return findmeta_block(ex.args) diff --git a/base/meta.jl b/base/meta.jl index b483630a92f8f..3fe815cd0cbc0 100644 --- a/base/meta.jl +++ b/base/meta.jl @@ -450,6 +450,7 @@ end _instantiate_type_in_env(x, spsig, spvals) = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), x, spsig, spvals) -is_meta_expr_head(head::Symbol) = (head === :inbounds || head === :boundscheck || head === :meta || head === :loopinfo) +is_meta_expr_head(head::Symbol) = (head === :inbounds || head === :boundscheck || head === :meta + || head === :loopinfo || head === :inline || head === :noinline) end # module diff --git a/src/ast.scm b/src/ast.scm index bc8d847279fc9..e5148a507a4fd 100644 --- a/src/ast.scm +++ b/src/ast.scm @@ -289,7 +289,7 @@ ;; predicates and accessors (define (quoted? e) - (memq (car e) '(quote top core globalref outerref line break inert meta inbounds loopinfo))) + (memq (car e) '(quote top core globalref outerref line break inert meta inbounds inline noinline loopinfo))) (define (quotify e) `',e) (define (unquote e) (if (and (pair? e) (memq (car e) '(quote inert))) diff --git a/src/codegen.cpp b/src/codegen.cpp index d59b312716b2d..21daeee2a73b6 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -4401,7 +4401,7 @@ static void emit_stmtpos(jl_codectx_t &ctx, jl_value_t *expr, int ssaval_result) jl_value_t **args = (jl_value_t**)jl_array_data(ex->args); jl_sym_t *head = ex->head; if (head == meta_sym || head == inbounds_sym || head == coverageeffect_sym - || head == aliasscope_sym || head == popaliasscope_sym) { + || head == aliasscope_sym || head == popaliasscope_sym || head == inline_sym || head == noinline_sym) { // some expression types are metadata and can be ignored // in statement position return; @@ -4836,7 +4836,7 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaval) } else if (head == leave_sym || head == coverageeffect_sym || head == pop_exception_sym || head == enter_sym || head == inbounds_sym - || head == aliasscope_sym || head == popaliasscope_sym) { + || head == aliasscope_sym || head == popaliasscope_sym || head == inline_sym || head == noinline_sym) { jl_errorf("Expr(:%s) in value position", jl_symbol_name(head)); } else if (head == boundscheck_sym) { diff --git a/src/interpreter.c b/src/interpreter.c index 7858bd6ddc4ea..2be907e82513c 100644 --- a/src/interpreter.c +++ b/src/interpreter.c @@ -311,7 +311,7 @@ static jl_value_t *eval_value(jl_value_t *e, interpreter_state *s) return jl_true; } else if (head == meta_sym || head == coverageeffect_sym || head == inbounds_sym || head == loopinfo_sym || - head == aliasscope_sym || head == popaliasscope_sym) { + head == aliasscope_sym || head == popaliasscope_sym || head == inline_sym || head == noinline_sym) { return jl_nothing; } else if (head == gc_preserve_begin_sym || head == gc_preserve_end_sym) { diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index f00ea0c9ba6d9..428b0513b7e52 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -3498,7 +3498,7 @@ f(x) = yt(x) thunk with-static-parameters toplevel-only global globalref outerref const-if-global thismodule const atomic null true false ssavalue isdefined toplevel module lambda - error gc_preserve_begin gc_preserve_end import using export))) + error gc_preserve_begin gc_preserve_end import using export inline noinline))) (define (local-in? s lam) (or (assq s (car (lam:vinfo lam))) @@ -4592,7 +4592,7 @@ f(x) = yt(x) (cons (car e) args))) ;; metadata expressions - ((line meta inbounds loopinfo gc_preserve_end aliasscope popaliasscope) + ((line meta inbounds loopinfo gc_preserve_end aliasscope popaliasscope inline noinline) (let ((have-ret? (and (pair? code) (pair? (car code)) (eq? (caar code) 'return)))) (cond ((eq? (car e) 'line) (set! current-loc e) @@ -4737,7 +4737,7 @@ f(x) = yt(x) (begin (set! linetable (cons (make-lineinfo name file line) linetable)) (set! current-loc 1))) (if (or reachable - (and (pair? e) (memq (car e) '(meta inbounds gc_preserve_begin gc_preserve_end aliasscope popaliasscope)))) + (and (pair? e) (memq (car e) '(meta inbounds gc_preserve_begin gc_preserve_end aliasscope popaliasscope inline noinline)))) (begin (set! code (cons e code)) (set! i (+ i 1)) (set! locs (cons current-loc locs))))) diff --git a/src/macroexpand.scm b/src/macroexpand.scm index 5e55c7bbb29c1..f17f4d3510dc6 100644 --- a/src/macroexpand.scm +++ b/src/macroexpand.scm @@ -352,7 +352,7 @@ ,(resolve-expansion-vars-with-new-env (caddr arg) env m parent-scope inarg)))) (else `(global ,(resolve-expansion-vars-with-new-env arg env m parent-scope inarg)))))) - ((using import export meta line inbounds boundscheck loopinfo) (map unescape e)) + ((using import export meta line inbounds boundscheck loopinfo inline noinline) (map unescape e)) ((macrocall) e) ; invalid syntax anyways, so just act like it's quoted. ((symboliclabel) e) ((symbolicgoto) e) diff --git a/src/method.c b/src/method.c index 48b074e800904..852a7ff88208a 100644 --- a/src/method.c +++ b/src/method.c @@ -84,7 +84,8 @@ static jl_value_t *resolve_globals(jl_value_t *expr, jl_module_t *module, jl_sve e->head == quote_sym || e->head == inert_sym || e->head == meta_sym || e->head == inbounds_sym || e->head == boundscheck_sym || e->head == loopinfo_sym || - e->head == aliasscope_sym || e->head == popaliasscope_sym) { + e->head == aliasscope_sym || e->head == popaliasscope_sym || + e->head == inline_sym || e->head == noinline_sym) { // ignore these } else { diff --git a/test/compiler/inline.jl b/test/compiler/inline.jl index 00797304ce5c0..1282fe5f2dd07 100644 --- a/test/compiler/inline.jl +++ b/test/compiler/inline.jl @@ -498,6 +498,96 @@ end end end +@testset "callsite @inline/@noinline annotations" begin + m = Module() + @eval m begin + # this global variable prevents inference to fold everything as constant, and/or the optimizer to inline the call accessing to this + g = 0 + + @noinline noinlined_explicit(x) = x + force_inline_explicit(x) = @inline noinlined_explicit(x) + force_inline_block_explicit(x) = @inline noinlined_explicit(x) + noinlined_explicit(x) + noinlined_implicit(x) = g + force_inline_implicit(x) = @inline noinlined_implicit(x) + force_inline_block_implicit(x) = @inline noinlined_implicit(x) + noinlined_implicit(x) + + @inline inlined_explicit(x) = x + force_noinline_explicit(x) = @noinline inlined_explicit(x) + force_noinline_block_explicit(x) = @noinline inlined_explicit(x) + inlined_explicit(x) + inlined_implicit(x) = x + force_noinline_implicit(x) = @noinline inlined_implicit(x) + force_noinline_block_implicit(x) = @noinline inlined_implicit(x) + inlined_implicit(x) + + # test callsite annotations for constant-prop'ed calls + + @noinline Base.@aggressive_constprop noinlined_constprop_explicit(a) = a+g + force_inline_constprop_explicit() = @inline noinlined_constprop_explicit(0) + Base.@aggressive_constprop noinlined_constprop_implicit(a) = a+g + force_inline_constprop_implicit() = @inline noinlined_constprop_implicit(0) + + @inline Base.@aggressive_constprop inlined_constprop_explicit(a) = a+g + force_noinline_constprop_explicit() = @noinline inlined_constprop_explicit(0) + @inline Base.@aggressive_constprop inlined_constprop_implicit(a) = a+g + force_noinline_constprop_implicit() = @noinline inlined_constprop_implicit(0) + + @noinline notinlined(a) = a + function nested(a0, b0) + @noinline begin + a = @inline notinlined(a0) # this call should be inlined + b = notinlined(b0) # this call should NOT be inlined + return a, b + end + end + end + + let ci = code_typed1(m.force_inline_explicit, (Int,)) + @test all(x->!isinvoke(x, :noinlined_explicit), ci.code) + end + let ci = code_typed1(m.force_inline_block_explicit, (Int,)) + @test all(ci.code) do x + !isinvoke(x, :noinlined_explicit) && + !isinvoke(x, :(+)) + end + end + let ci = code_typed1(m.force_inline_implicit, (Int,)) + @test all(x->!isinvoke(x, :noinlined_implicit), ci.code) + end + let ci = code_typed1(m.force_inline_block_implicit, (Int,)) + @test all(x->!isinvoke(x, :noinlined_explicit), ci.code) + end + + let ci = code_typed1(m.force_noinline_explicit, (Int,)) + @test any(x->isinvoke(x, :inlined_explicit), ci.code) + end + let ci = code_typed1(m.force_noinline_block_explicit, (Int,)) + @test count(x->isinvoke(x, :inlined_explicit), ci.code) == 2 + end + let ci = code_typed1(m.force_noinline_implicit, (Int,)) + @test any(x->isinvoke(x, :inlined_implicit), ci.code) + end + let ci = code_typed1(m.force_noinline_block_implicit, (Int,)) + @test count(x->isinvoke(x, :inlined_implicit), ci.code) == 2 + end + + let ci = code_typed1(m.force_inline_constprop_explicit) + @test all(x->!isinvoke(x, :noinlined_constprop_explicit), ci.code) + end + let ci = code_typed1(m.force_inline_constprop_implicit) + @test all(x->!isinvoke(x, :noinlined_constprop_implicit), ci.code) + end + + let ci = code_typed1(m.force_noinline_constprop_explicit) + @test any(x->isinvoke(x, :inlined_constprop_explicit), ci.code) + end + let ci = code_typed1(m.force_noinline_constprop_implicit) + @test any(x->isinvoke(x, :inlined_constprop_implicit), ci.code) + end + + let ci = code_typed1(m.nested, (Int,Int)) + @test count(x->isinvoke(x, :notinlined), ci.code) == 1 + end +end + # Issue #41299 - inlining deletes error check in :> g41299(f::Tf, args::Vararg{Any,N}) where {Tf,N} = f(args...) @test_throws TypeError g41299(>:, 1, 2)