Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify caching logic #441

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/abstractinterpret/abstractanalyzer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -536,23 +536,27 @@ get_reports(analyzer::AbstractAnalyzer, result::InferenceResult) = (analyzer[res
get_cached_reports(analyzer::AbstractAnalyzer, result::InferenceResult) = (analyzer[result]::JETCachedResult).reports
get_any_reports(analyzer::AbstractAnalyzer, result::InferenceResult) = (analyzer[result]::AnyJETResult).reports

# HACK to avoid runtime dispatch
@inline push_report!(reports::Vector{InferenceErrorReport}, @nospecialize(report::InferenceErrorReport)) =
@invoke push!(reports::Vector, report::InferenceErrorReport)

"""
add_new_report!(analyzer::AbstractAnalyzer, result::InferenceResult, report::InferenceErrorReport)

Adds new [`report::InferenceErrorReport`](@ref InferenceErrorReport) associated with `result::InferenceResult`.
"""
function add_new_report!(analyzer::AbstractAnalyzer, result::InferenceResult, @nospecialize(report::InferenceErrorReport))
push!(get_reports(analyzer, result), report)
push_report!(get_reports(analyzer, result), report)
return report
end

function add_cached_report!(analyzer::AbstractAnalyzer, caller::InferenceResult, @nospecialize(cached::InferenceErrorReport))
cached = copy_report′(cached)
push!(get_reports(analyzer, caller), cached)
push_report!(get_reports(analyzer, caller), cached)
return cached
end

add_caller_cache!(analyzer::AbstractAnalyzer, @nospecialize(report::InferenceErrorReport)) = push!(get_caller_cache(analyzer), report)
add_caller_cache!(analyzer::AbstractAnalyzer, @nospecialize(report::InferenceErrorReport)) = push_report!(get_caller_cache(analyzer), report)
add_caller_cache!(analyzer::AbstractAnalyzer, reports::Vector{InferenceErrorReport}) = append!(get_caller_cache(analyzer), reports)

# AbstractInterpreter
Expand Down
112 changes: 18 additions & 94 deletions src/abstractinterpret/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,8 @@ end
# cache
# =====

cache_report!(cache, @nospecialize(report::InferenceErrorReport)) =
push!(cache, copy_report′(report)::InferenceErrorReport)
cache_report!(cache::Vector{InferenceErrorReport}, @nospecialize(report::InferenceErrorReport)) =
push_report!(cache, copy_report′(report)::InferenceErrorReport)

struct AbstractAnalyzerView{Analyzer<:AbstractAnalyzer}
analyzer::Analyzer
Expand Down Expand Up @@ -340,6 +340,7 @@ end # @static if hasmethod(CC.transform_result_for_cache, (...))

function CC.transform_result_for_cache(analyzer::AbstractAnalyzer,
linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult)
istoplevel(linfo) && return nothing
cache = InferenceErrorReport[]
for report in get_reports(analyzer, result)
@static if JET_DEV_MODE
Expand Down Expand Up @@ -543,104 +544,27 @@ function filter_lineages!(analyzer::AbstractAnalyzer, caller::InferenceResult, c
filter!(!islineage(caller.linfo, current), get_reports(analyzer, caller))
end

# in this overload we can work on `frame.src::CodeInfo` (and also `frame::InferenceState`)
# where type inference (and also optimization if applied) already ran on
function CC._typeinf(analyzer::AbstractAnalyzer, frame::InferenceState)
CC.typeinf_nocycle(analyzer, frame) || return false # frame is now part of a higher cycle
# with no active ip's, frame is done
frames = frame.callers_in_cycle
isempty(frames) && push!(frames, frame)
valid_worlds = WorldRange()
for caller in frames
@assert !(caller.dont_work_on_me)
caller.dont_work_on_me = true
# might might not fully intersect these earlier, so do that now
valid_worlds = CC.intersect(caller.valid_worlds, valid_worlds)
end
for caller in frames
caller.valid_worlds = valid_worlds
CC.finish(caller, analyzer)
# finalize and record the linfo result
caller.inferred = true
end
# NOTE we don't discard `InferenceState`s here so that some analyzers can use them in `finish!`
# # collect results for the new expanded frame
# results = Tuple{InferenceResult, Vector{Any}, Bool}[
# ( frames[i].result,
# frames[i].stmt_edges[1]::Vector{Any},
# frames[i].cached )
# for i in 1:length(frames) ]
# empty!(frames)
for frame in frames
caller = frame.result
opt = caller.src
if (@static VERSION ≥ v"1.9.0-DEV.1636" ?
(opt isa OptimizationState{typeof(analyzer)}) :
(opt isa OptimizationState))
CC.optimize(analyzer, opt, OptimizationParams(analyzer), caller)
# # COMBAK we may want to enable inlining ?
# if opt.const_api
# # XXX: The work in ir_to_codeinf! is essentially wasted. The only reason
# # we're doing it is so that code_llvm can return the code
# # for the `return ...::Const` (which never runs anyway). We should do this
# # as a post processing step instead.
# CC.ir_to_codeinf!(opt)
# if result_type isa Const
# caller.src = result_type
# else
# @assert CC.isconstType(result_type)
# caller.src = Const(result_type.parameters[1])
# end
# end
caller.valid_worlds = CC.getindex((opt.inlining.et::CC.EdgeTracker).valid_worlds)
end
end
function CC.finish!(analyzer::AbstractAnalyzer, caller::InferenceResult)
reports = get_reports(analyzer, caller)

for frame in frames
caller = frame.result
edges = frame.stmt_edges[1]::Vector{Any}
cached = frame.cached
valid_worlds = caller.valid_worlds
if CC.last(valid_worlds) >= get_world_counter()
# if we aren't cached, we don't need this edge
# but our caller might, so let's just make it anyways
CC.store_backedges(caller, edges)
end
CC.finish!(analyzer, frame)

reports = get_reports(analyzer, caller)
# XXX this is a dirty fix for performance problem, we need more "proper" fix
# https://github.com/aviatesk/JET.jl/issues/75
unique!(aggregation_policy(analyzer), reports)

# XXX this is a dirty fix for performance problem, we need more "proper" fix
# https://github.com/aviatesk/JET.jl/issues/75
unique!(aggregation_policy(analyzer), reports)
if get_entry(analyzer) !== caller.linfo
# inter-procedural handling: get back to the caller what we got from these results
add_caller_cache!(analyzer, reports)

# global cache management
if cached && !istoplevel(frame)
CC.cache_result!(analyzer, caller)
end

if frame.parent !== nothing
# inter-procedural handling: get back to the caller what we got from these results
add_caller_cache!(analyzer, reports)

# local cache management
# TODO there are duplicated work here and `transform_result_for_cache`
cache = InferenceErrorReport[]
for report in reports
cache_report!(cache, report)
end
set_cached_result!(analyzer, caller, cache)
# local cache management
# TODO there are duplicated work here and `transform_result_for_cache`
cache = InferenceErrorReport[]
for report in reports
cache_report!(cache, report)
end
set_cached_result!(analyzer, caller, cache)
end

return true
end

# by default, this overload just is forwarded to the AbstractInterpreter's implementation
# but the only reason we have this overload is that some analyzers (like `JETAnalyzer`)
# can further overload this to generate `InferenceErrorReport` with an access to `frame`
function CC.finish!(analyzer::AbstractAnalyzer, frame::InferenceState)
return CC.finish!(analyzer, frame.result)
return @invoke CC.finish!(analyzer::AbstractInterpreter, caller::InferenceResult)
end

# top-level bridge
Expand Down
84 changes: 44 additions & 40 deletions src/analyzers/jetanalyzer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,18 +170,17 @@ function CC.InferenceState(result::InferenceResult, cache::Symbol, analyzer::JET
return frame
end

function CC.finish!(analyzer::JETAnalyzer, frame::InferenceState)
src = @invoke CC.finish!(analyzer::AbstractAnalyzer, frame::InferenceState)

if isnothing(src)
# caught in cycle, similar error should have been reported where the source is available
return src
else
code = (src::CodeInfo).code
function CC.finish!(analyzer::JETAnalyzer, caller::InferenceResult)
src = @invoke CC.finish!(analyzer::AbstractInterpreter, caller::InferenceResult)
if src isa CodeInfo
# report pass for uncaught `throw` calls
ReportPass(analyzer)(UncaughtExceptionReport, analyzer, frame, code)
return src
ReportPass(analyzer)(UncaughtExceptionReport, analyzer, caller, src)
else
# very much optimized (nothing to report), or very much unoptimized:
# in a case of the latter, similar error should have been reported
# where the source is available
end
return @invoke CC.finish!(analyzer::AbstractAnalyzer, caller::InferenceResult)
end

let # overload `abstract_call_gf_by_type`
Expand Down Expand Up @@ -487,56 +486,60 @@ end
Represents general `throw` calls traced during inference.
This is reported only when it's not caught by control flow.
"""
@jetreport struct UncaughtExceptionReport <: InferenceErrorReport
throw_calls::Vector{Tuple{Int,Expr}} # (pc, call)
end
function UncaughtExceptionReport(sv::InferenceState, throw_calls::Vector{Tuple{Int,Expr}})
vf = get_virtual_frame(sv.linfo)
sig = Any[]
ncalls = length(throw_calls)
for (i, (pc, call)) in enumerate(throw_calls)
call_sig = get_sig_nowrap((sv, pc), call)
append!(sig, call_sig)
i ≠ ncalls && push!(sig, ", ")
end
return UncaughtExceptionReport([vf], Signature(sig), throw_calls)
end
function print_report_message(io::IO, (; throw_calls)::UncaughtExceptionReport)
msg = length(throw_calls) == 1 ? "may throw" : "may throw either of"
print(io, msg)
end
@jetreport struct UncaughtExceptionReport <: InferenceErrorReport end
print_report_message(io::IO, ::UncaughtExceptionReport) = print(io, "may throw")
print_signature(::UncaughtExceptionReport) = false

# @jetreport struct UncaughtExceptionReport <: InferenceErrorReport
# throw_calls::Vector{Tuple{Int,Expr}} # (pc, call)
# end
# function UncaughtExceptionReport(caller::InferenceResult, throw_calls::Vector{Tuple{Int,Expr}})
# vf = get_virtual_frame(caller.linfo)
# sig = Any[]
# ncalls = length(throw_calls)
# for (i, (pc, call)) in enumerate(throw_calls)
# call_sig = get_sig_nowrap((caller.src::CodeInfo, pc), call)
# append!(sig, call_sig)
# i ≠ ncalls && push!(sig, ", ")
# end
# return UncaughtExceptionReport([vf], Signature(sig), throw_calls)
# end
# function print_report_message(io::IO, (; throw_calls)::UncaughtExceptionReport)
# msg = length(throw_calls) == 1 ? "may throw" : "may throw either of"
# print(io, msg)
# end

# report `throw` calls "appropriately"
# this error report pass is very special, since 1.) it's tightly bound to the report pass of
# `SeriousExceptionReport` and 2.) it involves "report filtering" on its own
function (::BasicPass)(::Type{UncaughtExceptionReport}, analyzer::JETAnalyzer, frame::InferenceState, stmts::Vector{Any})
if frame.bestguess === Bottom
report_uncaught_exceptions!(analyzer, frame, stmts)
function (::BasicPass)(::Type{UncaughtExceptionReport}, analyzer::JETAnalyzer, caller::InferenceResult, src::CodeInfo)
if caller.result === Bottom
report_uncaught_exceptions!(analyzer, caller, src)
return true
else
# the non-`Bottom` result may mean `throw` calls from the children frames
# (if exists) are caught and not propagated here
# we don't want to cache the caught `UncaughtExceptionReport`s for this frame and
# its parents, and just filter them away now
filter!(get_reports(analyzer, frame.result)) do @nospecialize(report::InferenceErrorReport)
filter!(get_reports(analyzer, caller)) do @nospecialize(report::InferenceErrorReport)
return !isa(report, UncaughtExceptionReport)
end
end
return false
end
(::SoundPass)(::Type{UncaughtExceptionReport}, analyzer::JETAnalyzer, frame::InferenceState, stmts::Vector{Any}) =
report_uncaught_exceptions!(analyzer, frame, stmts) # yes, you want tons of false positives !
function report_uncaught_exceptions!(analyzer::JETAnalyzer, frame::InferenceState, stmts::Vector{Any})
(::SoundPass)(::Type{UncaughtExceptionReport}, analyzer::JETAnalyzer, caller::InferenceResult, src::CodeInfo) =
report_uncaught_exceptions!(analyzer, caller, src) # yes, you want tons of false positives !
function report_uncaught_exceptions!(analyzer::JETAnalyzer, caller::InferenceResult, src::CodeInfo)
# if the return type here is `Bottom` annotated, this _may_ mean there're uncaught
# `throw` calls
# XXX it's possible that the `throw` calls within them are all caught but the other
# critical errors still make the return type `Bottom`
# NOTE to reduce the false positive cases described above, we count `throw` calls
# after optimization, since it may have eliminated "unreachable" `throw` calls
codelocs = frame.src.codelocs
linetable = frame.src.linetable::LineTable
codelocs = src.codelocs
linetable = src.linetable::LineTable
reported_locs = nothing
for report in get_reports(analyzer, frame.result)
for report in get_reports(analyzer, caller)
if isa(report, SeriousExceptionReport)
if isnothing(reported_locs)
reported_locs = LineInfoNode[]
Expand All @@ -545,7 +548,7 @@ function report_uncaught_exceptions!(analyzer::JETAnalyzer, frame::InferenceStat
end
end
throw_calls = nothing
for (pc, stmt) in enumerate(stmts)
for (pc, stmt) in enumerate(src.code)
isa(stmt, Expr) || continue
is_throw_call(stmt) || continue
# if this `throw` is already reported, don't duplciate
Expand All @@ -558,7 +561,8 @@ function report_uncaught_exceptions!(analyzer::JETAnalyzer, frame::InferenceStat
push!(throw_calls, (pc, stmt))
end
if !isnothing(throw_calls) && !isempty(throw_calls)
add_new_report!(analyzer, frame.result, UncaughtExceptionReport(frame, throw_calls))
# TODO add_new_report!(analyzer, caller, UncaughtExceptionReport(caller, throw_calls))
add_new_report!(analyzer, caller, UncaughtExceptionReport(caller))
return true
end
return false
Expand Down
14 changes: 4 additions & 10 deletions src/analyzers/optanalyzer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ struct OptAnalysisPass <: ReportPass end

optanalyzer_function_filter(@nospecialize ft) = true

# TODO better to work only `finish!`
# TODO better to work only `finish!`, i.e. only work on `CodeInfo` (with static parameters)
function CC.finish(frame::InferenceState, analyzer::OptAnalyzer)
ret = @invoke CC.finish(frame::InferenceState, analyzer::AbstractAnalyzer)

Expand Down Expand Up @@ -272,20 +272,15 @@ function (::OptAnalysisPass)(::Type{CapturedVariableReport}, analyzer::OptAnalyz
return reported
end

function CC.finish!(analyzer::OptAnalyzer, frame::InferenceState)
caller = frame.result

function CC.finish!(analyzer::OptAnalyzer, caller::InferenceResult)
# get the source before running `finish!` to keep the reference to `OptimizationState`
src = caller.src

ret = @invoke CC.finish!(analyzer::AbstractAnalyzer, frame::InferenceState)

if popfirst!(analyzer.__analyze_frame)
ReportPass(analyzer)(OptimizationFailureReport, analyzer, caller)

if (@static VERSION ≥ v"1.9.0-DEV.1636" ?
(src isa OptimizationState{typeof(analyzer)}) :
(src isa OptimizationState)) # the compiler optimized it, analyze it
src.ir === nothing || CC.ir_to_codeinf!(src)
ReportPass(analyzer)(RuntimeDispatchReport, analyzer, caller, src)
elseif (@static JET_DEV_MODE ? true : false)
if isa(src, CC.ConstAPI)
Expand All @@ -298,8 +293,7 @@ function CC.finish!(analyzer::OptAnalyzer, frame::InferenceState)
end
end
end

return ret
return @invoke CC.finish!(analyzer::AbstractAnalyzer, caller::InferenceResult)
end

# report optimization failure due to recursive calls, etc.
Expand Down
2 changes: 1 addition & 1 deletion test/abstractinterpret/test_inferenceerrorreport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ end
result = report_call(m.foo, (String,))
r = only(get_reports_with_test(result))
@test isa(r, UncaughtExceptionReport)
@test Any['(', 's', String, ')', ArgumentError] ⫇ r.sig._sig
@test_broken Any['(', 's', String, ')', ArgumentError] ⫇ r.sig._sig
end

sparams1(::Type{T}) where T = zero(T)
Expand Down