Skip to content

wip: make Cthulhu cache CodeInstance-based #614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 83 additions & 89 deletions src/Cthulhu.jl

Large diffs are not rendered by default.

102 changes: 48 additions & 54 deletions src/callsite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,30 @@ using Unicode
abstract type CallInfo end

# Call could be resolved to a singular MI
struct MICallInfo <: CallInfo
mi::MethodInstance
struct EdgeCallInfo <: CallInfo
ci::CodeInstance
rt
effects::Effects
exct
function MICallInfo(mi::MethodInstance, @nospecialize(rt), effects, @nospecialize(exct=nothing))
function EdgeCallInfo(ci::CodeInstance, @nospecialize(rt), effects::Effects, @nospecialize(exct=nothing))
if isa(rt, LimitedAccuracy)
return LimitedCallInfo(new(mi, ignorelimited(rt), effects, exct))
return LimitedCallInfo(new(ci, ignorelimited(rt), effects, exct))
else
return new(mi, rt, effects, exct)
return new(ci, rt, effects, exct)
end
end
end
get_mi(ci::MICallInfo) = ci.mi
get_rt(ci::MICallInfo) = ci.rt
get_effects(ci::MICallInfo) = ci.effects
get_exct(ci::MICallInfo) = ci.exct
get_ci(ci::EdgeCallInfo) = ci.ci
get_rt(ci::EdgeCallInfo) = ci.rt
get_effects(ci::EdgeCallInfo) = ci.effects
get_exct(ci::EdgeCallInfo) = ci.exct

abstract type WrappedCallInfo <: CallInfo end

get_wrapped(ci::WrappedCallInfo) = ci.wrapped
ignorewrappers(ci::CallInfo) = ci
ignorewrappers(ci::WrappedCallInfo) = ignorewrappers(get_wrapped(ci))
get_mi(ci::WrappedCallInfo) = get_mi(ignorewrappers(ci))
get_ci(ci::WrappedCallInfo) = get_ci(ignorewrappers(ci))
get_rt(ci::WrappedCallInfo) = get_rt(ignorewrappers(ci))
get_effects(ci::WrappedCallInfo) = get_effects(ignorewrappers(ci))
get_exct(ci::WrappedCallInfo) = get_exct(ignorewrappers(ci))
Expand All @@ -44,22 +44,17 @@ struct RTCallInfo <: CallInfo
exct
end
get_rt(ci::RTCallInfo) = ci.rt
get_mi(ci::RTCallInfo) = nothing
get_ci(ci::RTCallInfo) = nothing
get_effects(ci::RTCallInfo) = Effects()
get_exct(ci::RTCallInfo) = ci.exct

# uncached callsite, we can't recurse into this call
struct UncachedCallInfo <: WrappedCallInfo
wrapped::CallInfo
end

struct PureCallInfo <: CallInfo
argtypes::Vector{Any}
rt
PureCallInfo(argtypes::Vector{Any}, @nospecialize(rt)) =
new(argtypes, rt)
end
get_mi(::PureCallInfo) = nothing
get_ci(::PureCallInfo) = nothing
get_rt(pci::PureCallInfo) = pci.rt
get_effects(::PureCallInfo) = EFFECTS_TOTAL
get_exct(::PureCallInfo) = Union{}
Expand All @@ -69,7 +64,7 @@ struct FailedCallInfo <: CallInfo
sig
rt
end
get_mi(ci::FailedCallInfo) = fail(ci)
get_ci(ci::FailedCallInfo) = fail(ci)
get_rt(ci::FailedCallInfo) = fail(ci)
get_effects(ci::FailedCallInfo) = fail(ci)
get_exct(ci::FailedCallInfo) = fail(ci)
Expand All @@ -83,7 +78,7 @@ struct GeneratedCallInfo <: CallInfo
sig
rt
end
get_mi(genci::GeneratedCallInfo) = fail(genci)
get_ci(genci::GeneratedCallInfo) = fail(genci)
get_rt(genci::GeneratedCallInfo) = fail(genci)
get_effects(genci::GeneratedCallInfo) = fail(genci)
get_exct(genci::GeneratedCallInfo) = fail(genci)
Expand All @@ -101,15 +96,15 @@ struct MultiCallInfo <: CallInfo
@nospecialize(exct=nothing)) =
new(sig, rt, exct, callinfos)
end
get_mi(ci::MultiCallInfo) = error("Can't extract MethodInstance from multiple call informations")
get_ci(ci::MultiCallInfo) = error("Can't extract MethodInstance from multiple call informations")
get_rt(ci::MultiCallInfo) = ci.rt
get_effects(mci::MultiCallInfo) = mapreduce(get_effects, CC.merge_effects, mci.callinfos)
get_exct(ci::MultiCallInfo) = ci.exct

struct TaskCallInfo <: CallInfo
ci::CallInfo
end
get_mi(tci::TaskCallInfo) = get_mi(tci.ci)
get_ci(tci::TaskCallInfo) = get_ci(tci.ci)
get_rt(tci::TaskCallInfo) = get_rt(tci.ci)
get_effects(tci::TaskCallInfo) = get_effects(tci.ci)
get_exct(tci::TaskCallInfo) = get_exct(tci.ci)
Expand All @@ -118,7 +113,7 @@ struct InvokeCallInfo <: CallInfo
ci::CallInfo
InvokeCallInfo(@nospecialize ci::CallInfo) = new(ci)
end
get_mi(ici::InvokeCallInfo) = get_mi(ici.ci)
get_ci(ici::InvokeCallInfo) = get_ci(ici.ci)
get_rt(ici::InvokeCallInfo) = get_rt(ici.ci)
get_effects(ici::InvokeCallInfo) = get_effects(ici.ci)
get_exct(ici::InvokeCallInfo) = get_exct(ici.ci)
Expand All @@ -128,7 +123,7 @@ struct OCCallInfo <: CallInfo
ci::CallInfo
OCCallInfo(@nospecialize ci::CallInfo) = new(ci)
end
get_mi(occi::OCCallInfo) = get_mi(occi.ci)
get_ci(occi::OCCallInfo) = get_ci(occi.ci)
get_rt(occi::OCCallInfo) = get_rt(occi.ci)
get_effects(occi::OCCallInfo) = get_effects(occi.ci)
get_exct(occi::OCCallInfo) = get_exct(occi.ci)
Expand All @@ -137,52 +132,52 @@ get_exct(occi::OCCallInfo) = get_exct(occi.ci)
struct ReturnTypeCallInfo <: CallInfo
vmi::CallInfo # virtualized method call
end
get_mi((; vmi)::ReturnTypeCallInfo) = isa(vmi, FailedCallInfo) ? nothing : get_mi(vmi)
get_ci((; vmi)::ReturnTypeCallInfo) = isa(vmi, FailedCallInfo) ? nothing : get_ci(vmi)
get_rt((; vmi)::ReturnTypeCallInfo) = Type{isa(vmi, FailedCallInfo) ? Union{} : widenconst(get_rt(vmi))}
get_effects(::ReturnTypeCallInfo) = EFFECTS_TOTAL
get_exct(::ReturnTypeCallInfo) = Union{} # FIXME

struct ConstPropCallInfo <: CallInfo
mi::CallInfo
ci::CallInfo
result::InferenceResult
end
get_mi(cpci::ConstPropCallInfo) = cpci.result.linfo
get_rt(cpci::ConstPropCallInfo) = get_rt(cpci.mi)
get_ci(cpci::ConstPropCallInfo) = get_ci(cpci.ci)
get_rt(cpci::ConstPropCallInfo) = get_rt(cpci.ci)
get_effects(cpci::ConstPropCallInfo) = get_effects(cpci.result)
get_exct(cpci::ConstPropCallInfo) = get_exct(cpci.mi)
get_exct(cpci::ConstPropCallInfo) = get_exct(cpci.ci)

struct ConcreteCallInfo <: CallInfo
mi::CallInfo
ci::CallInfo
argtypes::ArgTypes
end
get_mi(ceci::ConcreteCallInfo) = get_mi(ceci.mi)
get_rt(ceci::ConcreteCallInfo) = get_rt(ceci.mi)
get_effects(ceci::ConcreteCallInfo) = get_effects(ceci.mi)
get_exct(cici::ConcreteCallInfo) = get_exct(ceci.mi)
get_ci(ceci::ConcreteCallInfo) = get_ci(ceci.ci)
get_rt(ceci::ConcreteCallInfo) = get_rt(ceci.ci)
get_effects(ceci::ConcreteCallInfo) = get_effects(ceci.ci)
get_exct(cici::ConcreteCallInfo) = get_exct(ceci.ci)

struct SemiConcreteCallInfo <: CallInfo
mi::CallInfo
ci::CallInfo
ir::IRCode
end
get_mi(scci::SemiConcreteCallInfo) = get_mi(scci.mi)
get_rt(scci::SemiConcreteCallInfo) = get_rt(scci.mi)
get_effects(scci::SemiConcreteCallInfo) = get_effects(scci.mi)
get_exct(scci::SemiConcreteCallInfo) = get_exct(scci.mi)
get_ci(scci::SemiConcreteCallInfo) = get_ci(scci.ci)
get_rt(scci::SemiConcreteCallInfo) = get_rt(scci.ci)
get_effects(scci::SemiConcreteCallInfo) = get_effects(scci.ci)
get_exct(scci::SemiConcreteCallInfo) = get_exct(scci.ci)

# CUDA callsite
struct CuCallInfo <: CallInfo
cumi::MICallInfo
ci::EdgeCallInfo
end
get_mi(gci::CuCallInfo) = get_mi(gci.cumi)
get_rt(gci::CuCallInfo) = get_rt(gci.cumi)
get_effects(gci::CuCallInfo) = get_effects(gci.cumi)
get_ci(gci::CuCallInfo) = get_ci(gci.ci)
get_rt(gci::CuCallInfo) = get_rt(gci.ci)
get_effects(gci::CuCallInfo) = get_effects(gci.ci)

struct Callsite
id::Int # ssa-id
info::CallInfo
head::Symbol
end
get_mi(c::Callsite) = get_mi(c.info)
get_ci(c::Callsite) = get_ci(c.info)
get_effects(c::Callsite) = get_effects(c.info)

# Callsite printing
Expand Down Expand Up @@ -277,17 +272,17 @@ function Base.show(io::IO, (;exct)::ExctWrapper)
printstyled(io, "(↑::", exct, ")"; color)
end

function show_callinfo(limiter, mici::MICallInfo)
mi = mici.mi
function show_callinfo(limiter, ci::EdgeCallInfo)
mi = ci.ci.def
tt = (Base.unwrap_unionall(mi.specTypes)::DataType).parameters[2:end]
if !isa(mi.def, Method)
name = ":toplevel"
else
name = mi.def.name
end
rt = get_rt(mici)
exct = get_exct(mici)
__show_limited(limiter, name, tt, rt, get_effects(mici), exct)
rt = get_rt(ci)
exct = get_exct(ci)
__show_limited(limiter, name, tt, rt, get_effects(ci), exct)
end

function show_callinfo(limiter, ci::Union{MultiCallInfo, FailedCallInfo, GeneratedCallInfo})
Expand Down Expand Up @@ -317,20 +312,20 @@ function show_callinfo(limiter, ci::ConstPropCallInfo)
# XXX: The first argument could be const-overriden too
name = ci.result.linfo.def.name
tt = ci.result.argtypes[2:end]
ci = ignorewrappers(ci.mi)::MICallInfo
ci = ignorewrappers(ci.ci)::EdgeCallInfo
__show_limited(limiter, name, tt, get_rt(ci), get_effects(ci))
end

function show_callinfo(limiter, ci::SemiConcreteCallInfo)
# XXX: The first argument could be const-overriden too
name = get_mi(ci).def.name
name = get_ci(ci).def.def.name
tt = ci.ir.argtypes[2:end]
__show_limited(limiter, name, tt, get_rt(ci), get_effects(ci))
end

function show_callinfo(limiter, ci::ConcreteCallInfo)
# XXX: The first argument could be const-overriden too
name = get_mi(ci).def.name
name = get_ci(ci).def.def.name
tt = ci.argtypes[2:end]
__show_limited(limiter, name, tt, get_rt(ci), get_effects(ci))
end
Expand Down Expand Up @@ -435,7 +430,7 @@ function Base.show(io::IO, c::Callsite)
limiter = TextWidthLimiter(io, cols)
limiter.width += 1 # for the '%' character
print(limiter, string(c.id))
if isa(info, MICallInfo)
if isa(info, EdgeCallInfo)
print(limiter, optimize ? string(" = ", c.head, ' ') : " = ")
show_callinfo(limiter, info)
else
Expand All @@ -457,7 +452,6 @@ function wrapped_callinfo(limiter, ci::WrappedCallInfo)
print(limiter, " > ")
end
_wrapped_callinfo(limiter, ::LimitedCallInfo) = print(limiter, "limited")
_wrapped_callinfo(limiter, ::UncachedCallInfo) = print(limiter, "uncached")

# is_callsite returns true if `call` dispatches to `callee`
# See also `maybe_callsite` below
Expand Down Expand Up @@ -527,7 +521,7 @@ function maybe_callsite(info::RTCallInfo, @nospecialize(tt::Type))
end
return true
end
function maybe_callsite(info::MICallInfo, @nospecialize(tt::Type))
function maybe_callsite(info::EdgeCallInfo, @nospecialize(tt::Type))
return tt <: info.mi.specTypes
end

Expand Down
22 changes: 8 additions & 14 deletions src/codeview.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,29 +117,21 @@ is_type_unstable(@nospecialize(type)) = type isa Type && (!Base.isdispatchelem(t

cthulhu_warntype(args...; kwargs...) = cthulhu_warntype(stdout::IO, args...; kwargs...)
function cthulhu_warntype(io::IO, debuginfo::AnyDebugInfo,
src::Union{CodeInfo,IRCode}, @nospecialize(rt), effects::Effects, mi::Union{Nothing,MethodInstance}=nothing;
src::Union{CodeInfo,IRCode}, @nospecialize(rt), effects::Effects, codeinst::Union{Nothing,CodeInstance}=nothing;
hide_type_stable::Bool=false, inline_cost::Bool=false, optimize::Bool=false,
interp::CthulhuInterpreter=CthulhuInterpreter())
if inline_cost
isa(mi, MethodInstance) || error("Need a MethodInstance to show inlining costs. Call `cthulhu_typed` directly instead.")
end
cthulhu_typed(io, debuginfo, src, rt, nothing, effects, mi; iswarn=true, optimize, hide_type_stable, inline_cost, interp)
cthulhu_typed(io, debuginfo, src, rt, nothing, effects, codeinst; iswarn=true, optimize, hide_type_stable, inline_cost, interp)
return nothing
end

# # for API consistency with the others
# function cthulhu_typed(io::IO, mi::MethodInstance, optimize, debuginfo, params, config::CthulhuConfig)
# interp = mkinterp(mi)
# (; src, rt, infos, slottypes) = lookup(interp, mi, optimize)
# ci = Cthulhu.preprocess_ci!(src, mi, optimize, config)
# cthulhu_typed(io, debuginfo, src, rt, mi)
# end

cthulhu_typed(io::IO, debuginfo::DebugInfo, args...; kwargs...) =
cthulhu_typed(io, Symbol(debuginfo), args...; kwargs...)
function cthulhu_typed(io::IO, debuginfo::Symbol,
src::Union{CodeInfo,IRCode}, @nospecialize(rt), @nospecialize(exct),
effects::Effects, mi::Union{Nothing,MethodInstance};
effects::Effects, codeinst::Union{Nothing,CodeInstance};
iswarn::Bool=false, hide_type_stable::Bool=false, optimize::Bool=true,
pc2remarks::Union{Nothing,PC2Remarks}=nothing,
pc2effects::Union{Nothing,PC2Effects}=nothing,
Expand All @@ -148,6 +140,8 @@ function cthulhu_typed(io::IO, debuginfo::Symbol,
inlay_types_vscode::Bool=false, diagnostics_vscode::Bool=false, jump_always::Bool=false,
interp::AbstractInterpreter=CthulhuInterpreter())

mi = codeinst === nothing ? nothing : codeinst.def

debuginfo = IRShow.debuginfo(debuginfo)
lineprinter = __debuginfo[debuginfo]
rettype = ignorelimited(rt)
Expand Down Expand Up @@ -316,11 +310,11 @@ function cthulhu_typed(io::IO, debuginfo::Symbol,
end
println(lambda_io)
else
isa(mi, MethodInstance) || throw("`mi::MethodInstance` is required")
isa(codeinst, CodeInstance) || throw("`codeinst::CodeInstance` is required")
cfg = src isa IRCode ? src.cfg : CC.compute_basic_blocks(src.code)
max_bb_idx_size = length(string(length(cfg.blocks)))
str = irshow_config.line_info_preprinter(lambda_io, " "^(max_bb_idx_size + 2), -1)
callsite = Callsite(0, MICallInfo(mi, rettype, effects, exct), :invoke)
callsite = Callsite(0, EdgeCallInfo(codeinst, rettype, effects, exct), :invoke)
println(lambda_io, "∘ ", "─"^(max_bb_idx_size), str, " ", callsite)
end

Expand Down Expand Up @@ -459,7 +453,7 @@ function Base.show(
(; interp, mi) = b
(; effects) = lookup(interp, mi, optimize)
if get(io, :typeinfo, Any) === Bookmark # a hack to check if in Vector etc.
print(io, Callsite(-1, MICallInfo(b.mi, rt, Effects()), :invoke))
print(io, Callsite(-1, EdgeCallInfo(b.mi, rt, Effects()), :invoke))
print(io, " (world: ", world, ")")
return
end
Expand Down
Loading
Loading