Skip to content

Commit

Permalink
AbstractInterpreter: enable selective pure/concrete eval for extern…
Browse files Browse the repository at this point in the history
…al `AbstractInterpreter` with overlayed method table

Built on top of #44511 and #44561, and solves <JuliaGPU/GPUCompiler.jl#309>.
This commit allows external `AbstractInterpreter` to selectively use
pure/concrete evals even if it uses an overlayed method table.
More specifically, such `AbstractInterpreter` can use pure/concrete evals
as far as any callees used in a call in question doesn't come from the
overlayed method table:
```julia
@test Base.return_types((), MTOverlayInterp()) do
    isbitstype(Int) ? nothing : missing
end == Any[Nothing]
Base.@assume_effects :terminates_globally function issue41694(x)
    res = 1
    1 < x < 20 || throw("bad")
    while x > 1
        res *= x
        x -= 1
    end
    return res
end
@test Base.return_types((), MTOverlayInterp()) do
    issue41694(3) == 6 ? nothing : missing
end == Any[Nothing]
```

In order to check if a call is tainted by any overlayed call, our effect
system now additionally tracks `overlayed::Bool` property. This effect
property is required to prevents concrete-eval in the following kind of situation:
```julia
strangesin(x) = sin(x)
@overlay OverlayedMT strangesin(x::Float64) = iszero(x) ? nothing : cos(x)
Base.@assume_effects :total totalcall(f, args...) = f(args...)
@test Base.return_types(; interp=MTOverlayInterp()) do
    # we need to disable partial pure/concrete evaluation when tainted by any overlayed call
    if totalcall(strangesin, 1.0) == cos(1.0)
        return nothing
    else
        return missing
    end
end |> only === Nothing
```
  • Loading branch information
aviatesk committed Mar 14, 2022
1 parent f5f7445 commit 7ae3491
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 85 deletions.
74 changes: 45 additions & 29 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# At this point we are guaranteed to end up throwing on this path,
# which is all that's required for :consistent-cy. Of course, we don't
# know anything else about this statement.
tristate_merge!(sv, Effects(Effects(), consistent=ALWAYS_TRUE))
tristate_merge!(sv, Effects(Effects(true), consistent=ALWAYS_TRUE))
return CallMeta(Any, false)
end

argtypes = arginfo.argtypes
matches = find_matching_methods(argtypes, atype, method_table(interp), InferenceParams(interp).MAX_UNION_SPLITTING, max_methods)
if isa(matches, FailedMethodMatch)
add_remark!(interp, sv, matches.reason)
tristate_merge!(sv, Effects())
tristate_merge!(sv, Effects(true))
return CallMeta(Any, false)
end

Expand All @@ -72,6 +72,12 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
any_const_result = false
const_results = Union{InferenceResult,Nothing,ConstResult}[]
multiple_matches = napplicable > 1
if matches.overlayed
# currently we don't have a good way to execute the overlayed method definition,
# so we should give up pure/concrete eval when any of the matched methods is overlayed
f = nothing
tristate_merge!(sv, Effects(true))
end

val = pure_eval_call(interp, f, applicable, arginfo, sv)
val !== nothing && return CallMeta(val, MethodResultPure(info)) # TODO: add some sort of edge(s)
Expand Down Expand Up @@ -102,7 +108,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
this_arginfo = ArgInfo(fargs, this_argtypes)
const_call_result = abstract_call_method_with_const_args(interp, result, f, this_arginfo, match, sv)
const_call_result = abstract_call_method_with_const_args(interp, result,
f, this_arginfo, match, sv)
effects = result.edge_effects
const_result = nothing
if const_call_result !== nothing
Expand Down Expand Up @@ -144,7 +151,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# this is in preparation for inlining, or improving the return result
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
this_arginfo = ArgInfo(fargs, this_argtypes)
const_call_result = abstract_call_method_with_const_args(interp, result, f, this_arginfo, match, sv)
const_call_result = abstract_call_method_with_const_args(interp, result,
f, this_arginfo, match, sv)
effects = result.edge_effects
const_result = nothing
if const_call_result !== nothing
Expand Down Expand Up @@ -189,7 +197,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end

if seen != napplicable
tristate_merge!(sv, Effects())
tristate_merge!(sv, Effects(true))
elseif isa(matches, MethodMatches) ? (!matches.fullmatch || any_ambig(matches)) :
(!_all(b->b, matches.fullmatches) || any_ambig(matches))
# Account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature.
Expand Down Expand Up @@ -228,6 +236,7 @@ struct MethodMatches
valid_worlds::WorldRange
mt::Core.MethodTable
fullmatch::Bool
overlayed::Bool
end
any_ambig(info::MethodMatchInfo) = info.results.ambig
any_ambig(m::MethodMatches) = any_ambig(m.info)
Expand All @@ -239,6 +248,7 @@ struct UnionSplitMethodMatches
valid_worlds::WorldRange
mts::Vector{Core.MethodTable}
fullmatches::Vector{Bool}
overlayed::Bool
end
any_ambig(m::UnionSplitMethodMatches) = _any(any_ambig, m.info.matches)

Expand All @@ -253,16 +263,19 @@ function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), meth
valid_worlds = WorldRange()
mts = Core.MethodTable[]
fullmatches = Bool[]
overlayed = false
for i in 1:length(split_argtypes)
arg_n = split_argtypes[i]::Vector{Any}
sig_n = argtypes_to_type(arg_n)
mt = ccall(:jl_method_table_for, Any, (Any,), sig_n)
mt === nothing && return FailedMethodMatch("Could not identify method table for call")
mt = mt::Core.MethodTable
matches = findall(sig_n, method_table; limit = max_methods)
if matches === missing
result = findall(sig_n, method_table; limit = max_methods)
if result === missing
return FailedMethodMatch("For one of the union split cases, too many methods matched")
end
matches, overlayedᵢ = result
overlayed |= overlayedᵢ
push!(infos, MethodMatchInfo(matches))
for m in matches
push!(applicable, m)
Expand All @@ -288,25 +301,28 @@ function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), meth
UnionSplitInfo(infos),
valid_worlds,
mts,
fullmatches)
fullmatches,
overlayed)
else
mt = ccall(:jl_method_table_for, Any, (Any,), atype)
if mt === nothing
return FailedMethodMatch("Could not identify method table for call")
end
mt = mt::Core.MethodTable
matches = findall(atype, method_table; limit = max_methods)
if matches === missing
result = findall(atype, method_table; limit = max_methods)
if result === missing
# this means too many methods matched
# (assume this will always be true, so we don't compute / update valid age in this case)
return FailedMethodMatch("Too many methods matched")
end
matches, overlayed = result
fullmatch = _any(match->(match::MethodMatch).fully_covers, matches)
return MethodMatches(matches.matches,
MethodMatchInfo(matches),
matches.valid_worlds,
mt,
fullmatch)
fullmatch,
overlayed)
end
end

Expand Down Expand Up @@ -446,7 +462,7 @@ const RECURSION_MSG = "Bounded recursion detected. Call was widened to force con
function abstract_call_method(interp::AbstractInterpreter, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState)
if method.name === :depwarn && isdefined(Main, :Base) && method.module === Main.Base
add_remark!(interp, sv, "Refusing to infer into `depwarn`")
return MethodCallResult(Any, false, false, nothing, Effects())
return MethodCallResult(Any, false, false, nothing, Effects(true))
end
topmost = nothing
# Limit argument type tuple growth of functions:
Expand Down Expand Up @@ -515,7 +531,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
# we have a self-cycle in the call-graph, but not in the inference graph (typically):
# break this edge now (before we record it) by returning early
# (non-typically, this means that we lose the ability to detect a guaranteed StackOverflow in some cases)
return MethodCallResult(Any, true, true, nothing, Effects())
return MethodCallResult(Any, true, true, nothing, Effects(true))
end
topmost = nothing
edgecycle = true
Expand Down Expand Up @@ -564,7 +580,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
# since it's very unlikely that we'll try to inline this,
# or want make an invoke edge to its calling convention return type.
# (non-typically, this means that we lose the ability to detect a guaranteed StackOverflow in some cases)
return MethodCallResult(Any, true, true, nothing, Effects())
return MethodCallResult(Any, true, true, nothing, Effects(true))
end
add_remark!(interp, sv, RECURSION_MSG)
topmost = topmost::InferenceState
Expand Down Expand Up @@ -640,8 +656,7 @@ end

function pure_eval_eligible(interp::AbstractInterpreter,
@nospecialize(f), applicable::Vector{Any}, arginfo::ArgInfo, sv::InferenceState)
return !isoverlayed(method_table(interp)) &&
f !== nothing &&
return f !== nothing &&
length(applicable) == 1 &&
is_method_pure(applicable[1]::MethodMatch) &&
is_all_const_arg(arginfo)
Expand Down Expand Up @@ -677,10 +692,10 @@ end

function concrete_eval_eligible(interp::AbstractInterpreter,
@nospecialize(f), result::MethodCallResult, arginfo::ArgInfo, sv::InferenceState)
return !isoverlayed(method_table(interp)) &&
f !== nothing &&
return f !== nothing &&
result.edge !== nothing &&
is_total_or_error(result.edge_effects) &&
!result.edge_effects.overlayed &&
is_all_const_arg(arginfo)
end

Expand Down Expand Up @@ -1200,7 +1215,7 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::
if !isa(aft, Const) && !isa(aft, PartialOpaque) && (!isType(aftw) || has_free_typevars(aftw))
if !isconcretetype(aftw) || (aftw <: Builtin)
add_remark!(interp, sv, "Core._apply_iterate called on a function of a non-concrete type")
tristate_merge!(sv, Effects())
tristate_merge!(sv, Effects(true))
# bail now, since it seems unlikely that abstract_call will be able to do any better after splitting
# this also ensures we don't call abstract_call_gf_by_type below on an IntrinsicFunction or Builtin
return CallMeta(Any, false)
Expand Down Expand Up @@ -1477,7 +1492,7 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
types = rewrap_unionall(Tuple{ft, unwrap_unionall(types).parameters...}, types)::Type
nargtype = Tuple{ft, nargtype.parameters...}
argtype = Tuple{ft, argtype.parameters...}
match, valid_worlds = findsup(types, method_table(interp))
match, valid_worlds, overlayed = findsup(types, method_table(interp))
match === nothing && return CallMeta(Any, false)
update_valid_age!(sv, valid_worlds)
method = match.method
Expand All @@ -1495,7 +1510,8 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
# t, a = ti.parameters[i], argtypes′[i]
# argtypes′[i] = t ⊑ a ? t : a
# end
const_call_result = abstract_call_method_with_const_args(interp, result, singleton_type(ft′), arginfo, match, sv)
const_call_result = abstract_call_method_with_const_args(interp, result,
overlayed ? nothing : singleton_type(ft′), arginfo, match, sv)
const_result = nothing
if const_call_result !== nothing
if const_call_result.rt rt
Expand Down Expand Up @@ -1528,20 +1544,20 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
if call.rt === Bottom
tristate_merge!(sv, Effects(EFFECTS_TOTAL, nothrow=ALWAYS_FALSE))
else
tristate_merge!(sv, Effects())
tristate_merge!(sv, Effects(true))
end
end
return call
elseif f === modifyfield!
tristate_merge!(sv, Effects()) # TODO
tristate_merge!(sv, Effects(true)) # TODO
return abstract_modifyfield!(interp, argtypes, sv)
end
rt = abstract_call_builtin(interp, f, arginfo, sv, max_methods)
tristate_merge!(sv, builtin_effects(f, argtypes, rt))
return CallMeta(rt, false)
elseif isa(f, Core.OpaqueClosure)
# calling an OpaqueClosure about which we have no information returns no information
tristate_merge!(sv, Effects())
tristate_merge!(sv, Effects(true))
return CallMeta(Any, false)
elseif f === Core.kwfunc
if la == 2
Expand Down Expand Up @@ -1643,8 +1659,8 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::Part
match = MethodMatch(sig, Core.svec(), closure.source, sig <: rewrap_unionall(sigT, tt))
const_result = nothing
if !result.edgecycle
const_call_result = abstract_call_method_with_const_args(interp, result, nothing,
arginfo, match, sv)
const_call_result = abstract_call_method_with_const_args(interp, result,
nothing, arginfo, match, sv)
if const_call_result !== nothing
if const_call_result.rt rt
(; rt, const_result) = const_call_result
Expand Down Expand Up @@ -1674,16 +1690,16 @@ function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo,
if isa(ft, PartialOpaque)
newargtypes = copy(argtypes)
newargtypes[1] = ft.env
tristate_merge!(sv, Effects()) # TODO
tristate_merge!(sv, Effects(true)) # TODO
return abstract_call_opaque_closure(interp, ft, ArgInfo(arginfo.fargs, newargtypes), sv)
elseif (uft = unwrap_unionall(widenconst(ft)); isa(uft, DataType) && uft.name === typename(Core.OpaqueClosure))
tristate_merge!(sv, Effects()) # TODO
tristate_merge!(sv, Effects(true)) # TODO
return CallMeta(rewrap_unionall((uft::DataType).parameters[2], widenconst(ft)), false)
elseif f === nothing
# non-constant function, but the number of arguments is known
# and the ft is not a Builtin or IntrinsicFunction
if hasintersect(widenconst(ft), Union{Builtin, Core.OpaqueClosure})
tristate_merge!(sv, Effects())
tristate_merge!(sv, Effects(true))
add_remark!(interp, sv, "Could not identify method table for call")
return CallMeta(Any, false)
end
Expand Down
64 changes: 35 additions & 29 deletions base/compiler/methodtable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,18 @@ end
getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch

"""
findall(sig::Type, view::MethodTableView; limit::Int=typemax(Int)) -> MethodLookupResult or missing
findall(sig::Type, view::MethodTableView; limit::Int=typemax(Int)) ->
(matches::MethodLookupResult, overlayed::Bool) or missing
Find all methods in the given method table `view` that are applicable to the
given signature `sig`. If no applicable methods are found, an empty result is
returned. If the number of applicable methods exceeded the specified limit,
`missing` is returned.
Find all methods in the given method table `view` that are applicable to the given signature `sig`.
If no applicable methods are found, an empty result is returned.
If the number of applicable methods exceeded the specified limit, `missing` is returned.
`overlayed` indicates if any matching method is defined in an overlayed method table.
"""
function findall(@nospecialize(sig::Type), table::InternalMethodTable; limit::Int=Int(typemax(Int32)))
return _findall(sig, nothing, table.world, limit)
result = _findall(sig, nothing, table.world, limit)
result === missing && return missing
return result, false
end

function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int=Int(typemax(Int32)))
Expand All @@ -57,7 +60,7 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int
nr = length(result)
if nr 1 && result[nr].fully_covers
# no need to fall back to the internal method table
return result
return result, true
end
# fall back to the internal method table
fallback_result = _findall(sig, nothing, table.world, limit)
Expand All @@ -68,7 +71,7 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int
WorldRange(
max(result.valid_worlds.min_world, fallback_result.valid_worlds.min_world),
min(result.valid_worlds.max_world, fallback_result.valid_worlds.max_world)),
result.ambig | fallback_result.ambig)
result.ambig | fallback_result.ambig), !isempty(result)
end

function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt, limit::Int)
Expand All @@ -83,31 +86,38 @@ function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable},
end

"""
findsup(sig::Type, view::MethodTableView) -> Tuple{MethodMatch, WorldRange} or nothing
Find the (unique) method `m` such that `sig <: m.sig`, while being more
specific than any other method with the same property. In other words, find
the method which is the least upper bound (supremum) under the specificity/subtype
relation of the queried `signature`. If `sig` is concrete, this is equivalent to
asking for the method that will be called given arguments whose types match the
given signature. This query is also used to implement `invoke`.
Such a method `m` need not exist. It is possible that no method is an
upper bound of `sig`, or it is possible that among the upper bounds, there
is no least element. In both cases `nothing` is returned.
findsup(sig::Type, view::MethodTableView) ->
(match::MethodMatch, valid_worlds::WorldRange, overlayed::Bool) or nothing
Find the (unique) method such that `sig <: match.method.sig`, while being more
specific than any other method with the same property. In other words, find the method
which is the least upper bound (supremum) under the specificity/subtype relation of
the queried `sig`nature. If `sig` is concrete, this is equivalent to asking for the method
that will be called given arguments whose types match the given signature.
Note that this query is also used to implement `invoke`.
Such a matching method `match` doesn't necessarily exist.
It is possible that no method is an upper bound of `sig`, or
it is possible that among the upper bounds, there is no least element.
In both cases `nothing` is returned.
`overlayed` indicates if the matching method is defined in an overlayed method table.
"""
function findsup(@nospecialize(sig::Type), table::InternalMethodTable)
return _findsup(sig, nothing, table.world)
return (_findsup(sig, nothing, table.world)..., false)
end

function findsup(@nospecialize(sig::Type), table::OverlayMethodTable)
match, valid_worlds = _findsup(sig, table.mt, table.world)
match !== nothing && return match, valid_worlds
match !== nothing && return match, valid_worlds, true
# fall back to the internal method table
fallback_match, fallback_valid_worlds = _findsup(sig, nothing, table.world)
return fallback_match, WorldRange(
max(valid_worlds.min_world, fallback_valid_worlds.min_world),
min(valid_worlds.max_world, fallback_valid_worlds.max_world))
return (
fallback_match,
WorldRange(
max(valid_worlds.min_world, fallback_valid_worlds.min_world),
min(valid_worlds.max_world, fallback_valid_worlds.max_world)),
false)
end

function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt)
Expand All @@ -118,7 +128,3 @@ function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable},
valid_worlds = WorldRange(min_valid[], max_valid[])
return match, valid_worlds
end

isoverlayed(::MethodTableView) = error("unsatisfied MethodTableView interface")
isoverlayed(::InternalMethodTable) = false
isoverlayed(::OverlayMethodTable) = true
Loading

0 comments on commit 7ae3491

Please sign in to comment.