Skip to content

Commit 53799df

Browse files
committed
make IRInterpretationState mutable
1 parent 8dd93f9 commit 53799df

File tree

3 files changed

+56
-59
lines changed

3 files changed

+56
-59
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -989,8 +989,9 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter,
989989
mi_cache = WorldView(code_cache(interp), world)
990990
code = get(mi_cache, mi, nothing)
991991
if code !== nothing
992-
irsv = IRInterpretationState(interp, code, mi, arginfo.argtypes, world, sv)
992+
irsv = IRInterpretationState(interp, code, mi, arginfo.argtypes, world)
993993
if irsv !== nothing
994+
irsv.parent = sv
994995
irinterp = isa(interp, NativeInterpreter) ? NativeInterpreter(interp; irinterp=true) : interp
995996
rt, nothrow = ir_abstract_constant_propagation(irinterp, irsv)
996997
@assert !(rt isa Conditional || rt isa MustAlias) "invalid lattice element returned from irinterp"

base/compiler/inferencestate.jl

Lines changed: 46 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -634,52 +634,47 @@ end
634634
# =====================
635635

636636
# TODO add `result::InferenceResult` and put the irinterp result into the inference cache?
637-
struct IRInterpretationState
638-
method_info::MethodInfo
639-
ir::IRCode
640-
mi::MethodInstance
641-
world::UInt
642-
curridx::RefValue{Int}
643-
argtypes_refined::Vector{Bool}
644-
sptypes::Vector{VarState}
645-
tpdum::TwoPhaseDefUseMap
646-
ssa_refined::BitSet
647-
lazydomtree::LazyDomtree
648-
valid_worlds::RefValue{WorldRange}
649-
edges::Vector{Any}
650-
parent # ::AbsIntState
651-
end
652-
653-
# AbsIntState
654-
# ===========
655-
656-
const AbsIntState = Union{InferenceState,IRInterpretationState}
637+
mutable struct IRInterpretationState
638+
const method_info::MethodInfo
639+
const ir::IRCode
640+
const mi::MethodInstance
641+
const world::UInt
642+
curridx::Int
643+
const argtypes_refined::Vector{Bool}
644+
const sptypes::Vector{VarState}
645+
const tpdum::TwoPhaseDefUseMap
646+
const ssa_refined::BitSet
647+
const lazydomtree::LazyDomtree
648+
valid_worlds::WorldRange
649+
const edges::Vector{Any}
650+
parent # ::Union{Nothing,AbsIntState}
657651

658-
function IRInterpretationState(interp::AbstractInterpreter,
659-
method_info::MethodInfo, ir::IRCode, mi::MethodInstance, argtypes::Vector{Any},
660-
world::UInt, min_world::UInt, max_world::UInt, parent::AbsIntState)
661-
curridx = RefValue(1)
662-
given_argtypes = Vector{Any}(undef, length(argtypes))
663-
for i = 1:length(given_argtypes)
664-
given_argtypes[i] = widenslotwrapper(argtypes[i])
652+
function IRInterpretationState(interp::AbstractInterpreter,
653+
method_info::MethodInfo, ir::IRCode, mi::MethodInstance, argtypes::Vector{Any},
654+
world::UInt, min_world::UInt, max_world::UInt)
655+
curridx = 1
656+
given_argtypes = Vector{Any}(undef, length(argtypes))
657+
for i = 1:length(given_argtypes)
658+
given_argtypes[i] = widenslotwrapper(argtypes[i])
659+
end
660+
given_argtypes = va_process_argtypes(optimizer_lattice(interp), given_argtypes, mi)
661+
argtypes_refined = Bool[!(optimizer_lattice(interp), ir.argtypes[i], given_argtypes[i])
662+
for i = 1:length(given_argtypes)]
663+
empty!(ir.argtypes)
664+
append!(ir.argtypes, given_argtypes)
665+
tpdum = TwoPhaseDefUseMap(length(ir.stmts))
666+
ssa_refined = BitSet()
667+
lazydomtree = LazyDomtree(ir)
668+
valid_worlds = WorldRange(min_world, max_world == typemax(UInt) ? get_world_counter() : max_world)
669+
edges = Any[]
670+
parent = nothing
671+
return new(method_info, ir, mi, world, curridx, argtypes_refined, ir.sptypes, tpdum,
672+
ssa_refined, lazydomtree, valid_worlds, edges, parent)
665673
end
666-
given_argtypes = va_process_argtypes(optimizer_lattice(interp), given_argtypes, mi)
667-
argtypes_refined = Bool[!(optimizer_lattice(interp), ir.argtypes[i], given_argtypes[i])
668-
for i = 1:length(given_argtypes)]
669-
empty!(ir.argtypes)
670-
append!(ir.argtypes, given_argtypes)
671-
tpdum = TwoPhaseDefUseMap(length(ir.stmts))
672-
ssa_refined = BitSet()
673-
lazydomtree = LazyDomtree(ir)
674-
valid_worlds = RefValue(WorldRange(min_world, max_world == typemax(UInt) ? get_world_counter() : max_world))
675-
edges = Any[]
676-
return IRInterpretationState(method_info, ir, mi, world, curridx, argtypes_refined,
677-
ir.sptypes, tpdum, ssa_refined, lazydomtree,
678-
valid_worlds, edges, parent)
679674
end
680675

681676
function IRInterpretationState(interp::AbstractInterpreter,
682-
code::CodeInstance, mi::MethodInstance, argtypes::Vector{Any}, world::UInt, parent::AbsIntState)
677+
code::CodeInstance, mi::MethodInstance, argtypes::Vector{Any}, world::UInt)
683678
@assert code.def === mi
684679
src = @atomic :monotonic code.inferred
685680
if isa(src, Vector{UInt8})
@@ -690,9 +685,14 @@ function IRInterpretationState(interp::AbstractInterpreter,
690685
method_info = MethodInfo(src)
691686
ir = inflate_ir(src, mi)
692687
return IRInterpretationState(interp, method_info, ir, mi, argtypes, world,
693-
src.min_world, src.max_world, parent)
688+
src.min_world, src.max_world)
694689
end
695690

691+
# AbsIntState
692+
# ===========
693+
694+
const AbsIntState = Union{InferenceState,IRInterpretationState}
695+
696696
frame_instance(sv::InferenceState) = sv.linfo
697697
frame_instance(sv::IRInterpretationState) = sv.mi
698698

@@ -736,16 +736,11 @@ has_conditional(𝕃::AbstractLattice, ::InferenceState) = has_conditional(𝕃)
736736
has_conditional(::AbstractLattice, ::IRInterpretationState) = false
737737

738738
# work towards converging the valid age range for sv
739-
function update_valid_age!(sv::InferenceState, valid_worlds::WorldRange)
739+
function update_valid_age!(sv::AbsIntState, valid_worlds::WorldRange)
740740
valid_worlds = sv.valid_worlds = intersect(valid_worlds, sv.valid_worlds)
741741
@assert sv.world in valid_worlds "invalid age range update"
742742
return valid_worlds
743743
end
744-
function update_valid_age!(irsv::IRInterpretationState, valid_worlds::WorldRange)
745-
valid_worlds = irsv.valid_worlds[] = intersect(valid_worlds, irsv.valid_worlds[])
746-
@assert irsv.world in valid_worlds "invalid age range update"
747-
return valid_worlds
748-
end
749744

750745
"""
751746
AbsIntStackUnwind(sv::AbsIntState)
@@ -799,13 +794,13 @@ function add_mt_backedge!(irsv::IRInterpretationState, mt::MethodTable, @nospeci
799794
end
800795

801796
get_curr_ssaflag(sv::InferenceState) = sv.src.ssaflags[sv.currpc]
802-
get_curr_ssaflag(sv::IRInterpretationState) = sv.ir.stmts[sv.curridx[]][:flag]
797+
get_curr_ssaflag(sv::IRInterpretationState) = sv.ir.stmts[sv.curridx][:flag]
803798

804799
add_curr_ssaflag!(sv::InferenceState, flag::UInt8) = sv.src.ssaflags[sv.currpc] |= flag
805-
add_curr_ssaflag!(sv::IRInterpretationState, flag::UInt8) = sv.ir.stmts[sv.curridx[]][:flag] |= flag
800+
add_curr_ssaflag!(sv::IRInterpretationState, flag::UInt8) = sv.ir.stmts[sv.curridx][:flag] |= flag
806801

807802
sub_curr_ssaflag!(sv::InferenceState, flag::UInt8) = sv.src.ssaflags[sv.currpc] &= ~flag
808-
sub_curr_ssaflag!(sv::IRInterpretationState, flag::UInt8) = sv.ir.stmts[sv.curridx[]][:flag] &= ~flag
803+
sub_curr_ssaflag!(sv::IRInterpretationState, flag::UInt8) = sv.ir.stmts[sv.curridx][:flag] &= ~flag
809804

810805
merge_effects!(::AbstractInterpreter, caller::InferenceState, effects::Effects) =
811806
caller.ipo_effects = merge_effects(caller.ipo_effects, effects)

base/compiler/ssair/irinterp.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ function concrete_eval_invoke(interp::AbstractInterpreter,
2626
if is_constprop_edge_recursed(mi, irsv)
2727
return Pair{Any,Bool}(nothing, is_nothrow(effects))
2828
end
29-
newirsv = IRInterpretationState(interp, code, mi, argtypes, world, irsv)
29+
newirsv = IRInterpretationState(interp, code, mi, argtypes, world)
3030
if newirsv !== nothing
31+
newirsv.parent = irsv
3132
return _ir_abstract_constant_propagation(interp, newirsv)
3233
end
3334
return Pair{Any,Bool}(nothing, is_nothrow(effects))
@@ -51,7 +52,7 @@ end
5152
function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, irsv::IRInterpretationState)
5253
si = StmtInfo(true) # TODO better job here?
5354
(; rt, effects, info) = abstract_call(interp, arginfo, si, irsv)
54-
irsv.ir.stmts[irsv.curridx[]][:info] = info
55+
irsv.ir.stmts[irsv.curridx][:info] = info
5556
return RTEffects(rt, effects)
5657
end
5758

@@ -199,7 +200,7 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR
199200
stmts = bbs[bb].stmts
200201
lstmt = last(stmts)
201202
for idx = stmts
202-
irsv.curridx[] = idx
203+
irsv.curridx = idx
203204
inst = ir.stmts[idx][:inst]
204205
typ = ir.stmts[idx][:type]
205206
any_refined = false
@@ -247,7 +248,7 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR
247248
stmts = bbs[bb].stmts
248249
lstmt = last(stmts)
249250
for idx = stmts
250-
irsv.curridx[] = idx
251+
irsv.curridx = idx
251252
inst = ir.stmts[idx][:inst]
252253
for ur in userefs(inst)
253254
val = ur[]
@@ -271,7 +272,7 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR
271272
stmts = bbs[bb].stmts
272273
lstmt = last(stmts)
273274
for idx = stmts
274-
irsv.curridx[] = idx
275+
irsv.curridx = idx
275276
inst = ir.stmts[idx][:inst]
276277
for ur in userefs(inst)
277278
val = ur[]
@@ -291,7 +292,7 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR
291292
end
292293
while !isempty(stmt_ip)
293294
idx = popfirst!(stmt_ip)
294-
irsv.curridx[] = idx
295+
irsv.curridx = idx
295296
inst = ir.stmts[idx][:inst]
296297
typ = ir.stmts[idx][:type]
297298
if reprocess_instruction!(interp,
@@ -323,7 +324,7 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR
323324
end
324325
end
325326

326-
if last(irsv.valid_worlds[]) >= get_world_counter()
327+
if last(irsv.valid_worlds) >= get_world_counter()
327328
# if we aren't cached, we don't need this edge
328329
# but our caller might, so let's just make it anyways
329330
store_backedges(frame_instance(irsv), irsv.edges)

0 commit comments

Comments
 (0)