Skip to content

Commit c3da8e6

Browse files
author
Ian Atol
committed
Allow inlining methods with unmatched type parameters
1 parent 3cff21e commit c3da8e6

File tree

8 files changed

+255
-73
lines changed

8 files changed

+255
-73
lines changed

base/compiler/ssair/inlining.jl

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -359,11 +359,17 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
359359
boundscheck = :off
360360
end
361361
end
362+
if !validate_sparams(sparam_vals)
363+
sparam_vals = insert_node_here!(compact,
364+
effect_free(NewInstruction(Expr(:call, Core._compute_sparams, item.mi.def, argexprs...), SimpleVector, topline)))
365+
end
362366
# If the iterator already moved on to the next basic block,
363367
# temporarily re-open in again.
364368
local return_value
365369
sig = def.sig
366370
# Special case inlining that maintains the current basic block if there's only one BB in the target
371+
new_new_offset = length(compact.new_new_nodes)
372+
late_fixup_offset = length(compact.late_fixup)
367373
if spec.linear_inline_eligible
368374
#compact[idx] = nothing
369375
inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx)
@@ -372,7 +378,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
372378
# face of rename_arguments! mutating in place - should figure out
373379
# something better eventually.
374380
inline_compact[idx′] = nothing
375-
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, compact)
381+
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, inline_compact)
376382
if isa(stmt′, ReturnNode)
377383
val = stmt′.val
378384
return_value = SSAValue(idx′)
@@ -383,7 +389,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
383389
end
384390
inline_compact[idx′] = stmt′
385391
end
386-
just_fixup!(inline_compact)
392+
just_fixup!(inline_compact, new_new_offset, late_fixup_offset)
387393
compact.result_idx = inline_compact.result_idx
388394
else
389395
bb_offset, post_bb_id = popfirst!(todo_bbs)
@@ -397,7 +403,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
397403
inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx)
398404
for ((_, idx′), stmt′) in inline_compact
399405
inline_compact[idx′] = nothing
400-
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, compact)
406+
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, inline_compact)
401407
if isa(stmt′, ReturnNode)
402408
if isdefined(stmt′, :val)
403409
val = stmt′.val
@@ -428,7 +434,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
428434
end
429435
inline_compact[idx′] = stmt′
430436
end
431-
just_fixup!(inline_compact)
437+
just_fixup!(inline_compact, new_new_offset, late_fixup_offset)
432438
compact.result_idx = inline_compact.result_idx
433439
compact.active_result_bb = inline_compact.active_result_bb
434440
if length(pn.edges) == 1
@@ -896,8 +902,7 @@ function analyze_method!(match::MethodMatch, argtypes::Vector{Any},
896902
end
897903
end
898904

899-
# Bail out if any static parameters are left as TypeVar
900-
validate_sparams(match.sparams) || return nothing
905+
#validate_sparams(match.sparams) || return nothing
901906

902907
et = state.et
903908

@@ -1104,7 +1109,7 @@ function inline_invoke!(
11041109
argtypes = invoke_rewrite(sig.argtypes)
11051110
if isa(result, ConstPropResult)
11061111
(; mi) = item = InliningTodo(result.result, argtypes)
1107-
validate_sparams(mi.sparam_vals) || return nothing
1112+
# validate_sparams(mi.sparam_vals) || return nothing
11081113
if argtypes_to_type(argtypes) <: mi.def.sig
11091114
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
11101115
handle_single_case!(ir, idx, stmt, item, todo, state.params, true)
@@ -1327,7 +1332,7 @@ function handle_const_prop_result!(
13271332
(; mi) = item = InliningTodo(result.result, argtypes)
13281333
spec_types = mi.specTypes
13291334
allow_abstract || isdispatchtuple(spec_types) || return false
1330-
validate_sparams(mi.sparam_vals) || return false
1335+
#validate_sparams(mi.sparam_vals) || return false
13311336
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
13321337
item === nothing && return false
13331338
push!(cases, InliningCase(spec_types, item))
@@ -1365,7 +1370,6 @@ function handle_const_opaque_closure_call!(
13651370
sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}})
13661371
item = InliningTodo(result.result, sig.argtypes)
13671372
isdispatchtuple(item.mi.specTypes) || return
1368-
validate_sparams(item.mi.sparam_vals) || return
13691373
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
13701374
handle_single_case!(ir, idx, stmt, item, todo, state.params)
13711375
return nothing
@@ -1545,38 +1549,49 @@ function late_inline_special_case!(
15451549
end
15461550

15471551
function ssa_substitute!(idx::Int, @nospecialize(val), arg_replacements::Vector{Any},
1548-
@nospecialize(spsig), spvals::SimpleVector,
1552+
@nospecialize(spsig), spvals::Union{SimpleVector, SSAValue},
15491553
linetable_offset::Int32, boundscheck::Symbol, compact::IncrementalCompact)
15501554
compact.result[idx][:flag] &= ~IR_FLAG_INBOUNDS
15511555
compact.result[idx][:line] += linetable_offset
1552-
return ssa_substitute_op!(val, arg_replacements, spsig, spvals, boundscheck)
1556+
return ssa_substitute_op!(val, arg_replacements, spsig, spvals, boundscheck, compact, idx)
15531557
end
15541558

15551559
function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any},
1556-
@nospecialize(spsig), spvals::SimpleVector, boundscheck::Symbol)
1560+
@nospecialize(spsig), spvals::Union{SimpleVector, SSAValue}, boundscheck::Symbol,
1561+
compact::IncrementalCompact, idx::Int)
15571562
if isa(val, Argument)
15581563
return arg_replacements[val.n]
15591564
end
15601565
if isa(val, Expr)
15611566
e = val::Expr
15621567
head = e.head
15631568
if head === :static_parameter
1564-
return quoted(spvals[e.args[1]::Int])
1569+
if isa(spvals, SimpleVector)
1570+
return quoted(spvals[e.args[1]::Int])
1571+
else
1572+
ret = insert_node!(compact, SSAValue(idx),
1573+
effect_free(NewInstruction(Expr(:call, Core._svec_ref, false, spvals, e.args[1]), Any)))
1574+
return ret
1575+
end
15651576
elseif head === :cfunction
1566-
@assert !isa(spsig, UnionAll) || !isempty(spvals)
1567-
e.args[3] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[3], spsig, spvals)
1568-
e.args[4] = svec(Any[
1569-
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals)
1570-
for argt in e.args[4]::SimpleVector ]...)
1577+
if isa(spvals, SimpleVector)
1578+
@assert !isa(spsig, UnionAll) || !isempty(spvals)
1579+
e.args[3] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[3], spsig, spvals)
1580+
e.args[4] = svec(Any[
1581+
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals)
1582+
for argt in e.args[4]::SimpleVector ]...)
1583+
end
15711584
elseif head === :foreigncall
1572-
@assert !isa(spsig, UnionAll) || !isempty(spvals)
1573-
for i = 1:length(e.args)
1574-
if i == 2
1575-
e.args[2] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[2], spsig, spvals)
1576-
elseif i == 3
1577-
e.args[3] = svec(Any[
1578-
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals)
1579-
for argt in e.args[3]::SimpleVector ]...)
1585+
if isa(spvals, SimpleVector)
1586+
@assert !isa(spsig, UnionAll) || !isempty(spvals)
1587+
for i = 1:length(e.args)
1588+
if i == 2
1589+
e.args[2] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[2], spsig, spvals)
1590+
elseif i == 3
1591+
e.args[3] = svec(Any[
1592+
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals)
1593+
for argt in e.args[3]::SimpleVector ]...)
1594+
end
15801595
end
15811596
end
15821597
elseif head === :boundscheck
@@ -1591,7 +1606,7 @@ function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any},
15911606
end
15921607
urs = userefs(val)
15931608
for op in urs
1594-
op[] = ssa_substitute_op!(op[], arg_replacements, spsig, spvals, boundscheck)
1609+
op[] = ssa_substitute_op!(op[], arg_replacements, spsig, spvals, boundscheck, compact, idx)
15951610
end
15961611
return urs[]
15971612
end

base/compiler/ssair/ir.jl

Lines changed: 77 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -631,16 +631,13 @@ mutable struct IncrementalCompact
631631
perm = my_sortperm(Int[code.new_nodes.info[i].pos for i in 1:length(code.new_nodes)])
632632
new_len = length(code.stmts) + length(code.new_nodes)
633633
ssa_rename = Any[SSAValue(i) for i = 1:new_len]
634-
new_new_used_ssas = Vector{Int}()
635-
late_fixup = Vector{Int}()
636634
bb_rename = Vector{Int}()
637-
new_new_nodes = NewNodeStream()
638635
pending_nodes = NewNodeStream()
639636
pending_perm = Int[]
640637
return new(code, parent.result,
641638
parent.result_bbs, ssa_rename, bb_rename, bb_rename, parent.used_ssas,
642-
late_fixup, perm, 1,
643-
new_new_nodes, new_new_used_ssas, pending_nodes, pending_perm,
639+
parent.late_fixup, perm, 1,
640+
parent.new_new_nodes, parent.new_new_used_ssas, pending_nodes, pending_perm,
644641
1, result_offset, parent.active_result_bb, false, false, false)
645642
end
646643
end
@@ -1469,62 +1466,104 @@ function maybe_erase_unused!(
14691466
return false
14701467
end
14711468

1472-
function fixup_phinode_values!(compact::IncrementalCompact, old_values::Vector{Any})
1469+
struct FixedNode
1470+
node::Any
1471+
needs_fixup::Bool
1472+
FixedNode(@nospecialize(node), needs_fixup::Bool) = new(node, needs_fixup)
1473+
end
1474+
1475+
function fixup_phinode_values!(compact::IncrementalCompact, old_values::Vector{Any}, reify_new_nodes::Bool)
14731476
values = Vector{Any}(undef, length(old_values))
1477+
needs_fixup = false
14741478
for i = 1:length(old_values)
14751479
isassigned(old_values, i) || continue
14761480
val = old_values[i]
1477-
if isa(val, Union{OldSSAValue, NewSSAValue})
1478-
val = fixup_node(compact, val)
1481+
if isa(val, OldSSAValue)
1482+
val = compact.ssa_rename[val.id]
1483+
if isa(val, SSAValue)
1484+
compact.used_ssas[val.id] += 1
1485+
end
1486+
elseif isa(val, NewSSAValue)
1487+
if reify_new_nodes
1488+
val = SSAValue(length(compact.result) + val.id)
1489+
else
1490+
needs_fixup = true
1491+
end
14791492
end
14801493
values[i] = val
14811494
end
1482-
values
1495+
return FixedNode(values, needs_fixup)
14831496
end
14841497

1485-
function fixup_node(compact::IncrementalCompact, @nospecialize(stmt))
1498+
function fixup_node(compact::IncrementalCompact, @nospecialize(stmt), reify_new_nodes::Bool)
14861499
if isa(stmt, PhiNode)
1487-
return PhiNode(stmt.edges, fixup_phinode_values!(compact, stmt.values))
1500+
(;node, needs_fixup) = fixup_phinode_values!(compact, stmt.values, reify_new_nodes)
1501+
return FixedNode(PhiNode(stmt.edges, node), needs_fixup)
14881502
elseif isa(stmt, PhiCNode)
1489-
return PhiCNode(fixup_phinode_values!(compact, stmt.values))
1503+
(;node, needs_fixup) = fixup_phinode_values!(compact, stmt.values, reify_new_nodes)
1504+
return FixedNode(PhiCNode(node), needs_fixup)
14901505
elseif isa(stmt, NewSSAValue)
1491-
return SSAValue(length(compact.result) + stmt.id)
1492-
elseif isa(stmt, OldSSAValue)
1493-
val = compact.ssa_rename[stmt.id]
1494-
if isa(val, SSAValue)
1495-
# If `val.id` is greater than the length of `compact.result` or
1496-
# `compact.used_ssas`, this SSA value is in `new_new_nodes`, so
1497-
# don't count the use
1498-
compact.used_ssas[val.id] += 1
1506+
if reify_new_nodes
1507+
return FixedNode(SSAValue(length(compact.result) + stmt.id), false)
1508+
else
1509+
return FixedNode(stmt, true)
14991510
end
1500-
return val
1511+
elseif isa(stmt, OldSSAValue)
1512+
return FixedNode(compact.ssa_rename[stmt.id], false)
15011513
else
15021514
urs = userefs(stmt)
1515+
needs_fixup = false
15031516
for ur in urs
15041517
val = ur[]
1505-
if isa(val, Union{NewSSAValue, OldSSAValue})
1506-
ur[] = fixup_node(compact, val)
1518+
if isa(val, NewSSAValue)
1519+
if reify_new_nodes
1520+
val = SSAValue(length(compact.result) + val.id)
1521+
else
1522+
needs_fixup = true
1523+
end
1524+
elseif isa(val, OldSSAValue)
1525+
val = compact.ssa_rename[val.id]
15071526
end
1527+
if isa(val, SSAValue) && val.id <= length(compact.used_ssas)
1528+
# If `val.id` is greater than the length of `compact.result` or
1529+
# `compact.used_ssas`, this SSA value is in `new_new_nodes`, so
1530+
# don't count the use
1531+
compact.used_ssas[val.id] += 1
1532+
end
1533+
ur[] = val
15081534
end
1509-
return urs[]
1535+
return FixedNode(urs[], needs_fixup)
15101536
end
15111537
end
15121538

1513-
function just_fixup!(compact::IncrementalCompact)
1514-
resize!(compact.used_ssas, length(compact.result))
1515-
append!(compact.used_ssas, compact.new_new_used_ssas)
1516-
empty!(compact.new_new_used_ssas)
1517-
for idx in compact.late_fixup
1539+
function just_fixup!(compact::IncrementalCompact, new_new_nodes_offset::Union{Int, Nothing} = nothing, late_fixup_offset::Union{Int, Nothing}=nothing)
1540+
if new_new_nodes_offset === late_fixup_offset === nothing # only do this appending in non_dce_finish!
1541+
resize!(compact.used_ssas, length(compact.result))
1542+
append!(compact.used_ssas, compact.new_new_used_ssas)
1543+
empty!(compact.new_new_used_ssas)
1544+
end
1545+
off = late_fixup_offset === nothing ? 1 : (late_fixup_offset+1)
1546+
set_off = off
1547+
for i in off:length(compact.late_fixup)
1548+
idx = compact.late_fixup[i]
15181549
stmt = compact.result[idx][:inst]
1519-
new_stmt = fixup_node(compact, stmt)
1520-
(stmt === new_stmt) || (compact.result[idx][:inst] = new_stmt)
1521-
end
1522-
for idx in 1:length(compact.new_new_nodes)
1523-
node = compact.new_new_nodes.stmts[idx]
1524-
stmt = node[:inst]
1525-
new_stmt = fixup_node(compact, stmt)
1526-
if new_stmt !== stmt
1527-
node[:inst] = new_stmt
1550+
(;node, needs_fixup) = fixup_node(compact, stmt, late_fixup_offset === nothing)
1551+
(stmt === node) || (compact.result[idx][:inst] = node)
1552+
if needs_fixup
1553+
compact.late_fixup[set_off] = idx
1554+
set_off += 1
1555+
end
1556+
end
1557+
if late_fixup_offset !== nothing
1558+
resize!(compact.late_fixup, set_off-1)
1559+
end
1560+
off = new_new_nodes_offset === nothing ? 1 : (new_new_nodes_offset+1)
1561+
for idx in off:length(compact.new_new_nodes)
1562+
new_node = compact.new_new_nodes.stmts[idx]
1563+
stmt = new_node[:inst]
1564+
(;node) = fixup_node(compact, stmt, late_fixup_offset === nothing)
1565+
if node !== stmt
1566+
new_node[:inst] = node
15281567
end
15291568
end
15301569
end

0 commit comments

Comments
 (0)