Skip to content

Commit 1e20c9c

Browse files
authored
sroa: Handle looking through chains of KeyValue instances (#52369)
Addresses an outstanding todo from the KeyValue PR and allows (once all the PRs are merged), optimization when multiple ScopedValues are used `with(a=>1, b=>2)`, etc. To facilitate this, in addition to the sroa adjustment, the ScopedValue code is adjusted to unroll the PersistentDict creation so that the optimizer can see the full chain (we do not support loops in the optimizer).
1 parent 727142a commit 1e20c9c

File tree

3 files changed

+48
-34
lines changed

3 files changed

+48
-34
lines changed

base/compiler/ssair/passes.jl

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -866,36 +866,38 @@ function lift_leaves_keyvalue(compact::IncrementalCompact, @nospecialize(key),
866866
for i = 1:length(leaves)
867867
leaf = leaves[i]
868868
cache_key = leaf
869-
if isa(leaf, AnySSAValue)
870-
(def, leaf) = walk_to_def(compact, leaf)
871-
if is_known_invoke_or_call(def, Core.OptimizedGenerics.KeyValue.set, compact)
872-
@assert isexpr(def, :invoke)
873-
if length(def.args) in (5, 6)
874-
collection = def.args[end-2]
875-
set_key = def.args[end-1]
876-
set_val_idx = length(def.args)
877-
elseif length(def.args) == 4
878-
collection = def.args[end-1]
879-
# Key is deleted
880-
# TODO: Model this
881-
return nothing
882-
elseif length(def.args) == 3
883-
collection = def.args[end]
884-
# The whole collection is deleted
885-
# TODO: Model this
886-
return nothing
887-
else
888-
return nothing
889-
end
890-
if set_key === key || (egal_tfunc(𝕃ₒ, argextype(key, compact), argextype(set_key, compact)) == Const(true))
891-
lift_arg!(compact, leaf, cache_key, def, set_val_idx, lifted_leaves)
869+
while true
870+
if isa(leaf, AnySSAValue)
871+
(def, leaf) = walk_to_def(compact, leaf)
872+
if is_known_invoke_or_call(def, Core.OptimizedGenerics.KeyValue.set, compact)
873+
@assert isexpr(def, :invoke)
874+
if length(def.args) in (5, 6)
875+
collection = def.args[end-2]
876+
set_key = def.args[end-1]
877+
set_val_idx = length(def.args)
878+
elseif length(def.args) == 4
879+
collection = def.args[end-1]
880+
# Key is deleted
881+
# TODO: Model this
882+
return nothing
883+
elseif length(def.args) == 3
884+
collection = def.args[end]
885+
# The whole collection is deleted
886+
# TODO: Model this
887+
return nothing
888+
else
889+
return nothing
890+
end
891+
if set_key === key || (egal_tfunc(𝕃ₒ, argextype(key, compact), argextype(set_key, compact)) == Const(true))
892+
lift_arg!(compact, leaf, cache_key, def, set_val_idx, lifted_leaves)
893+
break
894+
end
895+
leaf = collection
892896
continue
893897
end
894-
# TODO: Continue walking the chain
895-
return nothing
896898
end
899+
return nothing
897900
end
898-
return nothing
899901
end
900902
return lifted_leaves
901903
end
@@ -919,11 +921,11 @@ function lift_keyvalue_get!(compact::IncrementalCompact, idx::Int, stmt::Expr,
919921
(lifted_val, nest) = perform_lifting!(compact,
920922
visited_philikes, key, result_t, lifted_leaves, collection, nothing)
921923

922-
compact[idx] = lifted_val === nothing ? nothing : Expr(:call, Core.tuple, lifted_val.val)
924+
compact[idx] = lifted_val === nothing ? nothing : Expr(:call, GlobalRef(Core, :tuple), lifted_val.val)
923925
finish_phi_nest!(compact, nest)
924926
if lifted_val !== nothing
925-
if !(𝕃ₒ, compact[SSAValue(idx)][:type], result_t)
926-
compact[SSAValue(idx)][:flag] |= IR_FLAG_REFINED
927+
if !(𝕃ₒ, compact[SSAValue(idx)][:type], tuple_tfunc(𝕃ₒ, Any[result_t]))
928+
add_flag!(compact[SSAValue(idx)], IR_FLAG_REFINED)
927929
end
928930
end
929931

base/scopedvalues.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,14 @@ function Scope(parent::Union{Nothing, Scope}, key::ScopedValue{T}, value) where
7070
return Scope(ScopeStorage(parent.values, key=>val))
7171
end
7272

73-
function Scope(scope, pairs::Pair{<:ScopedValue}...)
74-
for pair in pairs
75-
scope = Scope(scope, pair...)
76-
end
77-
return scope::Scope
73+
function Scope(scope, pair::Pair{<:ScopedValue})
74+
return Scope(scope, pair...)
75+
end
76+
77+
function Scope(scope, pair1::Pair{<:ScopedValue}, pair2::Pair{<:ScopedValue}, pairs::Pair{<:ScopedValue}...)
78+
# Unroll this loop through recursion to make sure that
79+
# our compiler optimization support works
80+
return Scope(Scope(scope, pair1...), pair2, pairs...)
7881
end
7982
Scope(::Nothing) = nothing
8083

test/compiler/irpasses.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1549,11 +1549,20 @@ function persistent_dict_elim()
15491549
a = Base.PersistentDict(:a => 1)
15501550
return a[:a]
15511551
end
1552+
15521553
# Ideally we would be able to fully eliminate this,
15531554
# but currently this would require an extra round of constprop
15541555
@test_broken fully_eliminated(persistent_dict_elim)
15551556
@test code_typed(persistent_dict_elim)[1][1].code[end] == Core.ReturnNode(1)
15561557

1558+
function persistent_dict_elim_multiple()
1559+
a = Base.PersistentDict(:a => 1)
1560+
b = Base.PersistentDict(a, :b => 2)
1561+
return b[:a]
1562+
end
1563+
@test_broken fully_eliminated(persistent_dict_elim_multiple)
1564+
@test code_typed(persistent_dict_elim_multiple)[1][1].code[end] == Core.ReturnNode(1)
1565+
15571566
# Test CFG simplify with try/catch blocks
15581567
let code = Any[
15591568
# Block 1

0 commit comments

Comments
 (0)