Skip to content

Commit

Permalink
optimizer: fully support inlining of union-split, partially constant-…
Browse files Browse the repository at this point in the history
…prop' callsite (#43347)

Makes full use of constant-propagation, by addressing this [TODO](https://github.com/JuliaLang/julia/blob/00734c5fd045316a00d287ca2c0ec1a2eef6e4d1/base/compiler/ssair/inlining.jl#L1212).
Here is a performance improvement from #43287:
```julia
ulia> using BenchmarkTools

julia> X = rand(ComplexF32, 64, 64);

julia> dst = reinterpret(reshape, Float32, X);

julia> src = copy(dst);

julia> @Btime copyto!($dst, $src);
  50.819 μs (1 allocation: 32 bytes) # v1.6.4
  41.081 μs (0 allocations: 0 bytes) # this commit
```

fixes #43287
  • Loading branch information
aviatesk authored Jan 5, 2022
1 parent 85a6990 commit 1b600f0
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 91 deletions.
1 change: 1 addition & 0 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# by constant analysis, but let's create `ConstCallInfo` if there has been any successful
# constant propagation happened since other consumers may be interested in this
if any_const_result && seen == napplicable
@assert napplicable == nmatches(info) == length(const_results)
info = ConstCallInfo(info, const_results)
end

Expand Down
159 changes: 84 additions & 75 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -689,19 +689,16 @@ function rewrite_apply_exprargs!(
new_sig = with_atype(call_sig(ir, new_stmt)::Signature)
new_info = call.info
if isa(new_info, ConstCallInfo)
maybe_handle_const_call!(
handle_const_call!(
ir, state1.id, new_stmt, new_info, flag,
new_sig, istate, todo) && @goto analyzed
new_info = new_info.call # cascade to the non-constant handling
end
if isa(new_info, MethodMatchInfo) || isa(new_info, UnionSplitInfo)
new_sig, istate, todo)
elseif isa(new_info, MethodMatchInfo) || isa(new_info, UnionSplitInfo)
new_infos = isa(new_info, MethodMatchInfo) ? MethodMatchInfo[new_info] : new_info.matches
# See if we can inline this call to `iterate`
analyze_single_call!(
ir, state1.id, new_stmt, new_infos, flag,
new_sig, istate, todo)
end
@label analyzed
if i != length(thisarginfo.each)
valT = getfield_tfunc(call.rt, Const(1))
val_extracted = insert_node!(ir, idx, NewInstruction(
Expand Down Expand Up @@ -1136,139 +1133,150 @@ function process_simple!(ir::IRCode, idx::Int, state::InliningState, todo::Vecto
return stmt, sig
end

# TODO inline non-`isdispatchtuple`, union-split callsites
# TODO inline non-`isdispatchtuple`, union-split callsites?
function analyze_single_call!(
ir::IRCode, idx::Int, stmt::Expr, infos::Vector{MethodMatchInfo}, flag::UInt8,
sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}})
(; argtypes, atype) = sig
cases = InliningCase[]
local signature_union = Bottom
local only_method = nothing # keep track of whether there is one matching method
local meth
local meth::MethodLookupResult
local fully_covered = true
for i in 1:length(infos)
info = infos[i]
meth = info.results
meth = infos[i].results
if meth.ambig
# Too many applicable methods
# Or there is a (partial?) ambiguity
return
return nothing
elseif length(meth) == 0
# No applicable methods; try next union split
continue
elseif length(meth) == 1 && only_method !== false
if only_method === nothing
only_method = meth[1].method
elseif only_method !== meth[1].method
else
if length(meth) == 1 && only_method !== false
if only_method === nothing
only_method = meth[1].method
elseif only_method !== meth[1].method
only_method = false
end
else
only_method = false
end
else
only_method = false
end
for match in meth
spec_types = match.spec_types
signature_union = Union{signature_union, spec_types}
if !isdispatchtuple(spec_types)
fully_covered = false
continue
end
item = analyze_method!(match, argtypes, flag, state)
if item === nothing
fully_covered = false
continue
elseif _any(case->case.sig === spec_types, cases)
continue
end
push!(cases, InliningCase(spec_types, item))
signature_union = Union{signature_union, match.spec_types}
fully_covered &= handle_match!(match, argtypes, flag, state, cases)
end
end

# if the signature is fully or mostly covered and there is only one applicable method,
# if the signature is fully covered and there is only one applicable method,
# we can try to inline it even if the signature is not a dispatch tuple
if length(cases) == 0 && only_method isa Method
if length(infos) > 1
(metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any),
atype, only_method.sig)::SimpleVector
match = MethodMatch(metharg, methsp::SimpleVector, only_method, true)
else
meth = meth::MethodLookupResult
@assert length(meth) == 1
match = meth[1]
end
item = analyze_method!(match, argtypes, flag, state)
item === nothing && return
item === nothing && return nothing
push!(cases, InliningCase(match.spec_types, item))
fully_covered = match.fully_covers
else
fully_covered &= atype <: signature_union
end

# If we only have one case and that case is fully covered, we may either
# be able to do the inlining now (for constant cases), or push it directly
# onto the todo list
if fully_covered && length(cases) == 1
handle_single_case!(ir, idx, stmt, cases[1].item, todo)
elseif length(cases) > 0
push!(todo, idx=>UnionSplit(fully_covered, atype, cases))
end
return nothing
handle_cases!(ir, idx, stmt, sig, cases, fully_covered, todo)
end

# try to create `InliningCase`s using constant-prop'ed results
# currently it works only when constant-prop' succeeded for all (union-split) signatures
# TODO use any of constant-prop'ed results, and leave the other unhandled cases to later
# TODO this function contains a lot of duplications with `analyze_single_call!`, factor them out
function maybe_handle_const_call!(
ir::IRCode, idx::Int, stmt::Expr, info::ConstCallInfo, flag::UInt8,
# similar to `analyze_single_call!`, but with constant results
function handle_const_call!(
ir::IRCode, idx::Int, stmt::Expr, cinfo::ConstCallInfo, flag::UInt8,
sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}})
(; argtypes, atype) = sig
results = info.results
cases = InliningCase[] # TODO avoid this allocation for single cases ?
(; call, results) = cinfo
infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches
cases = InliningCase[]
local fully_covered = true
local signature_union = Bottom
for result in results
isa(result, InferenceResult) || return false
(; mi) = item = InliningTodo(result, argtypes)
spec_types = mi.specTypes
signature_union = Union{signature_union, spec_types}
if !isdispatchtuple(spec_types)
fully_covered = false
continue
end
if !validate_sparams(mi.sparam_vals)
fully_covered = false
local j = 0
for i in 1:length(infos)
meth = infos[i].results
if meth.ambig
# Too many applicable methods
# Or there is a (partial?) ambiguity
return nothing
elseif length(meth) == 0
# No applicable methods; try next union split
continue
end
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
if item === nothing
fully_covered = false
continue
for match in meth
j += 1
result = results[j]
if result === nothing
signature_union = Union{signature_union, match.spec_types}
fully_covered &= handle_match!(match, argtypes, flag, state, cases)
else
signature_union = Union{signature_union, result.linfo.specTypes}
fully_covered &= handle_const_result!(result, argtypes, flag, state, cases)
end
end
push!(cases, InliningCase(spec_types, item))
end

# if the signature is fully covered and there is only one applicable method,
# we can try to inline it even if the signature is not a dispatch tuple
if length(cases) == 0 && length(results) == 1
(; mi) = item = InliningTodo(results[1]::InferenceResult, argtypes)
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
validate_sparams(mi.sparam_vals) || return true
item === nothing && return true
validate_sparams(mi.sparam_vals) || return nothing
item === nothing && return nothing
push!(cases, InliningCase(mi.specTypes, item))
fully_covered = atype <: mi.specTypes
else
fully_covered &= atype <: signature_union
end

handle_cases!(ir, idx, stmt, sig, cases, fully_covered, todo)
end

function handle_match!(
match::MethodMatch, argtypes::Vector{Any}, flag::UInt8, state::InliningState,
cases::Vector{InliningCase})
spec_types = match.spec_types
isdispatchtuple(spec_types) || return false
item = analyze_method!(match, argtypes, flag, state)
item === nothing && return false
_any(case->case.sig === spec_types, cases) && return true
push!(cases, InliningCase(spec_types, item))
return true
end

function handle_const_result!(
result::InferenceResult, argtypes::Vector{Any}, flag::UInt8, state::InliningState,
cases::Vector{InliningCase})
(; mi) = item = InliningTodo(result, argtypes)
spec_types = mi.specTypes
isdispatchtuple(spec_types) || return false
validate_sparams(mi.sparam_vals) || return false
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
item === nothing && return false
push!(cases, InliningCase(spec_types, item))
return true
end

function handle_cases!(ir::IRCode, idx::Int, stmt::Expr, sig::Signature,
cases::Vector{InliningCase}, fully_covered::Bool, todo::Vector{Pair{Int, Any}})
# If we only have one case and that case is fully covered, we may either
# be able to do the inlining now (for constant cases), or push it directly
# onto the todo list
if fully_covered && length(cases) == 1
handle_single_case!(ir, idx, stmt, cases[1].item, todo)
elseif length(cases) > 0
push!(todo, idx=>UnionSplit(fully_covered, atype, cases))
push!(todo, idx=>UnionSplit(fully_covered, sig.atype, cases))
end
return true
return nothing
end

function handle_const_opaque_closure_call!(
Expand Down Expand Up @@ -1302,7 +1310,8 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
end
ir.stmts[idx][:flag] |= IR_FLAG_EFFECT_FREE
info = info.info
elseif info === false
end
if info === false
# Inference determined this couldn't be analyzed. Don't question it.
continue
end
Expand Down Expand Up @@ -1333,10 +1342,10 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
# if inference arrived here with constant-prop'ed result(s),
# we can perform a specialized analysis for just this case
if isa(info, ConstCallInfo)
maybe_handle_const_call!(
handle_const_call!(
ir, idx, stmt, info, flag,
sig, state, todo) && continue
info = info.call # cascade to the non-constant handling
sig, state, todo)
continue
end

# Ok, now figure out what method to call
Expand Down
33 changes: 21 additions & 12 deletions base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,27 @@ struct UnionSplitInfo
matches::Vector{MethodMatchInfo}
end

nmatches(info::MethodMatchInfo) = length(info.results)
function nmatches(info::UnionSplitInfo)
n = 0
for mminfo in info.matches
n += nmatches(mminfo)
end
return n
end

"""
info::ConstCallInfo
The precision of this call was improved using constant information.
In addition to the original call information `info.call`, this info also keeps
the inference results with constant information `info.results::Vector{Union{Nothing,InferenceResult}}`.
"""
struct ConstCallInfo
call::Union{MethodMatchInfo,UnionSplitInfo}
results::Vector{Union{Nothing,InferenceResult}}
end

"""
info::MethodResultPure
Expand Down Expand Up @@ -92,18 +113,6 @@ struct UnionSplitApplyCallInfo
infos::Vector{ApplyCallInfo}
end

"""
info::ConstCallInfo
The precision of this call was improved using constant information.
In addition to the original call information `info.call`, this info also keeps
the inference results with constant information `info.results::Vector{Union{Nothing,InferenceResult}}`.
"""
struct ConstCallInfo
call::Union{MethodMatchInfo,UnionSplitInfo}
results::Vector{Union{Nothing,InferenceResult}}
end

"""
info::InvokeCallInfo
Expand Down
13 changes: 9 additions & 4 deletions test/compiler/inline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -759,13 +759,18 @@ end
import Base: @constprop

# test union-split callsite with successful and unsuccessful constant-prop' results
@constprop :aggressive @inline f42840(xs, a::Int) = xs[a] # should be successful, and inlined
@constprop :none @noinline f42840(xs::AbstractVector, a::Int) = xs[a] # should be unsuccessful, but still statically resolved
# (also for https://github.com/JuliaLang/julia/issues/43287)
@constprop :aggressive @inline f42840(cond::Bool, xs::Tuple, a::Int) = # should be successful, and inlined with constant prop' result
cond ? xs[a] : @noinline(length(xs))
@constprop :none @noinline f42840(::Bool, xs::AbstractVector, a::Int) = # should be unsuccessful, but still statically resolved
xs[a]
let src = code_typed((Union{Tuple{Int,Int,Int}, Vector{Int}},)) do xs
f42840(xs, 2)
f42840(true, xs, 2)
end |> only |> first
# `(xs::Tuple{Int,Int,Int})[a::Const(2)]` => `getfield(xs, 2)`
# `f43287(true, xs::Tuple{Int,Int,Int}, 2)` => `getfield(xs, 2)`
# `f43287(true, xs::Vector{Int}, 2)` => `:invoke f43287(true, xs, 2)`
@test count(iscall((src, getfield)), src.code) == 1
@test count(isinvoke(:length), src.code) == 0
@test count(isinvoke(:f42840), src.code) == 1
end
# a bit weird, but should handle this kind of case as well
Expand Down

0 comments on commit 1b600f0

Please sign in to comment.