Skip to content

effects: taint overlay-ed method's :nonoverlayed effect bit #51078

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 11 additions & 20 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# At this point we are guaranteed to end up throwing on this path,
# which is all that's required for :consistent-cy. Of course, we don't
# know anything else about this statement.
effects = Effects(; consistent=ALWAYS_TRUE, nonoverlayed=!isoverlayed(method_table(interp)))
effects = Effects(; consistent=ALWAYS_TRUE)
return CallMeta(Any, effects, NoCallInfo())
end

Expand All @@ -28,7 +28,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
return CallMeta(Any, Effects(), NoCallInfo())
end

(; valid_worlds, applicable, info, nonoverlayed) = matches
(; valid_worlds, applicable, info) = matches
update_valid_age!(sv, valid_worlds)
napplicable = length(applicable)
rettype = Bottom
Expand All @@ -39,7 +39,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
const_results = Union{Nothing,ConstResult}[]
multiple_matches = napplicable > 1
fargs = arginfo.fargs
all_effects = Effects(EFFECTS_TOTAL; nonoverlayed)
all_effects = EFFECTS_TOTAL

𝕃ₚ = ipo_lattice(interp)
for i in 1:napplicable
Expand Down Expand Up @@ -205,7 +205,6 @@ struct MethodMatches
valid_worlds::WorldRange
mt::MethodTable
fullmatch::Bool
nonoverlayed::Bool
end
any_ambig(info::MethodMatchInfo) = info.results.ambig
any_ambig(m::MethodMatches) = any_ambig(m.info)
Expand All @@ -217,7 +216,6 @@ struct UnionSplitMethodMatches
valid_worlds::WorldRange
mts::Vector{MethodTable}
fullmatches::Vector{Bool}
nonoverlayed::Bool
end
any_ambig(m::UnionSplitMethodMatches) = any(any_ambig, m.info.matches)

Expand All @@ -233,19 +231,16 @@ function find_matching_methods(𝕃::AbstractLattice,
valid_worlds = WorldRange()
mts = MethodTable[]
fullmatches = Bool[]
nonoverlayed = true
for i in 1:length(split_argtypes)
arg_n = split_argtypes[i]::Vector{Any}
sig_n = argtypes_to_type(arg_n)
mt = ccall(:jl_method_table_for, Any, (Any,), sig_n)
mt === nothing && return FailedMethodMatch("Could not identify method table for call")
mt = mt::MethodTable
result = findall(sig_n, method_table; limit = max_methods)
if result === nothing
matches = findall(sig_n, method_table; limit = max_methods)
if matches === nothing
return FailedMethodMatch("For one of the union split cases, too many methods matched")
end
(; matches, overlayed) = result
nonoverlayed &= !overlayed
push!(infos, MethodMatchInfo(matches))
for m in matches
push!(applicable, m)
Expand All @@ -271,28 +266,25 @@ function find_matching_methods(𝕃::AbstractLattice,
UnionSplitInfo(infos),
valid_worlds,
mts,
fullmatches,
nonoverlayed)
fullmatches)
else
mt = ccall(:jl_method_table_for, Any, (Any,), atype)
if mt === nothing
return FailedMethodMatch("Could not identify method table for call")
end
mt = mt::MethodTable
result = findall(atype, method_table; limit = max_methods)
if result === nothing
matches = findall(atype, method_table; limit = max_methods)
if matches === nothing
# this means too many methods matched
# (assume this will always be true, so we don't compute / update valid age in this case)
return FailedMethodMatch("Too many methods matched")
end
(; matches, overlayed) = result
fullmatch = any(match::MethodMatch->match.fully_covers, matches)
return MethodMatches(matches.matches,
MethodMatchInfo(matches),
matches.valid_worlds,
mt,
fullmatch,
!overlayed)
fullmatch)
end
end

Expand Down Expand Up @@ -862,7 +854,7 @@ function concrete_eval_eligible(interp::AbstractInterpreter,
mi = result.edge
if mi !== nothing && is_foldable(effects)
if f !== nothing && is_all_const_arg(arginfo, #=start=#2)
if is_nonoverlayed(mi.def::Method) && (!isoverlayed(method_table(interp)) || is_nonoverlayed(effects))
if is_nonoverlayed(interp) || is_nonoverlayed(effects)
return :concrete_eval
end
# disable concrete-evaluation if this function call is tainted by some overlayed
Expand Down Expand Up @@ -1924,7 +1916,7 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
lookupsig = rewrap_unionall(Tuple{ft, unwrapped.parameters...}, types)::Type
nargtype = Tuple{ft, nargtype.parameters...}
argtype = Tuple{ft, argtype.parameters...}
match, valid_worlds, overlayed = findsup(lookupsig, method_table(interp))
match, valid_worlds = findsup(lookupsig, method_table(interp))
match === nothing && return CallMeta(Any, Effects(), NoCallInfo())
update_valid_age!(sv, valid_worlds)
method = match.method
Expand Down Expand Up @@ -1955,7 +1947,6 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
end
end
rt = from_interprocedural!(interp, rt, sv, arginfo, sig)
effects = Effects(effects; nonoverlayed = !overlayed)
info = InvokeCallInfo(match, const_result)
edge !== nothing && add_invoke_backedge!(sv, lookupsig, edge)
return CallMeta(rt, effects, info)
Expand Down
13 changes: 12 additions & 1 deletion base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,11 @@ mutable struct InferenceState
ipo_effects = Effects(ipo_effects; effect_free = ALWAYS_FALSE)
end

restrict_abstract_call_sites = isa(linfo.def, Module)
if def isa Method
ipo_effects = Effects(ipo_effects; nonoverlayed=is_nonoverlayed(def))
end

restrict_abstract_call_sites = isa(def, Module)
@assert cache === :no || cache === :local || cache === :global
cached = cache === :global

Expand All @@ -314,6 +318,13 @@ mutable struct InferenceState
end
end

is_nonoverlayed(m::Method) = !isdefined(m, :external_mt)
is_nonoverlayed(interp::AbstractInterpreter) = !isoverlayed(method_table(interp))
isoverlayed(::MethodTableView) = error("unsatisfied MethodTableView interface")
isoverlayed(::InternalMethodTable) = false
isoverlayed(::OverlayMethodTable) = true
isoverlayed(mt::CachedMethodTable) = isoverlayed(mt.table)

is_inferred(sv::InferenceState) = is_inferred(sv.result)
is_inferred(result::InferenceResult) = result.result !== nothing

Expand Down
51 changes: 16 additions & 35 deletions base/compiler/methodtable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@ function iterate(result::MethodLookupResult, args...)
end
getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch

struct MethodMatchResult
matches::MethodLookupResult
overlayed::Bool
end

"""
struct InternalMethodTable <: MethodTableView

Expand Down Expand Up @@ -55,47 +50,42 @@ Overlays another method table view with an additional local fast path cache that
can respond to repeated, identical queries faster than the original method table.
"""
struct CachedMethodTable{T<:MethodTableView} <: MethodTableView
cache::IdDict{MethodMatchKey, Union{Nothing,MethodMatchResult}}
cache::IdDict{MethodMatchKey, Union{Nothing,MethodLookupResult}}
table::T
end
CachedMethodTable(table::T) where T = CachedMethodTable{T}(IdDict{MethodMatchKey, Union{Nothing,MethodMatchResult}}(), table)
CachedMethodTable(table::T) where T = CachedMethodTable{T}(IdDict{MethodMatchKey, Union{Nothing,MethodLookupResult}}(), table)

"""
findall(sig::Type, view::MethodTableView; limit::Int=-1) ->
MethodMatchResult(matches::MethodLookupResult, overlayed::Bool) or nothing
matches::MethodLookupResult or nothing

Find all methods in the given method table `view` that are applicable to the given signature `sig`.
If no applicable methods are found, an empty result is returned.
If the number of applicable methods exceeded the specified `limit`, `nothing` is returned.
Note that the default setting `limit=-1` does not limit the number of applicable methods.
`overlayed` indicates if any of the matching methods comes from an overlayed method table.
"""
function findall(@nospecialize(sig::Type), table::InternalMethodTable; limit::Int=-1)
result = _findall(sig, nothing, table.world, limit)
result === nothing && return nothing
return MethodMatchResult(result, false)
end
findall(@nospecialize(sig::Type), table::InternalMethodTable; limit::Int=-1) =
_findall(sig, nothing, table.world, limit)

function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int=-1)
result = _findall(sig, table.mt, table.world, limit)
result === nothing && return nothing
nr = length(result)
if nr ≥ 1 && result[nr].fully_covers
# no need to fall back to the internal method table
return MethodMatchResult(result, true)
return result
end
# fall back to the internal method table
fallback_result = _findall(sig, nothing, table.world, limit)
fallback_result === nothing && return nothing
# merge the fallback match results with the internal method table
return MethodMatchResult(
MethodLookupResult(
vcat(result.matches, fallback_result.matches),
WorldRange(
max(result.valid_worlds.min_world, fallback_result.valid_worlds.min_world),
min(result.valid_worlds.max_world, fallback_result.valid_worlds.max_world)),
result.ambig | fallback_result.ambig),
!isempty(result))
return MethodLookupResult(
vcat(result.matches, fallback_result.matches),
WorldRange(
max(result.valid_worlds.min_world, fallback_result.valid_worlds.min_world),
min(result.valid_worlds.max_world, fallback_result.valid_worlds.max_world)),
result.ambig | fallback_result.ambig)
end

function _findall(@nospecialize(sig::Type), mt::Union{Nothing,MethodTable}, world::UInt, limit::Int)
Expand Down Expand Up @@ -138,21 +128,19 @@ In both cases `nothing` is returned.

`overlayed` indicates if any of the matching methods comes from an overlayed method table.
"""
function findsup(@nospecialize(sig::Type), table::InternalMethodTable)
return (_findsup(sig, nothing, table.world)..., false)
end
findsup(@nospecialize(sig::Type), table::InternalMethodTable) =
_findsup(sig, nothing, table.world)

function findsup(@nospecialize(sig::Type), table::OverlayMethodTable)
match, valid_worlds = _findsup(sig, table.mt, table.world)
match !== nothing && return match, valid_worlds, true
match !== nothing && return match, valid_worlds
# fall back to the internal method table
fallback_match, fallback_valid_worlds = _findsup(sig, nothing, table.world)
return (
fallback_match,
WorldRange(
max(valid_worlds.min_world, fallback_valid_worlds.min_world),
min(valid_worlds.max_world, fallback_valid_worlds.max_world)),
false)
min(valid_worlds.max_world, fallback_valid_worlds.max_world)))
end

function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,MethodTable}, world::UInt)
Expand All @@ -166,10 +154,3 @@ end

# This query is not cached
findsup(@nospecialize(sig::Type), table::CachedMethodTable) = findsup(sig, table.table)

isoverlayed(::MethodTableView) = error("unsatisfied MethodTableView interface")
isoverlayed(::InternalMethodTable) = false
isoverlayed(::OverlayMethodTable) = true
isoverlayed(mt::CachedMethodTable) = isoverlayed(mt.table)
isoverlayed(m::Method) = isdefined(m, :external_mt)
is_nonoverlayed(m::Method) = !isoverlayed(m)
2 changes: 1 addition & 1 deletion base/compiler/ssair/irinterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ function concrete_eval_invoke(interp::AbstractInterpreter,
argtypes === nothing && return Pair{Any,Bool}(Bottom, false)
effects = decode_effects(code.ipo_purity_bits)
if (is_foldable(effects) && is_all_const_arg(argtypes, #=start=#1) &&
is_nonoverlayed(effects) && is_nonoverlayed(mi.def::Method))
(is_nonoverlayed(interp) || is_nonoverlayed(effects)))
args = collect_const_args(argtypes, #=start=#1)
value = let world = get_world_counter(interp)
try
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2763,7 +2763,7 @@ function _hasmethod_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, sv
if !isa(mt, MethodTable)
return CallMeta(Bool, EFFECTS_THROWS, NoCallInfo())
end
match, valid_worlds, overlayed = findsup(types, method_table(interp))
match, valid_worlds = findsup(types, method_table(interp))
update_valid_age!(sv, valid_worlds)
if match === nothing
rt = Const(false)
Expand Down
5 changes: 2 additions & 3 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1633,13 +1633,12 @@ function infer_effects(@nospecialize(f), @nospecialize(types=default_tt(f));
Core.Compiler.ArgInfo(nothing, argtypes), rt)
end
tt = signature_type(f, types)
result = Core.Compiler.findall(tt, Core.Compiler.method_table(interp))
if result === nothing
matches = Core.Compiler.findall(tt, Core.Compiler.method_table(interp))
if matches === nothing
# unanalyzable call, i.e. the interpreter world might be newer than the world where
# the `f` is defined, return the unknown effects
return Core.Compiler.Effects()
end
(; matches) = result
effects = Core.Compiler.EFFECTS_TOTAL
if matches.ambig || !any(match::Core.MethodMatch->match.fully_covers, matches.matches)
# account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature.
Expand Down
3 changes: 2 additions & 1 deletion test/compiler/AbstractInterpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ callstrange(::Float64) = strangesin(x)
callstrange(::Nothing) = Core.compilerbarrier(:type, nothing) # trigger inference bail out
callstrange_entry(x) = callstrange(x) # needs to be defined here because of world age
let interp = MTOverlayInterp(Set{Any}())
matches = Core.Compiler.findall(Tuple{typeof(callstrange),Any}, Core.Compiler.method_table(interp)).matches
matches = Core.Compiler.findall(Tuple{typeof(callstrange),Any}, Core.Compiler.method_table(interp))
@test matches !== nothing
@test Core.Compiler.length(matches) == 2
if Core.Compiler.getindex(matches, 1).method == which(callstrange, (Nothing,))
@test Base.infer_effects(callstrange_entry, (Any,); interp) |> !Core.Compiler.is_nonoverlayed
Expand Down
2 changes: 1 addition & 1 deletion test/compiler/datastructures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using Test
sig = Tuple{typeof(*), Any, Any}
result1 = Core.Compiler.findall(sig, table; limit=-1)
result2 = Core.Compiler.findall(sig, table; limit=Core.Compiler.InferenceParams().max_methods)
@test result1 !== nothing && !Core.Compiler.isempty(result1.matches)
@test result1 !== nothing && !Core.Compiler.isempty(result1)
@test result2 === nothing
end

Expand Down