Skip to content

Commit d2d0ba3

Browse files
JeffBezansonKristofferC
authored andcommitted
some refactoring of abstract_call (#33774)
- rename abstract_call to abstract_call_known - rename abstract_eval_call to abstract_call - check for pure_eval_call later, to try to avoid redundant work/code - don't generate argument type lists with Vararg before the end; then we don't need to check for them - reuse abstract_call more instead of duplicating some of its code - remove some unused code
1 parent e18ddfe commit d2d0ba3

File tree

3 files changed

+83
-103
lines changed

3 files changed

+83
-103
lines changed

base/compiler/abstractinterpretation.jl

+80-92
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,14 @@ function abstract_call_gf_by_type(@nospecialize(f), argtypes::Vector{Any}, @nosp
6464
seen = 0 # number of signatures actually inferred
6565
istoplevel = sv.linfo.def isa Module
6666
multiple_matches = napplicable > 1
67+
68+
if f !== nothing && napplicable == 1 && is_method_pure(applicable[1][3], applicable[1][1], applicable[1][2])
69+
val = pure_eval_call(f, argtypes)
70+
if val !== false
71+
return val
72+
end
73+
end
74+
6775
for i in 1:napplicable
6876
match = applicable[i]::SimpleVector
6977
method = match[3]::Method
@@ -506,7 +514,7 @@ function abstract_iteration(@nospecialize(itft), @nospecialize(itertype), vtypes
506514
else
507515
return Any[Vararg{Any}]
508516
end
509-
stateordonet = abstract_call(iteratef, nothing, Any[itft, itertype], vtypes, sv)
517+
stateordonet = abstract_call_known(iteratef, nothing, Any[itft, itertype], vtypes, sv)
510518
# Return Bottom if this is not an iterator.
511519
# WARNING: Changes to the iteration protocol must be reflected here,
512520
# this is not just an optimization.
@@ -525,7 +533,7 @@ function abstract_iteration(@nospecialize(itft), @nospecialize(itertype), vtypes
525533
valtype = stateordonet.parameters[1]
526534
statetype = stateordonet.parameters[2]
527535
push!(ret, valtype)
528-
stateordonet = abstract_call(iteratef, nothing, Any[Const(iteratef), itertype, statetype], vtypes, sv)
536+
stateordonet = abstract_call_known(iteratef, nothing, Any[Const(iteratef), itertype, statetype], vtypes, sv)
529537
stateordonet = widenconst(stateordonet)
530538
end
531539
if stateordonet === Nothing
@@ -542,7 +550,7 @@ function abstract_iteration(@nospecialize(itft), @nospecialize(itertype), vtypes
542550
end
543551
valtype = tmerge(valtype, nounion.parameters[1])
544552
statetype = tmerge(statetype, nounion.parameters[2])
545-
stateordonet = abstract_call(iteratef, nothing, Any[Const(iteratef), itertype, statetype], vtypes, sv)
553+
stateordonet = abstract_call_known(iteratef, nothing, Any[Const(iteratef), itertype, statetype], vtypes, sv)
546554
stateordonet = widenconst(stateordonet)
547555
end
548556
push!(ret, Vararg{valtype})
@@ -584,14 +592,16 @@ function abstract_apply(@nospecialize(itft), @nospecialize(aft), aargtypes::Vect
584592
ctypes = ctypes´
585593
end
586594
for ct in ctypes
587-
if isa(aft, Const)
588-
rt = abstract_call(aft.val, nothing, ct, vtypes, sv, max_methods)
589-
elseif isconstType(aft)
590-
rt = abstract_call(aft.parameters[1], nothing, ct, vtypes, sv, max_methods)
591-
else
592-
astype = argtypes_to_type(ct)
593-
rt = abstract_call_gf_by_type(nothing, ct, astype, sv, max_methods)
595+
lct = length(ct)
596+
# truncate argument list at the first Vararg
597+
for i = 1:lct-1
598+
if isvarargtype(ct[i])
599+
ct[i] = tuple_tail_elem(ct[i], ct[(i+1):lct])
600+
resize!(ct, i)
601+
break
602+
end
594603
end
604+
rt = abstract_call(nothing, ct, vtypes, sv, max_methods)
595605
res = tmerge(res, rt)
596606
if res === Any
597607
break
@@ -600,33 +610,24 @@ function abstract_apply(@nospecialize(itft), @nospecialize(aft), aargtypes::Vect
600610
return res
601611
end
602612

603-
function pure_eval_call(@nospecialize(f), argtypes::Vector{Any}, @nospecialize(atype), sv::InferenceState)
604-
for i = 2:length(argtypes)
605-
a = widenconditional(argtypes[i])
606-
if !(isa(a, Const) || isconstType(a))
607-
return false
608-
end
609-
end
610-
611-
min_valid = UInt[typemin(UInt)]
612-
max_valid = UInt[typemax(UInt)]
613-
meth = _methods_by_ftype(atype, 1, sv.params.world, min_valid, max_valid)
614-
if meth === false || length(meth) != 1
615-
return false
616-
end
617-
meth = meth[1]::SimpleVector
618-
sig = meth[1]::DataType
619-
sparams = meth[2]::SimpleVector
620-
method = meth[3]::Method
621-
613+
function is_method_pure(method::Method, @nospecialize(sig), sparams::SimpleVector)
622614
if isdefined(method, :generator)
623615
method.generator.expand_early || return false
624616
mi = specialize_method(method, sig, sparams, false)
625617
isa(mi, MethodInstance) || return false
626618
staged = get_staged(mi)
627619
(staged isa CodeInfo && (staged::CodeInfo).pure) || return false
628-
elseif !method.pure
629-
return false
620+
return true
621+
end
622+
return method.pure
623+
end
624+
625+
function pure_eval_call(@nospecialize(f), argtypes::Vector{Any})
626+
for i = 2:length(argtypes)
627+
a = widenconditional(argtypes[i])
628+
if !(isa(a, Const) || isconstType(a))
629+
return false
630+
end
630631
end
631632

632633
args = Any[ (a = widenconditional(argtypes[i]); isa(a, Const) ? a.val : a.parameters[1]) for i in 2:length(argtypes) ]
@@ -656,27 +657,21 @@ function argtype_tail(argtypes::Vector{Any}, i::Int)
656657
return argtypes[i:n]
657658
end
658659

659-
function abstract_call(@nospecialize(f), fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState, max_methods = sv.params.MAX_METHODS)
660-
if f === _apply
661-
ft = argtype_by_index(argtypes, 2)
662-
ft === Bottom && return Bottom
663-
return abstract_apply(nothing, ft, argtype_tail(argtypes, 3), vtypes, sv, max_methods)
664-
elseif f === _apply_iterate
665-
itft = argtype_by_index(argtypes, 2)
666-
ft = argtype_by_index(argtypes, 3)
667-
(itft === Bottom || ft === Bottom) && return Bottom
668-
return abstract_apply(itft, ft, argtype_tail(argtypes, 4), vtypes, sv, max_methods)
669-
end
670-
660+
# call where the function is known exactly
661+
function abstract_call_known(@nospecialize(f), fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState, max_methods = sv.params.MAX_METHODS)
671662
la = length(argtypes)
672-
for i = 2:(la - 1)
673-
if isvarargtype(argtypes[i])
674-
return Any
675-
end
676-
end
677663

678-
if isa(f, Builtin) || isa(f, IntrinsicFunction)
679-
if f === ifelse && fargs isa Vector{Any} && length(argtypes) == 4 && argtypes[2] isa Conditional
664+
if isa(f, Builtin)
665+
if f === _apply
666+
ft = argtype_by_index(argtypes, 2)
667+
ft === Bottom && return Bottom
668+
return abstract_apply(nothing, ft, argtype_tail(argtypes, 3), vtypes, sv, max_methods)
669+
elseif f === _apply_iterate
670+
itft = argtype_by_index(argtypes, 2)
671+
ft = argtype_by_index(argtypes, 3)
672+
(itft === Bottom || ft === Bottom) && return Bottom
673+
return abstract_apply(itft, ft, argtype_tail(argtypes, 4), vtypes, sv, max_methods)
674+
elseif f === ifelse && fargs isa Vector{Any} && la == 4 && argtypes[2] isa Conditional
680675
# try to simulate this as a real conditional (`cnd ? x : y`), so that the penalty for using `ifelse` instead isn't too high
681676
cnd = argtypes[2]::Conditional
682677
tx = argtypes[3]
@@ -692,7 +687,7 @@ function abstract_call(@nospecialize(f), fargs::Union{Nothing,Vector{Any}}, argt
692687
return tmerge(tx, ty)
693688
end
694689
rt = builtin_tfunction(f, argtypes[2:end], sv)
695-
if f === getfield && isa(fargs, Vector{Any}) && length(argtypes) == 3 && isa(argtypes[3], Const) && isa(argtypes[3].val, Int) && argtypes[2] Tuple
690+
if f === getfield && isa(fargs, Vector{Any}) && la == 3 && isa(argtypes[3], Const) && isa(argtypes[3].val, Int) && argtypes[2] Tuple
696691
cti = precise_container_type(nothing, argtypes[2], vtypes, sv)
697692
idx = argtypes[3].val
698693
if 1 <= idx <= length(cti)
@@ -767,7 +762,7 @@ function abstract_call(@nospecialize(f), fargs::Union{Nothing,Vector{Any}}, argt
767762
end
768763
return isa(rt, TypeVar) ? rt.ub : rt
769764
elseif f === Core.kwfunc
770-
if length(argtypes) == 2
765+
if la == 2
771766
ft = widenconst(argtypes[2])
772767
if isa(ft, DataType) && isdefined(ft.name, :mt) && isdefined(ft.name.mt, :kwsorter)
773768
return Const(ft.name.mt.kwsorter)
@@ -777,19 +772,19 @@ function abstract_call(@nospecialize(f), fargs::Union{Nothing,Vector{Any}}, argt
777772
elseif f === TypeVar
778773
# Manually look through the definition of TypeVar to
779774
# make sure to be able to get `PartialTypeVar`s out.
780-
(length(argtypes) < 2 || length(argtypes) > 4) && return Union{}
775+
(la < 2 || la > 4) && return Union{}
781776
n = argtypes[2]
782777
ub_var = Const(Any)
783778
lb_var = Const(Union{})
784-
if length(argtypes) == 4
779+
if la == 4
785780
ub_var = argtypes[4]
786781
lb_var = argtypes[3]
787-
elseif length(argtypes) == 3
782+
elseif la == 3
788783
ub_var = argtypes[3]
789784
end
790785
return typevar_tfunc(n, lb_var, ub_var)
791786
elseif f === UnionAll
792-
if length(argtypes) == 3
787+
if la == 3
793788
canconst = true
794789
if isa(argtypes[3], Const)
795790
body = argtypes[3].val
@@ -826,23 +821,24 @@ function abstract_call(@nospecialize(f), fargs::Union{Nothing,Vector{Any}}, argt
826821
if rt_rt !== nothing
827822
return rt_rt
828823
end
829-
elseif length(argtypes) == 2 && istopfunction(f, :!)
824+
return Type
825+
elseif la == 2 && istopfunction(f, :!)
830826
# handle Conditional propagation through !Bool
831827
aty = argtypes[2]
832828
if isa(aty, Conditional)
833829
abstract_call_gf_by_type(f, Any[Const(f), Bool], Tuple{typeof(f), Bool}, sv) # make sure we've inferred `!(::Bool)`
834830
return Conditional(aty.var, aty.elsetype, aty.vtype)
835831
end
836-
elseif length(argtypes) == 3 && istopfunction(f, :!==)
832+
elseif la == 3 && istopfunction(f, :!==)
837833
# mark !== as exactly a negated call to ===
838-
rty = abstract_call((===), fargs, argtypes, vtypes, sv)
834+
rty = abstract_call_known((===), fargs, argtypes, vtypes, sv)
839835
if isa(rty, Conditional)
840836
return Conditional(rty.var, rty.elsetype, rty.vtype) # swap if-else
841837
elseif isa(rty, Const)
842838
return Const(rty.val === false)
843839
end
844840
return rty
845-
elseif length(argtypes) == 3 && istopfunction(f, :(>:))
841+
elseif la == 3 && istopfunction(f, :(>:))
846842
# mark issupertype as a exact alias for issubtype
847843
# swap T1 and T2 arguments and call <:
848844
if length(fargs) == 3
@@ -851,42 +847,36 @@ function abstract_call(@nospecialize(f), fargs::Union{Nothing,Vector{Any}}, argt
851847
fargs = nothing
852848
end
853849
argtypes = Any[typeof(<:), argtypes[3], argtypes[2]]
854-
rty = abstract_call(<:, fargs, argtypes, vtypes, sv)
850+
rty = abstract_call_known(<:, fargs, argtypes, vtypes, sv)
855851
return rty
856-
elseif length(argtypes) == 2 && isa(argtypes[2], Const) && isa(argtypes[2].val, SimpleVector) && istopfunction(f, :length)
852+
elseif la == 2 && isa(argtypes[2], Const) && isa(argtypes[2].val, SimpleVector) && istopfunction(f, :length)
857853
# mark length(::SimpleVector) as @pure
858854
return Const(length(argtypes[2].val))
859-
elseif length(argtypes) == 3 && isa(argtypes[2], Const) && isa(argtypes[3], Const) &&
855+
elseif la == 3 && isa(argtypes[2], Const) && isa(argtypes[3], Const) &&
860856
isa(argtypes[2].val, SimpleVector) && isa(argtypes[3].val, Int) && istopfunction(f, :getindex)
861857
# mark getindex(::SimpleVector, i::Int) as @pure
862858
svecval = argtypes[2].val::SimpleVector
863859
idx = argtypes[3].val::Int
864860
if 1 <= idx <= length(svecval) && isassigned(svecval, idx)
865861
return Const(getindex(svecval, idx))
866862
end
867-
elseif length(argtypes) == 2 && istopfunction(f, :typename)
863+
elseif la == 2 && istopfunction(f, :typename)
868864
return typename_static(argtypes[2])
869865
elseif max_methods > 1 && istopfunction(f, :copyto!)
870866
max_methods = 1
867+
elseif la == 3 && istopfunction(f, :typejoin)
868+
val = pure_eval_call(f, argtypes)
869+
return val === false ? Type : val
871870
end
872871

873872
atype = argtypes_to_type(argtypes)
874-
t = pure_eval_call(f, argtypes, atype, sv)
875-
t !== false && return t
876-
877-
if istopfunction(f, :typejoin) || is_return_type(f)
878-
return Type # don't try to infer these function edges directly -- it won't actually come up with anything useful
879-
end
880-
881873
return abstract_call_gf_by_type(f, argtypes, atype, sv, max_methods)
882874
end
883875

884-
# wrapper around `abstract_call` for first computing if `f` is available
885-
function abstract_eval_call(fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState)
876+
# call where the function is any lattice element
877+
function abstract_call(fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState,
878+
max_methods = sv.params.MAX_METHODS)
886879
#print("call ", e.args[1], argtypes, "\n\n")
887-
for x in argtypes
888-
x === Bottom && return Bottom
889-
end
890880
ft = argtypes[1]
891881
if isa(ft, Const)
892882
f = ft.val
@@ -895,19 +885,14 @@ function abstract_eval_call(fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{
895885
elseif isa(ft, DataType) && isdefined(ft, :instance)
896886
f = ft.instance
897887
else
898-
for i = 2:(length(argtypes) - 1)
899-
if isvarargtype(argtypes[i])
900-
return Any
901-
end
902-
end
903888
# non-constant function, but the number of arguments is known
904889
# and the ft is not a Builtin or IntrinsicFunction
905890
if typeintersect(widenconst(ft), Builtin) != Union{}
906891
return Any
907892
end
908-
return abstract_call_gf_by_type(nothing, argtypes, argtypes_to_type(argtypes), sv)
893+
return abstract_call_gf_by_type(nothing, argtypes, argtypes_to_type(argtypes), sv, max_methods)
909894
end
910-
return abstract_call(f, fargs, argtypes, vtypes, sv)
895+
return abstract_call_known(f, fargs, argtypes, vtypes, sv, max_methods)
911896
end
912897

913898
function sp_type_rewrap(@nospecialize(T), linfo::MethodInstance, isreturn::Bool)
@@ -956,7 +941,7 @@ function abstract_eval_cfunction(e::Expr, vtypes::VarTable, sv::InferenceState)
956941
# this may be the wrong world for the call,
957942
# but some of the result is likely to be valid anyways
958943
# and that may help generate better codegen
959-
abstract_eval_call(nothing, at, vtypes, sv)
944+
abstract_call(nothing, at, vtypes, sv)
960945
nothing
961946
end
962947

@@ -976,8 +961,17 @@ function abstract_eval(@nospecialize(e), vtypes::VarTable, sv::InferenceState)
976961
end
977962
e = e::Expr
978963
if e.head === :call
979-
argtypes = Any[ abstract_eval(a, vtypes, sv) for a in e.args ]
980-
t = abstract_eval_call(e.args, argtypes, vtypes, sv)
964+
ea = e.args
965+
n = length(ea)
966+
argtypes = Vector{Any}(undef, n)
967+
@inbounds for i = 1:n
968+
ai = abstract_eval(ea[i], vtypes, sv)
969+
if ai === Bottom
970+
return Bottom
971+
end
972+
argtypes[i] = ai
973+
end
974+
t = abstract_call(ea, argtypes, vtypes, sv)
981975
elseif e.head === :new
982976
t = instanceof_tfunc(abstract_eval(e.args[1], vtypes, sv))[1]
983977
if isconcretetype(t) && !t.mutable
@@ -1103,12 +1097,6 @@ function abstract_eval_ssavalue(s::SSAValue, src::CodeInfo)
11031097
return typ
11041098
end
11051099

1106-
# determine whether `ex` abstractly evals to constant `c`
1107-
function abstract_evals_to_constant(@nospecialize(ex), @nospecialize(c), vtypes::VarTable, sv::InferenceState)
1108-
av = abstract_eval(ex, vtypes, sv)
1109-
return isa(av, Const) && av.val === c
1110-
end
1111-
11121100
# make as much progress on `frame` as possible (without handling cycles)
11131101
function typeinf_local(frame::InferenceState)
11141102
@assert !frame.inferred

base/compiler/tfuncs.jl

+1-8
Original file line numberDiff line numberDiff line change
@@ -1441,14 +1441,7 @@ function return_type_tfunc(argtypes::Vector{Any}, vtypes::VarTable, sv::Inferenc
14411441
if contains_is(argtypes_vec, Union{})
14421442
return Const(Union{})
14431443
end
1444-
astype = argtypes_to_type(argtypes_vec)
1445-
if isa(aft, Const)
1446-
rt = abstract_call(aft.val, nothing, argtypes_vec, vtypes, sv, -1)
1447-
elseif isconstType(aft)
1448-
rt = abstract_call(aft.parameters[1], nothing, argtypes_vec, vtypes, sv, -1)
1449-
else
1450-
rt = abstract_call_gf_by_type(nothing, argtypes_vec, astype, sv, -1)
1451-
end
1444+
rt = abstract_call(nothing, argtypes_vec, vtypes, sv, -1)
14521445
if isa(rt, Const)
14531446
# output was computed to be constant
14541447
return Const(typeof(rt.val))

base/compiler/typeutils.jl

+2-3
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,10 @@ _typename(union::UnionAll) = _typename(union.body)
9595
_typename(a::DataType) = Const(a.name)
9696

9797
function tuple_tail_elem(@nospecialize(init), ct::Vector{Any})
98-
# FIXME: this is broken: it violates subtyping relations and creates invalid types with free typevars
99-
tmerge_maybe_vararg(@nospecialize(a), @nospecialize(b)) = tmerge(a, tvar_extent(unwrapva(b)))
10098
t = init
10199
for x in ct
102-
t = tmerge_maybe_vararg(t, x)
100+
# FIXME: this is broken: it violates subtyping relations and creates invalid types with free typevars
101+
t = tmerge(t, tvar_extent(unwrapva(x)))
103102
end
104103
return Vararg{widenconst(t)}
105104
end

0 commit comments

Comments
 (0)