Skip to content

Commit b422883

Browse files
KenoCédric Belmant
and
Cédric Belmant
authored
Move inlinability determination into cache transform (JuliaLang#57979)
Currently the inlineability determination is in a bit of an odd spot - just after the optimizers while everything is still in IRCode. It seems more sensible to move this code into the cache transformation code, which is the first place that makes an actual decision based on inlineability. If an external AbstractInterpreter does not need to covert to CodeInfo for compilation purposes this also potentially saves that extra conversion. While we're at it, clean up some naming to deconflict it with other uses. --------- Co-authored-by: Cédric Belmant <cedric.belmant@juliahub.com>
1 parent 92fc06f commit b422883

File tree

6 files changed

+149
-89
lines changed

6 files changed

+149
-89
lines changed

Compiler/extras/CompilerDevTools/src/CompilerDevTools.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ end
4747

4848
function Compiler.transform_result_for_cache(interp::SplitCacheInterp, result::Compiler.InferenceResult, edges::Compiler.SimpleVector)
4949
opt = result.src::Compiler.OptimizationState
50-
ir = opt.result.ir::Compiler.IRCode
50+
ir = opt.optresult.ir::Compiler.IRCode
5151
override = with_new_compiler
5252
for inst in ir.stmts
5353
stmt = inst[:stmt]

Compiler/src/optimize.jl

+34-64
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,14 @@ function inline_cost_clamp(x::Int)
116116
return convert(InlineCostType, x)
117117
end
118118

119+
const SRC_FLAG_DECLARED_INLINE = 0x1
120+
const SRC_FLAG_DECLARED_NOINLINE = 0x2
121+
119122
is_declared_inline(@nospecialize src::MaybeCompressed) =
120-
ccall(:jl_ir_flag_inlining, UInt8, (Any,), src) == 1
123+
ccall(:jl_ir_flag_inlining, UInt8, (Any,), src) == SRC_FLAG_DECLARED_INLINE
121124

122125
is_declared_noinline(@nospecialize src::MaybeCompressed) =
123-
ccall(:jl_ir_flag_inlining, UInt8, (Any,), src) == 2
126+
ccall(:jl_ir_flag_inlining, UInt8, (Any,), src) == SRC_FLAG_DECLARED_NOINLINE
124127

125128
#####################
126129
# OptimizationState #
@@ -157,6 +160,7 @@ code_cache(state::InliningState) = WorldView(code_cache(state.interp), state.wor
157160

158161
mutable struct OptimizationResult
159162
ir::IRCode
163+
inline_flag::UInt8
160164
simplified::Bool # indicates whether the IR was processed with `cfg_simplify!`
161165
end
162166

@@ -168,7 +172,7 @@ end
168172
mutable struct OptimizationState{Interp<:AbstractInterpreter}
169173
linfo::MethodInstance
170174
src::CodeInfo
171-
result::Union{Nothing, OptimizationResult}
175+
optresult::Union{Nothing, OptimizationResult}
172176
stmt_info::Vector{CallInfo}
173177
mod::Module
174178
sptypes::Vector{VarState}
@@ -236,13 +240,29 @@ include("ssair/EscapeAnalysis.jl")
236240
include("ssair/passes.jl")
237241
include("ssair/irinterp.jl")
238242

243+
function ir_to_codeinf!(opt::OptimizationState, frame::InferenceState, edges::SimpleVector)
244+
ir_to_codeinf!(opt, edges, compute_inlining_cost(frame.interp, frame.result, opt.optresult))
245+
end
246+
247+
function ir_to_codeinf!(opt::OptimizationState, edges::SimpleVector, inlining_cost::InlineCostType)
248+
src = ir_to_codeinf!(opt, edges)
249+
src.inlining_cost = inlining_cost
250+
src
251+
end
252+
253+
function ir_to_codeinf!(opt::OptimizationState, edges::SimpleVector)
254+
src = ir_to_codeinf!(opt)
255+
src.edges = edges
256+
src
257+
end
258+
239259
function ir_to_codeinf!(opt::OptimizationState)
240-
(; linfo, src, result) = opt
241-
if result === nothing
260+
(; linfo, src, optresult) = opt
261+
if optresult === nothing
242262
return src
243263
end
244-
src = ir_to_codeinf!(src, result.ir)
245-
opt.result = nothing
264+
src = ir_to_codeinf!(src, optresult.ir)
265+
opt.optresult = nothing
246266
opt.src = src
247267
maybe_validate_code(linfo, src, "optimized")
248268
return src
@@ -485,63 +505,12 @@ end
485505
abstract_eval_ssavalue(s::SSAValue, src::Union{IRCode,IncrementalCompact}) = types(src)[s]
486506

487507
"""
488-
finish(interp::AbstractInterpreter, opt::OptimizationState,
489-
ir::IRCode, caller::InferenceResult)
508+
finishopt!(interp::AbstractInterpreter, opt::OptimizationState, ir::IRCode)
490509
491-
Post-process information derived by Julia-level optimizations for later use.
492-
In particular, this function determines the inlineability of the optimized code.
510+
Called at the end of optimization to store the resulting IR back into the OptimizationState.
493511
"""
494-
function finish(interp::AbstractInterpreter, opt::OptimizationState,
495-
ir::IRCode, caller::InferenceResult)
496-
(; src, linfo) = opt
497-
(; def, specTypes) = linfo
498-
499-
force_noinline = is_declared_noinline(src)
500-
501-
# compute inlining and other related optimizations
502-
result = caller.result
503-
@assert !(result isa LimitedAccuracy)
504-
result = widenslotwrapper(result)
505-
506-
opt.result = OptimizationResult(ir, false)
507-
508-
# determine and cache inlineability
509-
if !force_noinline
510-
sig = unwrap_unionall(specTypes)
511-
if !(isa(sig, DataType) && sig.name === Tuple.name)
512-
force_noinline = true
513-
end
514-
if !is_declared_inline(src) && result === Bottom
515-
force_noinline = true
516-
end
517-
end
518-
if force_noinline
519-
set_inlineable!(src, false)
520-
elseif isa(def, Method)
521-
if is_declared_inline(src) && isdispatchtuple(specTypes)
522-
# obey @inline declaration if a dispatch barrier would not help
523-
set_inlineable!(src, true)
524-
else
525-
# compute the cost (size) of inlining this code
526-
params = OptimizationParams(interp)
527-
cost_threshold = default = params.inline_cost_threshold
528-
if (optimizer_lattice(interp), result, Tuple) && !isconcretetype(widenconst(result))
529-
cost_threshold += params.inline_tupleret_bonus
530-
end
531-
# if the method is declared as `@inline`, increase the cost threshold 20x
532-
if is_declared_inline(src)
533-
cost_threshold += 19*default
534-
end
535-
# a few functions get special treatment
536-
if def.module === _topmod(def.module)
537-
name = def.name
538-
if name === :iterate || name === :unsafe_convert || name === :cconvert
539-
cost_threshold += 4*default
540-
end
541-
end
542-
src.inlining_cost = inline_cost(ir, params, cost_threshold)
543-
end
544-
end
512+
function finishopt!(interp::AbstractInterpreter, opt::OptimizationState, ir::IRCode)
513+
opt.optresult = OptimizationResult(ir, ccall(:jl_ir_flag_inlining, UInt8, (Any,), opt.src), false)
545514
return nothing
546515
end
547516

@@ -1015,7 +984,8 @@ end
1015984
function optimize(interp::AbstractInterpreter, opt::OptimizationState, caller::InferenceResult)
1016985
@timeit "optimizer" ir = run_passes_ipo_safe(opt.src, opt)
1017986
ipo_dataflow_analysis!(interp, opt, ir, caller)
1018-
return finish(interp, opt, ir, caller)
987+
finishopt!(interp, opt, ir)
988+
return nothing
1019989
end
1020990

1021991
macro pass(name, expr)
@@ -1459,7 +1429,7 @@ function statement_or_branch_cost(@nospecialize(stmt), line::Int, src::Union{Cod
14591429
return thiscost
14601430
end
14611431

1462-
function inline_cost(ir::IRCode, params::OptimizationParams, cost_threshold::Int)
1432+
function inline_cost_model(ir::IRCode, params::OptimizationParams, cost_threshold::Int)
14631433
bodycost = 0
14641434
for i = 1:length(ir.stmts)
14651435
stmt = ir[SSAValue(i)][:stmt]

Compiler/src/ssair/inlining.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -976,7 +976,7 @@ function retrieve_ir_for_inlining(mi::MethodInstance, ir::IRCode, preserve_local
976976
return ir, spec_info, DebugInfo(ir.debuginfo, length(ir.stmts))
977977
end
978978
function retrieve_ir_for_inlining(mi::MethodInstance, opt::OptimizationState, preserve_local_sources::Bool)
979-
result = opt.result
979+
result = opt.optresult
980980
if result !== nothing
981981
!result.simplified && simplify_ir!(result)
982982
return retrieve_ir_for_inlining(mi, result.ir, preserve_local_sources)

Compiler/src/typeinfer.jl

+108-23
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,10 @@ end
104104
function finish!(interp::AbstractInterpreter, caller::InferenceState, validation_world::UInt, time_before::UInt64)
105105
result = caller.result
106106
#@assert last(result.valid_worlds) <= get_world_counter() || isempty(caller.edges)
107-
if isdefined(result, :ci)
107+
if caller.cache_mode === CACHE_MODE_LOCAL
108+
@assert !isdefined(result, :ci)
109+
result.src = transform_result_for_local_cache(interp, result)
110+
elseif isdefined(result, :ci)
108111
edges = result_edges(interp, caller)
109112
ci = result.ci
110113
# if we aren't cached, we don't need this edge
@@ -115,11 +118,16 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState, validation
115118
store_backedges(ci, edges)
116119
end
117120
inferred_result = nothing
118-
uncompressed = inferred_result
121+
uncompressed = result.src
119122
const_flag = is_result_constabi_eligible(result)
123+
debuginfo = get_debuginfo(result.src)
120124
discard_src = caller.cache_mode === CACHE_MODE_NULL || const_flag
121125
if !discard_src
122126
inferred_result = transform_result_for_cache(interp, result, edges)
127+
if inferred_result !== nothing
128+
uncompressed = inferred_result
129+
debuginfo = get_debuginfo(inferred_result)
130+
end
123131
# TODO: do we want to augment edges here with any :invoke targets that we got from inlining (such that we didn't have a direct edge to it already)?
124132
if inferred_result isa CodeInfo
125133
result.src = inferred_result
@@ -128,27 +136,28 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState, validation
128136
resize!(inferred_result.slottypes::Vector{Any}, nslots)
129137
resize!(inferred_result.slotnames, nslots)
130138
end
131-
di = inferred_result.debuginfo
132-
uncompressed = inferred_result
133139
inferred_result = maybe_compress_codeinfo(interp, result.linfo, inferred_result)
134140
result.is_src_volatile = false
135141
elseif ci.owner === nothing
136142
# The global cache can only handle objects that codegen understands
137143
inferred_result = nothing
138144
end
139145
end
140-
if !@isdefined di
141-
di = DebugInfo(result.linfo)
146+
if debuginfo === nothing
147+
debuginfo = DebugInfo(result.linfo)
142148
end
143149
time_now = _time_ns()
144150
time_self_ns = caller.time_self_ns + (time_now - time_before)
145151
time_total = (time_now - caller.time_start - caller.time_paused) * 1e-9
146152
ccall(:jl_update_codeinst, Cvoid, (Any, Any, Int32, UInt, UInt, UInt32, Any, Float64, Float64, Float64, Any, Any),
147153
ci, inferred_result, const_flag, first(result.valid_worlds), last(result.valid_worlds), encode_effects(result.ipo_effects),
148-
result.analysis_results, time_total, caller.time_caches, time_self_ns * 1e-9, di, edges)
154+
result.analysis_results, time_total, caller.time_caches, time_self_ns * 1e-9, debuginfo, edges)
149155
engine_reject(interp, ci)
150156
codegen = codegen_cache(interp)
151-
if !discard_src && codegen !== nothing && uncompressed isa CodeInfo
157+
if !discard_src && codegen !== nothing && (isa(uncompressed, CodeInfo) || isa(uncompressed, OptimizationState))
158+
if isa(uncompressed, OptimizationState)
159+
uncompressed = ir_to_codeinf!(uncompressed, edges)
160+
end
152161
# record that the caller could use this result to generate code when required, if desired, to avoid repeating n^2 work
153162
codegen[ci] = uncompressed
154163
if bootstrapping_compiler && inferred_result == nothing
@@ -299,36 +308,113 @@ function adjust_cycle_frame!(sv::InferenceState, cycle_valid_worlds::WorldRange,
299308
return nothing
300309
end
301310

311+
function get_debuginfo(src)
312+
isa(src, CodeInfo) && return src.debuginfo
313+
isa(src, OptimizationState) && return src.src.debuginfo
314+
return nothing
315+
end
316+
302317
function is_result_constabi_eligible(result::InferenceResult)
303318
result_type = result.result
304319
return isa(result_type, Const) && is_foldable_nothrow(result.ipo_effects) && is_inlineable_constant(result_type.val)
305320
end
306321

307-
function transform_result_for_cache(::AbstractInterpreter, result::InferenceResult, edges::SimpleVector)
322+
function compute_inlining_cost(interp::AbstractInterpreter, result::InferenceResult)
323+
src = result.src
324+
isa(src, OptimizationState) || return MAX_INLINE_COST
325+
compute_inlining_cost(interp, result, src.optresult)
326+
end
327+
328+
function compute_inlining_cost(interp::AbstractInterpreter, result::InferenceResult, optresult#=::OptimizationResult=#)
329+
return inline_cost_model(interp, result, optresult.inline_flag, optresult.ir)
330+
end
331+
332+
function inline_cost_model(interp::AbstractInterpreter, result::InferenceResult,
333+
inline_flag::UInt8, ir::IRCode)
334+
335+
inline_flag === SRC_FLAG_DECLARED_NOINLINE && return MAX_INLINE_COST
336+
337+
mi = result.linfo
338+
(; def, specTypes) = mi
339+
if !isa(def, Method)
340+
return MAX_INLINE_COST
341+
end
342+
343+
declared_inline = inline_flag === SRC_FLAG_DECLARED_INLINE
344+
345+
rt = result.result
346+
@assert !(rt isa LimitedAccuracy)
347+
rt = widenslotwrapper(rt)
348+
349+
sig = unwrap_unionall(specTypes)
350+
if !(isa(sig, DataType) && sig.name === Tuple.name)
351+
return MAX_INLINE_COST
352+
end
353+
if !declared_inline && rt === Bottom
354+
return MAX_INLINE_COST
355+
end
356+
357+
if declared_inline && isdispatchtuple(specTypes)
358+
# obey @inline declaration if a dispatch barrier would not help
359+
return MIN_INLINE_COST
360+
else
361+
# compute the cost (size) of inlining this code
362+
params = OptimizationParams(interp)
363+
cost_threshold = default = params.inline_cost_threshold
364+
if (optimizer_lattice(interp), rt, Tuple) && !isconcretetype(widenconst(rt))
365+
cost_threshold += params.inline_tupleret_bonus
366+
end
367+
# if the method is declared as `@inline`, increase the cost threshold 20x
368+
if declared_inline
369+
cost_threshold += 19*default
370+
end
371+
# a few functions get special treatment
372+
if def.module === _topmod(def.module)
373+
name = def.name
374+
if name === :iterate || name === :unsafe_convert || name === :cconvert
375+
cost_threshold += 4*default
376+
end
377+
end
378+
return inline_cost_model(ir, params, cost_threshold)
379+
end
380+
end
381+
382+
function transform_result_for_local_cache(interp::AbstractInterpreter, result::InferenceResult)
308383
src = result.src
309384
if isa(src, OptimizationState)
310-
src = ir_to_codeinf!(src)
385+
# Compute and store any information required to determine the inlineability of the callee.
386+
opt = src
387+
opt.src.inlining_cost = compute_inlining_cost(interp, result)
388+
end
389+
return src
390+
end
391+
392+
function transform_result_for_cache(interp::AbstractInterpreter, result::InferenceResult, edges::SimpleVector)
393+
inlining_cost = nothing
394+
src = result.src
395+
if isa(src, OptimizationState)
396+
opt = src
397+
inlining_cost = compute_inlining_cost(interp, result, opt.optresult)
398+
discard_optimized_result(interp, opt, inlining_cost) && return nothing
399+
src = ir_to_codeinf!(opt)
311400
end
312401
if isa(src, CodeInfo)
313402
src.edges = edges
403+
src.inlining_cost = inlining_cost !== nothing ? inlining_cost : compute_inlining_cost(interp, result)
314404
end
315405
return src
316406
end
317407

408+
function discard_optimized_result(interp::AbstractInterpreter, opt#=::OptimizationState=#, inlining_cost#=::InlineCostType=#)
409+
may_discard_trees(interp) || return false
410+
return inlining_cost == MAX_INLINE_COST
411+
end
412+
318413
function maybe_compress_codeinfo(interp::AbstractInterpreter, mi::MethodInstance, ci::CodeInfo)
319414
def = mi.def
320415
isa(def, Method) || return ci # don't compress toplevel code
321-
can_discard_trees = may_discard_trees(interp)
322-
cache_the_tree = !can_discard_trees || is_inlineable(ci)
323-
if cache_the_tree
324-
if may_compress(interp)
325-
return ccall(:jl_compress_ir, String, (Any, Any), def, ci)
326-
else
327-
return ci
328-
end
329-
else
330-
return nothing
331-
end
416+
may_compress(interp) && return ccall(:jl_compress_ir, String, (Any, Any), def, ci)
417+
return ci
332418
end
333419

334420
function cache_result!(interp::AbstractInterpreter, result::InferenceResult, ci::CodeInstance)
@@ -1101,8 +1187,7 @@ function typeinf_frame(interp::AbstractInterpreter, mi::MethodInstance, run_opti
11011187
else
11021188
opt = OptimizationState(frame, interp)
11031189
optimize(interp, opt, frame.result)
1104-
src = ir_to_codeinf!(opt)
1105-
src.edges = Core.svec(opt.inlining.edges...)
1190+
src = ir_to_codeinf!(opt, frame, Core.svec(opt.inlining.edges...))
11061191
end
11071192
result.src = frame.src = src
11081193
end

Compiler/test/codegen.jl

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using Random
66
using InteractiveUtils
7+
using InteractiveUtils: code_llvm, code_native
78
using Libdl
89
using Test
910

Compiler/test/inline.jl

+4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

3+
module inline_tests
4+
35
using Test
46
using Base.Meta
57
using Core: ReturnNode
@@ -2311,3 +2313,5 @@ g_noinline_invoke(x) = f_noinline_invoke(x)
23112313
let src = code_typed1(g_noinline_invoke, (Union{Symbol,Nothing},))
23122314
@test !any(@nospecialize(x)->isa(x,GlobalRef), src.code)
23132315
end
2316+
2317+
end # module inline_tests

0 commit comments

Comments
 (0)