diff --git a/src/abstractinterpret/abstractanalyzer.jl b/src/abstractinterpret/abstractanalyzer.jl index 24cf96415..1a44e063a 100644 --- a/src/abstractinterpret/abstractanalyzer.jl +++ b/src/abstractinterpret/abstractanalyzer.jl @@ -130,7 +130,7 @@ mutable struct AnalyzerState # the temporal stash to keep track of the context of caller inference/optimization and # the caller itself, to which reconstructed cached reports will be appended - cache_target::Union{Nothing,Pair{Symbol,InferenceResult}} + cache_target::(@static VERSION ≥ v"1.11.0-DEV.1552" ? Nothing : Union{Nothing,Pair{Symbol,InferenceResult}}) ## abstract toplevel execution ## @@ -417,12 +417,12 @@ struct AnalysisCache end AnalysisCache() = AnalysisCache(IdDict{MethodInstance,CodeInstance}()) -Base.haskey(analysis_cache::AnalysisCache, mi::MethodInstance) = haskey(analysis_cache.cache, mi) -Base.get(analysis_cache::AnalysisCache, mi::MethodInstance, default) = get(analysis_cache.cache, mi, default) -Base.getindex(analysis_cache::AnalysisCache, mi::MethodInstance) = getindex(analysis_cache.cache, mi) -Base.setindex!(analysis_cache::AnalysisCache, ci::CodeInstance, mi::MethodInstance) = setindex!(analysis_cache.cache, ci, mi) -Base.delete!(analysis_cache::AnalysisCache, mi::MethodInstance) = delete!(analysis_cache.cache, mi) -Base.show(io::IO, analysis_cache::AnalysisCache) = print(io, typeof(analysis_cache), "(", length(analysis_cache.cache), " entries)") +# Base.haskey(analysis_cache::AnalysisCache, mi::MethodInstance) = haskey(analysis_cache.cache, mi) +# Base.get(analysis_cache::AnalysisCache, mi::MethodInstance, default) = get(analysis_cache.cache, mi, default) +# Base.getindex(analysis_cache::AnalysisCache, mi::MethodInstance) = getindex(analysis_cache.cache, mi) +# Base.setindex!(analysis_cache::AnalysisCache, ci::CodeInstance, mi::MethodInstance) = setindex!(analysis_cache.cache, ci, mi) +# Base.delete!(analysis_cache::AnalysisCache, mi::MethodInstance) = delete!(analysis_cache.cache, mi) +# Base.show(io::IO, analysis_cache::AnalysisCache) = print(io, typeof(analysis_cache), "(", length(analysis_cache.cache), " entries)") """ AnalysisCache(analyzer::AbstractAnalyzer) -> analysis_cache::AnalysisCache diff --git a/src/abstractinterpret/typeinfer.jl b/src/abstractinterpret/typeinfer.jl index c55c1ba0b..d092d1907 100644 --- a/src/abstractinterpret/typeinfer.jl +++ b/src/abstractinterpret/typeinfer.jl @@ -26,11 +26,15 @@ end function CC.const_prop_call(analyzer::AbstractAnalyzer, mi::MethodInstance, result::MethodCallResult, arginfo::ArgInfo, sv::InferenceState, concrete_eval_result::Union{Nothing,CC.ConstCallResults}) + @static if VERSION < v"1.11.0-DEV.1552" set_cache_target!(analyzer, :const_prop_call => sv.result) + end const_result = @invoke CC.const_prop_call(analyzer::AbstractInterpreter, mi::MethodInstance, result::MethodCallResult, arginfo::ArgInfo, sv::InferenceState, concrete_eval_result::Union{Nothing,CC.ConstCallResults}) + @static if VERSION < v"1.11.0-DEV.1552" @assert get_cache_target(analyzer) === nothing "invalid JET analysis state" + end if const_result !== nothing # successful constant prop', we need to update reports collect_callee_reports!(analyzer, sv) @@ -151,9 +155,25 @@ end # ------ @static if VERSION ≥ v"1.11.0-DEV.1552" + CC.cache_owner(analyzer::AbstractAnalyzer) = AnalysisCache(analyzer) + +function CC.return_cached_result(analyzer::AbstractAnalyzer, codeinst::CodeInstance, caller::InferenceState) + # cache hit, now we need to append cached reports associated with this `MethodInstance` + inferred = @atomic :monotonic codeinst.inferred + for cached in (inferred::CachedAnalysisResult).reports + restored = add_cached_report!(analyzer, caller.result, cached) + @static if JET_DEV_MODE + actual, expected = first(restored.vst).linfo, codeinst.def + @assert actual === expected "invalid global cache restoration, expected $expected but got $actual" + end + stash_report!(analyzer, restored) # should be updated in `abstract_call` (after exiting `typeinf_edge`) + end + return @invoke CC.return_cached_result(analyzer::AbstractInterpreter, codeinst::CodeInstance, caller::InferenceState) end +else # if VERSION ≥ v"1.11.0-DEV.1552" + function CC.code_cache(analyzer::AbstractAnalyzer) view = AbstractAnalyzerView(analyzer) worlds = WorldRange(get_inference_world(analyzer)) @@ -212,21 +232,6 @@ function CC.getindex(wvc::WorldView{<:AbstractAnalyzerView}, mi::MethodInstance) return codeinst::CodeInstance end -function CC.transform_result_for_cache(analyzer::AbstractAnalyzer, - linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult) - cache = InferenceErrorReport[] - for report in get_any_reports(analyzer, result) - @static if JET_DEV_MODE - actual, expected = first(report.vst).linfo, linfo - @assert actual === expected "invalid global caching detected, expected $expected but got $actual" - end - cache_report!(cache, report) - end - inferred_result = @invoke transform_result_for_cache(analyzer::AbstractInterpreter, - linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult) - return CachedAnalysisResult(inferred_result, cache) -end - function CC.setindex!(wvc::WorldView{<:AbstractAnalyzerView}, codeinst::CodeInstance, mi::MethodInstance) analysis_cache = AnalysisCache(wvc) @static if VERSION < v"1.11.0-DEV.1552" @@ -247,7 +252,7 @@ end function (callback::JETCallback)(replaced::MethodInstance, max_world::UInt32) delete!(callback.analysis_cache, replaced) end -else +else # if VERSION ≥ v"1.11.0-DEV.798" function add_jet_callback!(mi::MethodInstance, analysis_cache::AnalysisCache) callback = JETCallback(analysis_cache) if !isdefined(mi, :callbacks) @@ -274,11 +279,36 @@ function (callback::JETCallback)(replaced::MethodInstance, max_world::UInt32, end return nothing end -end +end # if VERSION ≥ v"1.11.0-DEV.798" + +end # if VERSION ≥ v"1.11.0-DEV.1552" # local # ----- +@static if VERSION ≥ v"1.11.0-DEV.1552" + +CC.get_inference_cache(analyzer::AbstractAnalyzer) = get_inf_cache(analyzer) + +function CC.return_cached_result(analyzer::AbstractAnalyzer, inf_result::InferenceResult, caller::InferenceState) + # as the analyzer uses the reports that are cached by the abstract-interpretation + # with the extended lattice elements, here we should throw-away the error reports + # that are collected during the previous non-constant abstract-interpretation + # (see the `CC.typeinf(::AbstractAnalyzer, ::InferenceState)` overload) + filter_lineages!(analyzer, caller.result, inf_result.linfo) + for cached in get_cached_reports(analyzer, inf_result) + restored = add_cached_report!(analyzer, caller.result, cached) + @static if JET_DEV_MODE + actual, expected = first(restored.vst).linfo, inf_result.linfo + @assert actual === expected "invalid local cache restoration, expected $expected but got $actual" + end + stash_report!(analyzer, restored) # should be updated in `abstract_call_method_with_const_args` + end + return @invoke CC.return_cached_result(analyzer::AbstractInterpreter, inf_result::InferenceResult, caller::InferenceState) +end + +else # if VERSION ≥ v"1.11.0-DEV.1552" + CC.get_inference_cache(analyzer::AbstractAnalyzer) = AbstractAnalyzerView(analyzer) function CC.cache_lookup(𝕃ᵢ::CC.AbstractLattice, mi::MethodInstance, given_argtypes::Argtypes, view::AbstractAnalyzerView) @@ -322,6 +352,8 @@ end CC.push!(view::AbstractAnalyzerView, inf_result::InferenceResult) = CC.push!(get_inf_cache(view.analyzer), inf_result) +end # if VERSION ≥ v"1.11.0-DEV.1552" + # main driver # =========== @@ -545,6 +577,21 @@ function CC.cache_result!(analyzer::AbstractAnalyzer, caller::InferenceResult) @invoke CC.cache_result!(analyzer::AbstractInterpreter, caller::InferenceResult) end +function CC.transform_result_for_cache(analyzer::AbstractAnalyzer, + linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult) + cache = InferenceErrorReport[] + for report in get_any_reports(analyzer, result) + @static if JET_DEV_MODE + actual, expected = first(report.vst).linfo, linfo + @assert actual === expected "invalid global caching detected, expected $expected but got $actual" + end + cache_report!(cache, report) + end + inferred_result = @invoke transform_result_for_cache(analyzer::AbstractInterpreter, + linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult) + return CachedAnalysisResult(inferred_result, cache) +end + # top-level bridge # ================