Skip to content

Commit

Permalink
some refactoring of abstract_call (#33774)
Browse files Browse the repository at this point in the history
- rename abstract_call to abstract_call_known
- rename abstract_eval_call to abstract_call
- check for pure_eval_call later, to try to avoid redundant work/code
- don't generate argument type lists with Vararg before the end; then
  we don't need to check for them
- reuse abstract_call more instead of duplicating some of its code
- remove some unused code
  • Loading branch information
JeffBezanson authored and KristofferC committed Apr 11, 2020
1 parent e18ddfe commit d2d0ba3
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 103 deletions.
172 changes: 80 additions & 92 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ function abstract_call_gf_by_type(@nospecialize(f), argtypes::Vector{Any}, @nosp
seen = 0 # number of signatures actually inferred
istoplevel = sv.linfo.def isa Module
multiple_matches = napplicable > 1

if f !== nothing && napplicable == 1 && is_method_pure(applicable[1][3], applicable[1][1], applicable[1][2])
val = pure_eval_call(f, argtypes)
if val !== false
return val
end
end

for i in 1:napplicable
match = applicable[i]::SimpleVector
method = match[3]::Method
Expand Down Expand Up @@ -506,7 +514,7 @@ function abstract_iteration(@nospecialize(itft), @nospecialize(itertype), vtypes
else
return Any[Vararg{Any}]
end
stateordonet = abstract_call(iteratef, nothing, Any[itft, itertype], vtypes, sv)
stateordonet = abstract_call_known(iteratef, nothing, Any[itft, itertype], vtypes, sv)
# Return Bottom if this is not an iterator.
# WARNING: Changes to the iteration protocol must be reflected here,
# this is not just an optimization.
Expand All @@ -525,7 +533,7 @@ function abstract_iteration(@nospecialize(itft), @nospecialize(itertype), vtypes
valtype = stateordonet.parameters[1]
statetype = stateordonet.parameters[2]
push!(ret, valtype)
stateordonet = abstract_call(iteratef, nothing, Any[Const(iteratef), itertype, statetype], vtypes, sv)
stateordonet = abstract_call_known(iteratef, nothing, Any[Const(iteratef), itertype, statetype], vtypes, sv)
stateordonet = widenconst(stateordonet)
end
if stateordonet === Nothing
Expand All @@ -542,7 +550,7 @@ function abstract_iteration(@nospecialize(itft), @nospecialize(itertype), vtypes
end
valtype = tmerge(valtype, nounion.parameters[1])
statetype = tmerge(statetype, nounion.parameters[2])
stateordonet = abstract_call(iteratef, nothing, Any[Const(iteratef), itertype, statetype], vtypes, sv)
stateordonet = abstract_call_known(iteratef, nothing, Any[Const(iteratef), itertype, statetype], vtypes, sv)
stateordonet = widenconst(stateordonet)
end
push!(ret, Vararg{valtype})
Expand Down Expand Up @@ -584,14 +592,16 @@ function abstract_apply(@nospecialize(itft), @nospecialize(aft), aargtypes::Vect
ctypes = ctypes´
end
for ct in ctypes
if isa(aft, Const)
rt = abstract_call(aft.val, nothing, ct, vtypes, sv, max_methods)
elseif isconstType(aft)
rt = abstract_call(aft.parameters[1], nothing, ct, vtypes, sv, max_methods)
else
astype = argtypes_to_type(ct)
rt = abstract_call_gf_by_type(nothing, ct, astype, sv, max_methods)
lct = length(ct)
# truncate argument list at the first Vararg
for i = 1:lct-1
if isvarargtype(ct[i])
ct[i] = tuple_tail_elem(ct[i], ct[(i+1):lct])
resize!(ct, i)
break
end
end
rt = abstract_call(nothing, ct, vtypes, sv, max_methods)
res = tmerge(res, rt)
if res === Any
break
Expand All @@ -600,33 +610,24 @@ function abstract_apply(@nospecialize(itft), @nospecialize(aft), aargtypes::Vect
return res
end

function pure_eval_call(@nospecialize(f), argtypes::Vector{Any}, @nospecialize(atype), sv::InferenceState)
for i = 2:length(argtypes)
a = widenconditional(argtypes[i])
if !(isa(a, Const) || isconstType(a))
return false
end
end

min_valid = UInt[typemin(UInt)]
max_valid = UInt[typemax(UInt)]
meth = _methods_by_ftype(atype, 1, sv.params.world, min_valid, max_valid)
if meth === false || length(meth) != 1
return false
end
meth = meth[1]::SimpleVector
sig = meth[1]::DataType
sparams = meth[2]::SimpleVector
method = meth[3]::Method

function is_method_pure(method::Method, @nospecialize(sig), sparams::SimpleVector)
if isdefined(method, :generator)
method.generator.expand_early || return false
mi = specialize_method(method, sig, sparams, false)
isa(mi, MethodInstance) || return false
staged = get_staged(mi)
(staged isa CodeInfo && (staged::CodeInfo).pure) || return false
elseif !method.pure
return false
return true
end
return method.pure
end

function pure_eval_call(@nospecialize(f), argtypes::Vector{Any})
for i = 2:length(argtypes)
a = widenconditional(argtypes[i])
if !(isa(a, Const) || isconstType(a))
return false
end
end

args = Any[ (a = widenconditional(argtypes[i]); isa(a, Const) ? a.val : a.parameters[1]) for i in 2:length(argtypes) ]
Expand Down Expand Up @@ -656,27 +657,21 @@ function argtype_tail(argtypes::Vector{Any}, i::Int)
return argtypes[i:n]
end

function abstract_call(@nospecialize(f), fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState, max_methods = sv.params.MAX_METHODS)
if f === _apply
ft = argtype_by_index(argtypes, 2)
ft === Bottom && return Bottom
return abstract_apply(nothing, ft, argtype_tail(argtypes, 3), vtypes, sv, max_methods)
elseif f === _apply_iterate
itft = argtype_by_index(argtypes, 2)
ft = argtype_by_index(argtypes, 3)
(itft === Bottom || ft === Bottom) && return Bottom
return abstract_apply(itft, ft, argtype_tail(argtypes, 4), vtypes, sv, max_methods)
end

# call where the function is known exactly
function abstract_call_known(@nospecialize(f), fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState, max_methods = sv.params.MAX_METHODS)
la = length(argtypes)
for i = 2:(la - 1)
if isvarargtype(argtypes[i])
return Any
end
end

if isa(f, Builtin) || isa(f, IntrinsicFunction)
if f === ifelse && fargs isa Vector{Any} && length(argtypes) == 4 && argtypes[2] isa Conditional
if isa(f, Builtin)
if f === _apply
ft = argtype_by_index(argtypes, 2)
ft === Bottom && return Bottom
return abstract_apply(nothing, ft, argtype_tail(argtypes, 3), vtypes, sv, max_methods)
elseif f === _apply_iterate
itft = argtype_by_index(argtypes, 2)
ft = argtype_by_index(argtypes, 3)
(itft === Bottom || ft === Bottom) && return Bottom
return abstract_apply(itft, ft, argtype_tail(argtypes, 4), vtypes, sv, max_methods)
elseif f === ifelse && fargs isa Vector{Any} && la == 4 && argtypes[2] isa Conditional
# try to simulate this as a real conditional (`cnd ? x : y`), so that the penalty for using `ifelse` instead isn't too high
cnd = argtypes[2]::Conditional
tx = argtypes[3]
Expand All @@ -692,7 +687,7 @@ function abstract_call(@nospecialize(f), fargs::Union{Nothing,Vector{Any}}, argt
return tmerge(tx, ty)
end
rt = builtin_tfunction(f, argtypes[2:end], sv)
if f === getfield && isa(fargs, Vector{Any}) && length(argtypes) == 3 && isa(argtypes[3], Const) && isa(argtypes[3].val, Int) && argtypes[2] Tuple
if f === getfield && isa(fargs, Vector{Any}) && la == 3 && isa(argtypes[3], Const) && isa(argtypes[3].val, Int) && argtypes[2] Tuple
cti = precise_container_type(nothing, argtypes[2], vtypes, sv)
idx = argtypes[3].val
if 1 <= idx <= length(cti)
Expand Down Expand Up @@ -767,7 +762,7 @@ function abstract_call(@nospecialize(f), fargs::Union{Nothing,Vector{Any}}, argt
end
return isa(rt, TypeVar) ? rt.ub : rt
elseif f === Core.kwfunc
if length(argtypes) == 2
if la == 2
ft = widenconst(argtypes[2])
if isa(ft, DataType) && isdefined(ft.name, :mt) && isdefined(ft.name.mt, :kwsorter)
return Const(ft.name.mt.kwsorter)
Expand All @@ -777,19 +772,19 @@ function abstract_call(@nospecialize(f), fargs::Union{Nothing,Vector{Any}}, argt
elseif f === TypeVar
# Manually look through the definition of TypeVar to
# make sure to be able to get `PartialTypeVar`s out.
(length(argtypes) < 2 || length(argtypes) > 4) && return Union{}
(la < 2 || la > 4) && return Union{}
n = argtypes[2]
ub_var = Const(Any)
lb_var = Const(Union{})
if length(argtypes) == 4
if la == 4
ub_var = argtypes[4]
lb_var = argtypes[3]
elseif length(argtypes) == 3
elseif la == 3
ub_var = argtypes[3]
end
return typevar_tfunc(n, lb_var, ub_var)
elseif f === UnionAll
if length(argtypes) == 3
if la == 3
canconst = true
if isa(argtypes[3], Const)
body = argtypes[3].val
Expand Down Expand Up @@ -826,23 +821,24 @@ function abstract_call(@nospecialize(f), fargs::Union{Nothing,Vector{Any}}, argt
if rt_rt !== nothing
return rt_rt
end
elseif length(argtypes) == 2 && istopfunction(f, :!)
return Type
elseif la == 2 && istopfunction(f, :!)
# handle Conditional propagation through !Bool
aty = argtypes[2]
if isa(aty, Conditional)
abstract_call_gf_by_type(f, Any[Const(f), Bool], Tuple{typeof(f), Bool}, sv) # make sure we've inferred `!(::Bool)`
return Conditional(aty.var, aty.elsetype, aty.vtype)
end
elseif length(argtypes) == 3 && istopfunction(f, :!==)
elseif la == 3 && istopfunction(f, :!==)
# mark !== as exactly a negated call to ===
rty = abstract_call((===), fargs, argtypes, vtypes, sv)
rty = abstract_call_known((===), fargs, argtypes, vtypes, sv)
if isa(rty, Conditional)
return Conditional(rty.var, rty.elsetype, rty.vtype) # swap if-else
elseif isa(rty, Const)
return Const(rty.val === false)
end
return rty
elseif length(argtypes) == 3 && istopfunction(f, :(>:))
elseif la == 3 && istopfunction(f, :(>:))
# mark issupertype as a exact alias for issubtype
# swap T1 and T2 arguments and call <:
if length(fargs) == 3
Expand All @@ -851,42 +847,36 @@ function abstract_call(@nospecialize(f), fargs::Union{Nothing,Vector{Any}}, argt
fargs = nothing
end
argtypes = Any[typeof(<:), argtypes[3], argtypes[2]]
rty = abstract_call(<:, fargs, argtypes, vtypes, sv)
rty = abstract_call_known(<:, fargs, argtypes, vtypes, sv)
return rty
elseif length(argtypes) == 2 && isa(argtypes[2], Const) && isa(argtypes[2].val, SimpleVector) && istopfunction(f, :length)
elseif la == 2 && isa(argtypes[2], Const) && isa(argtypes[2].val, SimpleVector) && istopfunction(f, :length)
# mark length(::SimpleVector) as @pure
return Const(length(argtypes[2].val))
elseif length(argtypes) == 3 && isa(argtypes[2], Const) && isa(argtypes[3], Const) &&
elseif la == 3 && isa(argtypes[2], Const) && isa(argtypes[3], Const) &&
isa(argtypes[2].val, SimpleVector) && isa(argtypes[3].val, Int) && istopfunction(f, :getindex)
# mark getindex(::SimpleVector, i::Int) as @pure
svecval = argtypes[2].val::SimpleVector
idx = argtypes[3].val::Int
if 1 <= idx <= length(svecval) && isassigned(svecval, idx)
return Const(getindex(svecval, idx))
end
elseif length(argtypes) == 2 && istopfunction(f, :typename)
elseif la == 2 && istopfunction(f, :typename)
return typename_static(argtypes[2])
elseif max_methods > 1 && istopfunction(f, :copyto!)
max_methods = 1
elseif la == 3 && istopfunction(f, :typejoin)
val = pure_eval_call(f, argtypes)
return val === false ? Type : val
end

atype = argtypes_to_type(argtypes)
t = pure_eval_call(f, argtypes, atype, sv)
t !== false && return t

if istopfunction(f, :typejoin) || is_return_type(f)
return Type # don't try to infer these function edges directly -- it won't actually come up with anything useful
end

return abstract_call_gf_by_type(f, argtypes, atype, sv, max_methods)
end

# wrapper around `abstract_call` for first computing if `f` is available
function abstract_eval_call(fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState)
# call where the function is any lattice element
function abstract_call(fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState,
max_methods = sv.params.MAX_METHODS)
#print("call ", e.args[1], argtypes, "\n\n")
for x in argtypes
x === Bottom && return Bottom
end
ft = argtypes[1]
if isa(ft, Const)
f = ft.val
Expand All @@ -895,19 +885,14 @@ function abstract_eval_call(fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{
elseif isa(ft, DataType) && isdefined(ft, :instance)
f = ft.instance
else
for i = 2:(length(argtypes) - 1)
if isvarargtype(argtypes[i])
return Any
end
end
# non-constant function, but the number of arguments is known
# and the ft is not a Builtin or IntrinsicFunction
if typeintersect(widenconst(ft), Builtin) != Union{}
return Any
end
return abstract_call_gf_by_type(nothing, argtypes, argtypes_to_type(argtypes), sv)
return abstract_call_gf_by_type(nothing, argtypes, argtypes_to_type(argtypes), sv, max_methods)
end
return abstract_call(f, fargs, argtypes, vtypes, sv)
return abstract_call_known(f, fargs, argtypes, vtypes, sv, max_methods)
end

function sp_type_rewrap(@nospecialize(T), linfo::MethodInstance, isreturn::Bool)
Expand Down Expand Up @@ -956,7 +941,7 @@ function abstract_eval_cfunction(e::Expr, vtypes::VarTable, sv::InferenceState)
# this may be the wrong world for the call,
# but some of the result is likely to be valid anyways
# and that may help generate better codegen
abstract_eval_call(nothing, at, vtypes, sv)
abstract_call(nothing, at, vtypes, sv)
nothing
end

Expand All @@ -976,8 +961,17 @@ function abstract_eval(@nospecialize(e), vtypes::VarTable, sv::InferenceState)
end
e = e::Expr
if e.head === :call
argtypes = Any[ abstract_eval(a, vtypes, sv) for a in e.args ]
t = abstract_eval_call(e.args, argtypes, vtypes, sv)
ea = e.args
n = length(ea)
argtypes = Vector{Any}(undef, n)
@inbounds for i = 1:n
ai = abstract_eval(ea[i], vtypes, sv)
if ai === Bottom
return Bottom
end
argtypes[i] = ai
end
t = abstract_call(ea, argtypes, vtypes, sv)
elseif e.head === :new
t = instanceof_tfunc(abstract_eval(e.args[1], vtypes, sv))[1]
if isconcretetype(t) && !t.mutable
Expand Down Expand Up @@ -1103,12 +1097,6 @@ function abstract_eval_ssavalue(s::SSAValue, src::CodeInfo)
return typ
end

# determine whether `ex` abstractly evals to constant `c`
function abstract_evals_to_constant(@nospecialize(ex), @nospecialize(c), vtypes::VarTable, sv::InferenceState)
av = abstract_eval(ex, vtypes, sv)
return isa(av, Const) && av.val === c
end

# make as much progress on `frame` as possible (without handling cycles)
function typeinf_local(frame::InferenceState)
@assert !frame.inferred
Expand Down
9 changes: 1 addition & 8 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1441,14 +1441,7 @@ function return_type_tfunc(argtypes::Vector{Any}, vtypes::VarTable, sv::Inferenc
if contains_is(argtypes_vec, Union{})
return Const(Union{})
end
astype = argtypes_to_type(argtypes_vec)
if isa(aft, Const)
rt = abstract_call(aft.val, nothing, argtypes_vec, vtypes, sv, -1)
elseif isconstType(aft)
rt = abstract_call(aft.parameters[1], nothing, argtypes_vec, vtypes, sv, -1)
else
rt = abstract_call_gf_by_type(nothing, argtypes_vec, astype, sv, -1)
end
rt = abstract_call(nothing, argtypes_vec, vtypes, sv, -1)
if isa(rt, Const)
# output was computed to be constant
return Const(typeof(rt.val))
Expand Down
5 changes: 2 additions & 3 deletions base/compiler/typeutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,10 @@ _typename(union::UnionAll) = _typename(union.body)
_typename(a::DataType) = Const(a.name)

function tuple_tail_elem(@nospecialize(init), ct::Vector{Any})
# FIXME: this is broken: it violates subtyping relations and creates invalid types with free typevars
tmerge_maybe_vararg(@nospecialize(a), @nospecialize(b)) = tmerge(a, tvar_extent(unwrapva(b)))
t = init
for x in ct
t = tmerge_maybe_vararg(t, x)
# FIXME: this is broken: it violates subtyping relations and creates invalid types with free typevars
t = tmerge(t, tvar_extent(unwrapva(x)))
end
return Vararg{widenconst(t)}
end
Expand Down

0 comments on commit d2d0ba3

Please sign in to comment.