Skip to content

Commit 7569d68

Browse files
committed
AbstractInterpreter: refactor the lifetimes of OptimizationState and IRCode
This commit limits the lifetimes of `OptimizationState` and `IRCode` for a more dataflow clarity. It also avoids duplicated calls of `ir_to_codeinf!`. Note that external `AbstractInterpreter`s can still extend their lifetimes to cache additional information, as described by this newly added documentation of `finish!`: > finish!(interp::AbstractInterpreter, > opt::OptimizationState, ir::IRCode, caller::InferenceResult) > > Runs post-Julia-level optimization process and caches information for later uses: > - computes "purity" (i.e. side-effect-freeness) of the optimized frame > - computes inlining cost and cache the inlineability in `opt.src.inlineable` > - stores the result of optimization in `caller.src` > * by default, `caller.src` will be an optimized `CodeInfo` object transformed from `ir` > * in a case when this frame has been proven pure, `ConstAPI` object wrapping the constant > value will be kept in `caller.src` instead, so that the runtime system will use > the constant calling convention > > !!! note > The lifetimes of `opt` and `ir` end by the end of this process. > Still external `AbstractInterpreter` can override this method as necessary to cache them. > Note that `transform_result_for_cache` should be overloaded also in such cases, > otherwise the default `transform_result_for_cache` implmentation will discard any information > other than `CodeInfo`, `Vector{UInt8}` or `ConstAPI`. This commit also adds a new overload `infresult_iterator` so that external interpreters can tweak the behavior of post processings of `_typeinf`. Especially, this change is motivated by the need for JET, whose post-optimization processing needs references of `InferenceState`.
1 parent a7beb93 commit 7569d68

File tree

2 files changed

+77
-64
lines changed

2 files changed

+77
-64
lines changed

base/compiler/optimize.jl

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,9 @@ end
5454

5555
include("compiler/ssair/driver.jl")
5656

57-
mutable struct OptimizationState
57+
struct OptimizationState
5858
linfo::MethodInstance
5959
src::CodeInfo
60-
ir::Union{Nothing, IRCode}
6160
stmt_info::Vector{Any}
6261
mod::Module
6362
sptypes::Vector{Any} # static parameters
@@ -69,8 +68,7 @@ mutable struct OptimizationState
6968
EdgeTracker(s_edges, frame.valid_worlds),
7069
WorldView(code_cache(interp), frame.world),
7170
interp)
72-
return new(frame.linfo,
73-
frame.src, nothing, frame.stmt_info, frame.mod,
71+
return new(frame.linfo, frame.src, frame.stmt_info, frame.mod,
7472
frame.sptypes, frame.slottypes, inlining)
7573
end
7674
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams, interp::AbstractInterpreter)
@@ -97,8 +95,7 @@ mutable struct OptimizationState
9795
nothing,
9896
WorldView(code_cache(interp), get_world_counter()),
9997
interp)
100-
return new(linfo,
101-
src, nothing, stmt_info, mod,
98+
return new(linfo, src, stmt_info, mod,
10299
sptypes_from_meth_instance(linfo), slottypes, inlining)
103100
end
104101
end
@@ -109,11 +106,10 @@ function OptimizationState(linfo::MethodInstance, params::OptimizationParams, in
109106
return OptimizationState(linfo, src, params, interp)
110107
end
111108

112-
function ir_to_codeinf!(opt::OptimizationState)
109+
function ir_to_codeinf!(opt::OptimizationState, ir::IRCode)
113110
(; linfo, src) = opt
114111
optdef = linfo.def
115-
replace_code_newstyle!(src, opt.ir::IRCode, isa(optdef, Method) ? Int(optdef.nargs) : 0)
116-
opt.ir = nothing
112+
replace_code_newstyle!(src, ir, isa(optdef, Method) ? Int(optdef.nargs) : 0)
117113
widen_all_consts!(src)
118114
src.inferred = true
119115
# finish updating the result struct
@@ -383,18 +379,27 @@ struct ConstAPI
383379
end
384380

385381
"""
386-
finish(interp::AbstractInterpreter, opt::OptimizationState,
387-
params::OptimizationParams, ir::IRCode, caller::InferenceResult) -> analyzed::Union{Nothing,ConstAPI}
388-
389-
Post process information derived by Julia-level optimizations for later uses:
390-
- computes "purity", i.e. side-effect-freeness
391-
- computes inlining cost
392-
393-
In a case when the purity is proven, `finish` can return `ConstAPI` object wrapping the constant
394-
value so that the runtime system will use the constant calling convention for the method calls.
382+
finish!(interp::AbstractInterpreter,
383+
opt::OptimizationState, ir::IRCode, caller::InferenceResult)
384+
385+
Runs post-Julia-level optimization process and caches information for later uses:
386+
- computes "purity" (i.e. side-effect-freeness) of the optimized frame
387+
- computes inlining cost and cache the inlineability in `opt.src.inlineable`
388+
- stores the result of optimization in `caller.src`
389+
* by default, `caller.src` will be an optimized `CodeInfo` object transformed from `ir`
390+
* in a case when this frame has been proven pure, `ConstAPI` object wrapping the constant
391+
value will be kept in `caller.src` instead, so that the runtime system will use
392+
the constant calling convention
393+
394+
!!! note
395+
The lifetimes of `opt` and `ir` end by the end of this process.
396+
Still external `AbstractInterpreter` can override this method as necessary to cache them.
397+
Note that `transform_result_for_cache` should be overloaded also in such cases,
398+
otherwise the default `transform_result_for_cache` implmentation will discard any information
399+
other than `CodeInfo`, `Vector{UInt8}` or `ConstAPI`.
395400
"""
396-
function finish(interp::AbstractInterpreter, opt::OptimizationState,
397-
params::OptimizationParams, ir::IRCode, caller::InferenceResult)
401+
function finish!(interp::AbstractInterpreter,
402+
opt::OptimizationState, ir::IRCode, caller::InferenceResult)
398403
(; src, linfo) = opt
399404
(; def, specTypes) = linfo
400405

@@ -452,8 +457,6 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState,
452457
end
453458
end
454459

455-
opt.ir = ir
456-
457460
# determine and cache inlineability
458461
union_penalties = false
459462
if !force_noinline
@@ -480,6 +483,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState,
480483
# obey @inline declaration if a dispatch barrier would not help
481484
else
482485
# compute the cost (size) of inlining this code
486+
params = opt.inlining.params
483487
cost_threshold = default = params.inline_cost_threshold
484488
if result Tuple && !isconcretetype(widenconst(result))
485489
cost_threshold += params.inline_tupleret_bonus
@@ -499,14 +503,27 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState,
499503
end
500504
end
501505

502-
return analyzed
506+
caller.valid_worlds = (opt.inlining.et::EdgeTracker).valid_worlds[]
507+
508+
if isa(analyzed, ConstAPI)
509+
# XXX: The work in ir_to_codeinf! is essentially wasted. The only reason
510+
# we're doing it is so that code_llvm can return the code
511+
# for the `return ...::Const` (which never runs anyway). We should do this
512+
# as a post processing step instead.
513+
ir_to_codeinf!(opt, ir)
514+
caller.src = analyzed
515+
else
516+
caller.src = ir_to_codeinf!(opt, ir)
517+
end
518+
519+
return nothing
503520
end
504521

505522
# run the optimization work
506-
function optimize(interp::AbstractInterpreter, opt::OptimizationState,
507-
params::OptimizationParams, caller::InferenceResult)
523+
function optimize!(interp::AbstractInterpreter,
524+
opt::OptimizationState, caller::InferenceResult)
508525
@timeit "optimizer" ir = run_passes(opt.src, opt)
509-
return finish(interp, opt, params, ir, caller)
526+
@timeit "finish!" finish!(interp, opt, ir, caller)
510527
end
511528

512529
function run_passes(ci::CodeInfo, sv::OptimizationState)

base/compiler/typeinfer.jl

Lines changed: 34 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -210,23 +210,20 @@ function typeinf(interp::AbstractInterpreter, frame::InferenceState)
210210
end
211211
end
212212

213-
function finish!(interp::AbstractInterpreter, caller::InferenceResult)
214-
# If we didn't transform the src for caching, we may have to transform
215-
# it anyway for users like typeinf_ext. Do that here.
216-
opt = caller.src
217-
if opt isa OptimizationState # implies `may_optimize(interp) === true`
218-
if opt.ir !== nothing
219-
caller.src = ir_to_codeinf!(opt)
220-
end
221-
end
222-
return caller.src
223-
end
224-
225213
function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
226214
typeinf_nocycle(interp, frame) || return false # frame is now part of a higher cycle
227215
# with no active ip's, frame is done
228216
frames = frame.callers_in_cycle
229217
isempty(frames) && push!(frames, frame)
218+
finish_infstates!(interp, frames)
219+
# collect results for the new expanded frame
220+
results = infresult_iterator(interp, frames)
221+
optimize!(interp, results)
222+
cache_results!(interp, results)
223+
return true
224+
end
225+
226+
function finish_infstates!(interp::AbstractInterpreter, frames::Vector{InferenceState})
230227
valid_worlds = WorldRange()
231228
for caller in frames
232229
@assert !(caller.dont_work_on_me)
@@ -240,29 +237,35 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
240237
# finalize and record the linfo result
241238
caller.inferred = true
242239
end
243-
# collect results for the new expanded frame
244-
results = Tuple{InferenceResult, Vector{Any}, Bool}[
245-
( frames[i].result,
246-
frames[i].stmt_edges[1]::Vector{Any},
247-
frames[i].cached )
248-
for i in 1:length(frames) ]
249-
empty!(frames)
250-
for (caller, _, _) in results
240+
end
241+
242+
struct InfResultInfo
243+
caller::InferenceResult
244+
edges::Vector{Any}
245+
cached::Bool
246+
end
247+
248+
# returns iterator on which `optimize!` and `postopt_process!` work on
249+
function infresult_iterator(_::AbstractInterpreter, frames::Vector{InferenceState})
250+
results = InfResultInfo[ InfResultInfo(
251+
frames[i].result,
252+
frames[i].stmt_edges[1]::Vector{Any},
253+
frames[i].cached ) for i in 1:length(frames) ]
254+
empty!(frames) # discard `InferenceState` now
255+
return results
256+
end
257+
258+
function optimize!(interp::AbstractInterpreter, results::Vector{InfResultInfo})
259+
for (; caller) in results
251260
opt = caller.src
252261
if opt isa OptimizationState # implies `may_optimize(interp) === true`
253-
analyzed = optimize(interp, opt, OptimizationParams(interp), caller)
254-
if isa(analyzed, ConstAPI)
255-
# XXX: The work in ir_to_codeinf! is essentially wasted. The only reason
256-
# we're doing it is so that code_llvm can return the code
257-
# for the `return ...::Const` (which never runs anyway). We should do this
258-
# as a post processing step instead.
259-
ir_to_codeinf!(opt)
260-
caller.src = analyzed
261-
end
262-
caller.valid_worlds = (opt.inlining.et::EdgeTracker).valid_worlds[]
262+
optimize!(interp, opt, caller)
263263
end
264264
end
265-
for (caller, edges, cached) in results
265+
end
266+
267+
function cache_results!(interp::AbstractInterpreter, results::Vector{InfResultInfo})
268+
for (; caller, edges, cached) in results
266269
valid_worlds = caller.valid_worlds
267270
if last(valid_worlds) >= get_world_counter()
268271
# if we aren't cached, we don't need this edge
@@ -272,9 +275,7 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
272275
if cached
273276
cache_result!(interp, caller)
274277
end
275-
finish!(interp, caller)
276278
end
277-
return true
278279
end
279280

280281
function CodeInstance(result::InferenceResult, @nospecialize(inferred_result),
@@ -349,11 +350,6 @@ end
349350

350351
function transform_result_for_cache(interp::AbstractInterpreter, linfo::MethodInstance,
351352
valid_worlds::WorldRange, @nospecialize(inferred_result))
352-
# If we decided not to optimize, drop the OptimizationState now.
353-
# External interpreters can override as necessary to cache additional information
354-
if inferred_result isa OptimizationState
355-
inferred_result = ir_to_codeinf!(inferred_result)
356-
end
357353
if inferred_result isa CodeInfo
358354
inferred_result.min_world = first(valid_worlds)
359355
inferred_result.max_world = last(valid_worlds)

0 commit comments

Comments
 (0)