Skip to content

Commit 8eb7175

Browse files
committed
AbstractInterpreter: refactor the lifetimes of OptimizationState and IRCode
This commit limits the lifetimes of `OptimizationState` and `IRCode` for a more dataflow clarity. It also avoids duplicated calls of `ir_to_codeinf!`. Note that external `AbstractInterpreter`s can still extend their lifetimes to cache additional information, as described by this newly added documentation of `finish!`: > finish!(interp::AbstractInterpreter, > opt::OptimizationState, ir::IRCode, caller::InferenceResult) > > Runs post-Julia-level optimization process and caches information for later uses: > - computes "purity" (i.e. side-effect-freeness) of the optimized frame > - computes inlining cost and cache the inlineability in `opt.src.inlineable` > - stores the result of optimization in `caller.src` > * by default, `caller.src` will be an optimized `CodeInfo` object transformed from `ir` > * in a case when this frame has been proven pure, `ConstAPI` object wrapping the constant > value will be kept in `caller.src` instead, so that the runtime system will use > the constant calling convention > > !!! note > The lifetimes of `opt` and `ir` end by the end of this process. > Still external `AbstractInterpreter` can override this method as necessary to cache them. > Note that `transform_result_for_cache` should be overloaded also in such cases, > otherwise the default `transform_result_for_cache` implmentation will discard any information > other than `CodeInfo`, `Vector{UInt8}` or `ConstAPI`. This commit also adds a new overload `infresult_iterator` so that external interpreters can tweak the behavior of post processings of `_typeinf`. Especially, this change is motivated by the need for JET, whose post-optimization processing needs references of `InferenceState`.
1 parent 89a613b commit 8eb7175

File tree

3 files changed

+165
-147
lines changed

3 files changed

+165
-147
lines changed

base/compiler/optimize.jl

+125-104
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,9 @@ end
8484

8585
include("compiler/ssair/driver.jl")
8686

87-
mutable struct OptimizationState
87+
struct OptimizationState
8888
linfo::MethodInstance
8989
src::CodeInfo
90-
ir::Union{Nothing, IRCode}
9190
stmt_info::Vector{Any}
9291
mod::Module
9392
sptypes::Vector{Any} # static parameters
@@ -99,8 +98,7 @@ mutable struct OptimizationState
9998
EdgeTracker(s_edges, frame.valid_worlds),
10099
WorldView(code_cache(interp), frame.world),
101100
interp)
102-
return new(frame.linfo,
103-
frame.src, nothing, frame.stmt_info, frame.mod,
101+
return new(frame.linfo, frame.src, frame.stmt_info, frame.mod,
104102
frame.sptypes, frame.slottypes, inlining)
105103
end
106104
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams, interp::AbstractInterpreter)
@@ -127,8 +125,7 @@ mutable struct OptimizationState
127125
nothing,
128126
WorldView(code_cache(interp), get_world_counter()),
129127
interp)
130-
return new(linfo,
131-
src, nothing, stmt_info, mod,
128+
return new(linfo, src, stmt_info, mod,
132129
sptypes_from_meth_instance(linfo), slottypes, inlining)
133130
end
134131
end
@@ -139,11 +136,10 @@ function OptimizationState(linfo::MethodInstance, params::OptimizationParams, in
139136
return OptimizationState(linfo, src, params, interp)
140137
end
141138

142-
function ir_to_codeinf!(opt::OptimizationState)
139+
function ir_to_codeinf!(opt::OptimizationState, ir::IRCode)
143140
(; linfo, src) = opt
144141
optdef = linfo.def
145-
replace_code_newstyle!(src, opt.ir::IRCode, isa(optdef, Method) ? Int(optdef.nargs) : 0)
146-
opt.ir = nothing
142+
replace_code_newstyle!(src, ir, isa(optdef, Method) ? Int(optdef.nargs) : 0)
147143
widen_all_consts!(src)
148144
src.inferred = true
149145
# finish updating the result struct
@@ -380,130 +376,155 @@ struct ConstAPI
380376
end
381377

382378
"""
383-
finish(interp::AbstractInterpreter, opt::OptimizationState,
384-
params::OptimizationParams, ir::IRCode, caller::InferenceResult) -> analyzed::Union{Nothing,ConstAPI}
385-
386-
Post process information derived by Julia-level optimizations for later uses:
387-
- computes "purity", i.e. side-effect-freeness
388-
- computes inlining cost
389-
390-
In a case when the purity is proven, `finish` can return `ConstAPI` object wrapping the constant
391-
value so that the runtime system will use the constant calling convention for the method calls.
379+
finish!(interp::AbstractInterpreter,
380+
opt::OptimizationState, ir::IRCode, caller::InferenceResult)
381+
382+
Runs post-Julia-level optimization process and caches information for later uses:
383+
- computes "purity" (i.e. side-effect-freeness) of the optimized frame
384+
- computes inlining cost and cache the inlineability in `opt.src.inlineable`
385+
- stores the result of optimization in `caller.src`
386+
* by default, `caller.src` will be an optimized `CodeInfo` object transformed from `ir`
387+
* in a case when this frame has been proven pure, `ConstAPI` object wrapping the constant
388+
value will be kept in `caller.src` instead, so that the runtime system will use
389+
the constant calling convention
390+
391+
!!! note
392+
The lifetimes of `opt` and `ir` end by the end of this process.
393+
Still external `AbstractInterpreter` can override `transform_optresult_for_cache`
394+
as necessary to cache them. Note that `transform_result_for_cache` should be overloaded
395+
also in such cases, otherwise the default implmentation of `transform_result_for_cache`
396+
will discard any information other than `CodeInfo`, `Vector{UInt8}` or `ConstAPI`.
392397
"""
393-
function finish(interp::AbstractInterpreter, opt::OptimizationState,
394-
params::OptimizationParams, ir::IRCode, caller::InferenceResult)
395-
(; src, linfo) = opt
396-
(; def, specTypes) = linfo
397-
398-
analyzed = nothing # `ConstAPI` if this call can use constant calling convention
399-
force_noinline = _any(@nospecialize(x) -> isexpr(x, :meta) && x.args[1] === :noinline, ir.meta)
398+
function finish!(interp::AbstractInterpreter,
399+
opt::OptimizationState, ir::IRCode, caller::InferenceResult)
400+
src = opt.src
400401

401-
# compute inlining and other related optimizations
402402
result = caller.result
403403
@assert !(result isa LimitedAccuracy)
404404
result = isa(result, InterConditional) ? widenconditional(result) : result
405-
if (isa(result, Const) || isconstType(result))
406-
proven_pure = false
407-
# must be proven pure to use constant calling convention;
408-
# otherwise we might skip throwing errors (issue #20704)
409-
# TODO: Improve this analysis; if a function is marked @pure we should really
410-
# only care about certain errors (e.g. method errors and type errors).
411-
if length(ir.stmts) < 15
412-
proven_pure = true
413-
for i in 1:length(ir.stmts)
414-
node = ir.stmts[i]
415-
stmt = node[:inst]
416-
if stmt_affects_purity(stmt, ir) && !stmt_effect_free(stmt, node[:type], ir)
417-
proven_pure = false
418-
break
419-
end
420-
end
421-
if proven_pure
422-
for fl in src.slotflags
423-
if (fl & SLOT_USEDUNDEF) != 0
424-
proven_pure = false
425-
break
426-
end
427-
end
428-
end
429-
end
430405

431-
if proven_pure
432-
# use constant calling convention
433-
# Do not emit `jl_fptr_const_return` if coverage is enabled
434-
# so that we don't need to add coverage support
435-
# to the `jl_call_method_internal` fast path
436-
# Still set pure flag to make sure `inference` tests pass
437-
# and to possibly enable more optimization in the future
438-
src.pure = true
406+
newresult = nothing # ConstAPI if this call can use constant calling convention
407+
if isa(result, Const) || isconstType(result)
408+
# computes "purity" (i.e. side-effect-freeness)
409+
if compute_purity(ir, src)
410+
src.inlineable = src.pure = true
411+
412+
# must be proven pure to use constant calling convention;
413+
# otherwise we might skip throwing errors (issue #20704)
439414
if isa(result, Const)
440415
val = result.val
441416
if is_inlineable_constant(val)
442-
analyzed = ConstAPI(val)
417+
newresult = ConstAPI(val)
443418
end
444419
else
445420
@assert isconstType(result)
446-
analyzed = ConstAPI(result.parameters[1])
421+
newresult = ConstAPI(result.parameters[1])
447422
end
448-
force_noinline || (src.inlineable = true)
449423
end
450424
end
451425

452-
opt.ir = ir
453-
454426
# determine and cache inlineability
455-
union_penalties = false
456-
if !force_noinline
457-
sig = unwrap_unionall(specTypes)
458-
if isa(sig, DataType) && sig.name === Tuple.name
459-
for P in sig.parameters
460-
P = unwrap_unionall(P)
461-
if isa(P, Union)
462-
union_penalties = true
463-
break
464-
end
427+
src.inlineable = compute_inlineability(ir, opt, result, src.inlineable)
428+
429+
caller.valid_worlds = (opt.inlining.et::EdgeTracker).valid_worlds[]
430+
431+
caller.src = transform_optresult_for_cache(interp, opt, ir, newresult)
432+
433+
return nothing
434+
end
435+
436+
function compute_purity(ir::IRCode, src::CodeInfo)
437+
# TODO: Improve this analysis; if a function is marked @pure we should really
438+
# only care about certain errors (e.g. method errors and type errors).
439+
if length(ir.stmts) < 15
440+
for i in 1:length(ir.stmts)
441+
node = ir.stmts[i]
442+
stmt = node[:inst]
443+
if stmt_affects_purity(stmt, ir) && !stmt_effect_free(stmt, node[:type], ir)
444+
return false
465445
end
466-
else
467-
force_noinline = true
468446
end
469-
if !src.inlineable && result === Bottom
470-
force_noinline = true
447+
for flag in src.slotflags
448+
if (flag & SLOT_USEDUNDEF) != 0
449+
return false
450+
end
471451
end
452+
return true
472453
end
473-
if force_noinline
474-
src.inlineable = false
475-
elseif isa(def, Method)
476-
if src.inlineable && isdispatchtuple(specTypes)
477-
# obey @inline declaration if a dispatch barrier would not help
478-
else
479-
# compute the cost (size) of inlining this code
480-
cost_threshold = default = params.inline_cost_threshold
481-
if result Tuple && !isconcretetype(widenconst(result))
482-
cost_threshold += params.inline_tupleret_bonus
483-
end
484-
# if the method is declared as `@inline`, increase the cost threshold 20x
485-
if src.inlineable
486-
cost_threshold += 19*default
487-
end
488-
# a few functions get special treatment
489-
if def.module === _topmod(def.module)
490-
name = def.name
491-
if name === :iterate || name === :unsafe_convert || name === :cconvert
492-
cost_threshold += 4*default
493-
end
454+
return false
455+
end
456+
457+
function compute_inlineability(ir::IRCode, opt::OptimizationState, @nospecialize(result),
458+
declared_inlineability::Bool)
459+
(; def, specTypes) = opt.linfo
460+
force_noinline = _any(@nospecialize(x) -> isexpr(x, :meta) && x.args[1] === :noinline, ir.meta)
461+
force_noinline && return false
462+
union_penalties = false
463+
sig = unwrap_unionall(specTypes)
464+
if isa(sig, DataType) && sig.name === Tuple.name
465+
for P in sig.parameters
466+
P = unwrap_unionall(P)
467+
if isa(P, Union)
468+
union_penalties = true
469+
break
494470
end
495-
src.inlineable = inline_worthy(ir, params, union_penalties, cost_threshold)
496471
end
472+
else
473+
return false
474+
end
475+
if !declared_inlineability && result === Bottom
476+
return false
477+
end
478+
isa(def, Method) || return declared_inlineability
479+
if declared_inlineability && isdispatchtuple(specTypes)
480+
# obey @inline declaration if a dispatch barrier would not help
481+
return true
482+
end
483+
# compute the cost (size) of inlining this code
484+
params = opt.inlining.params
485+
cost_threshold = default = params.inline_cost_threshold
486+
if result Tuple && !isconcretetype(widenconst(result))
487+
cost_threshold += params.inline_tupleret_bonus
497488
end
489+
# if the method is declared as `@inline`, increase the cost threshold 20x
490+
if declared_inlineability
491+
cost_threshold += 19*default
492+
end
493+
# a few functions get special treatment
494+
if def.module === _topmod(def.module)
495+
name = def.name
496+
if name === :iterate || name === :unsafe_convert || name === :cconvert
497+
cost_threshold += 4*default
498+
end
499+
end
500+
return inline_worthy(ir, params, union_penalties, cost_threshold)
501+
end
498502

499-
return analyzed
503+
function transform_optresult_for_cache(::AbstractInterpreter,
504+
opt::OptimizationState, ir::IRCode, @nospecialize(newresult))
505+
if isa(newresult, ConstAPI)
506+
# use constant calling convention
507+
# Do not emit `jl_fptr_const_return` if coverage is enabled
508+
# so that we don't need to add coverage support
509+
# to the `jl_call_method_internal` fast path
510+
# Still set pure flag to make sure `inference` tests pass
511+
# and to possibly enable more optimization in the future
512+
513+
# XXX: The work in ir_to_codeinf! is essentially wasted. The only reason
514+
# we're doing it is so that code_llvm can return the code
515+
# for the `return ...::Const` (which never runs anyway). We should do this
516+
# as a post processing step instead.
517+
ir_to_codeinf!(opt, ir)
518+
return newresult
519+
end
520+
return ir_to_codeinf!(opt, ir)
500521
end
501522

502523
# run the optimization work
503-
function optimize(interp::AbstractInterpreter, opt::OptimizationState,
504-
params::OptimizationParams, caller::InferenceResult)
524+
function optimize!(interp::AbstractInterpreter,
525+
opt::OptimizationState, caller::InferenceResult)
505526
@timeit "optimizer" ir = run_passes(opt.src, opt, caller)
506-
return finish(interp, opt, params, ir, caller)
527+
@timeit "finish!" finish!(interp, opt, ir, caller)
507528
end
508529

509530
using .EscapeAnalysis

0 commit comments

Comments
 (0)