Skip to content

Commit daa0849

Browse files
authored
AbstractInterpreter: make it easier to overload pure/concrete-eval (#44224)
fix #44174.
1 parent b8e5d7e commit daa0849

File tree

2 files changed

+88
-67
lines changed

2 files changed

+88
-67
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 80 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
6565
const_results = Union{InferenceResult,Nothing,ConstResult}[]
6666
multiple_matches = napplicable > 1
6767

68-
if f !== nothing && napplicable == 1 && is_method_pure(applicable[1]::MethodMatch)
69-
val = pure_eval_call(f, argtypes)
70-
if val !== nothing
71-
# TODO: add some sort of edge(s)
72-
return CallMeta(val, MethodResultPure(info))
73-
end
74-
end
68+
val = pure_eval_call(interp, f, applicable, arginfo, sv)
69+
val !== nothing && return CallMeta(val, MethodResultPure(info)) # TODO: add some sort of edge(s)
7570

7671
fargs = arginfo.fargs
7772
for i in 1:napplicable
@@ -619,27 +614,85 @@ struct MethodCallResult
619614
end
620615
end
621616

617+
function pure_eval_eligible(interp::AbstractInterpreter,
618+
@nospecialize(f), applicable::Vector{Any}, arginfo::ArgInfo, sv::InferenceState)
619+
return !isoverlayed(method_table(interp, sv)) &&
620+
f !== nothing &&
621+
length(applicable) == 1 &&
622+
is_method_pure(applicable[1]::MethodMatch) &&
623+
is_all_const_arg(arginfo)
624+
end
625+
626+
function is_method_pure(method::Method, @nospecialize(sig), sparams::SimpleVector)
627+
if isdefined(method, :generator)
628+
method.generator.expand_early || return false
629+
mi = specialize_method(method, sig, sparams)
630+
isa(mi, MethodInstance) || return false
631+
staged = get_staged(mi)
632+
(staged isa CodeInfo && (staged::CodeInfo).pure) || return false
633+
return true
634+
end
635+
return method.pure
636+
end
637+
is_method_pure(match::MethodMatch) = is_method_pure(match.method, match.spec_types, match.sparams)
638+
639+
function pure_eval_call(interp::AbstractInterpreter,
640+
@nospecialize(f), applicable::Vector{Any}, arginfo::ArgInfo, sv::InferenceState)
641+
pure_eval_eligible(interp, f, applicable, arginfo, sv) || return nothing
642+
return _pure_eval_call(f, arginfo)
643+
end
644+
function _pure_eval_call(@nospecialize(f), arginfo::ArgInfo)
645+
args = collect_const_args(arginfo)
646+
try
647+
value = Core._apply_pure(f, args)
648+
return Const(value)
649+
catch
650+
return nothing
651+
end
652+
end
653+
654+
function concrete_eval_eligible(interp::AbstractInterpreter,
655+
@nospecialize(f), result::MethodCallResult, arginfo::ArgInfo, sv::InferenceState)
656+
return !isoverlayed(method_table(interp, sv)) &&
657+
f !== nothing &&
658+
result.edge !== nothing &&
659+
is_total_or_error(result.edge_effects) &&
660+
is_all_const_arg(arginfo)
661+
end
662+
622663
function is_all_const_arg((; argtypes)::ArgInfo)
623-
for a in argtypes
624-
if !isa(a, Const) && !isconstType(a) && !issingletontype(a)
625-
return false
626-
end
664+
for i = 2:length(argtypes)
665+
a = widenconditional(argtypes[i])
666+
isa(a, Const) || isconstType(a) || issingletontype(a) || return false
627667
end
628668
return true
629669
end
630670

631-
function concrete_eval_const_proven_total_or_error(interp::AbstractInterpreter,
632-
@nospecialize(f), (; argtypes)::ArgInfo, _::InferenceState)
633-
args = Any[ (a = widenconditional(argtypes[i]);
634-
isa(a, Const) ? a.val :
635-
isconstType(a) ? (a::DataType).parameters[1] :
636-
(a::DataType).instance) for i in 2:length(argtypes) ]
671+
function collect_const_args((; argtypes)::ArgInfo)
672+
return Any[ let a = widenconditional(argtypes[i])
673+
isa(a, Const) ? a.val :
674+
isconstType(a) ? (a::DataType).parameters[1] :
675+
(a::DataType).instance
676+
end for i in 2:length(argtypes) ]
677+
end
678+
679+
function concrete_eval_call(interp::AbstractInterpreter,
680+
@nospecialize(f), result::MethodCallResult, arginfo::ArgInfo, sv::InferenceState)
681+
concrete_eval_eligible(interp, f, result, arginfo, sv) || return nothing
682+
args = collect_const_args(arginfo)
637683
try
638684
value = Core._call_in_world_total(get_world_counter(interp), f, args...)
639-
return Const(value)
640-
catch e
641-
return nothing
685+
if is_inlineable_constant(value) || call_result_unused(sv)
686+
# If the constant is not inlineable, still do the const-prop, since the
687+
# code that led to the creation of the Const may be inlineable in the same
688+
# circumstance and may be optimizable.
689+
return ConstCallResults(Const(value), ConstResult(result.edge, value), EFFECTS_TOTAL)
690+
end
691+
catch
692+
# The evaulation threw. By :consistent-cy, we're guaranteed this would have happened at runtime
693+
return ConstCallResults(Union{}, ConstResult(result.edge), result.edge_effects)
642694
end
695+
return nothing
643696
end
644697

645698
function const_prop_enabled(interp::AbstractInterpreter, sv::InferenceState, match::MethodMatch)
@@ -671,19 +724,10 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul
671724
if !const_prop_enabled(interp, sv, match)
672725
return nothing
673726
end
674-
if f !== nothing && result.edge !== nothing && is_total_or_error(result.edge_effects) && is_all_const_arg(arginfo)
675-
rt = concrete_eval_const_proven_total_or_error(interp, f, arginfo, sv)
727+
val = concrete_eval_call(interp, f, result, arginfo, sv)
728+
if val !== nothing
676729
add_backedge!(result.edge, sv)
677-
if rt === nothing
678-
# The evaulation threw. By :consistent-cy, we're guaranteed this would have happened at runtime
679-
return ConstCallResults(Union{}, ConstResult(result.edge), result.edge_effects)
680-
end
681-
if is_inlineable_constant(rt.val) || call_result_unused(sv)
682-
# If the constant is not inlineable, still do the const-prop, since the
683-
# code that led to the creation of the Const may be inlineable in the same
684-
# circumstance and may be optimizable.
685-
return ConstCallResults(rt, ConstResult(result.edge, rt.val), EFFECTS_TOTAL)
686-
end
730+
return val
687731
end
688732
mi = maybe_get_const_prop_profitable(interp, result, f, arginfo, match, sv)
689733
mi === nothing && return nothing
@@ -1218,36 +1262,6 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::
12181262
return CallMeta(res, retinfo)
12191263
end
12201264

1221-
function is_method_pure(method::Method, @nospecialize(sig), sparams::SimpleVector)
1222-
if isdefined(method, :generator)
1223-
method.generator.expand_early || return false
1224-
mi = specialize_method(method, sig, sparams)
1225-
isa(mi, MethodInstance) || return false
1226-
staged = get_staged(mi)
1227-
(staged isa CodeInfo && (staged::CodeInfo).pure) || return false
1228-
return true
1229-
end
1230-
return method.pure
1231-
end
1232-
is_method_pure(match::MethodMatch) = is_method_pure(match.method, match.spec_types, match.sparams)
1233-
1234-
function pure_eval_call(@nospecialize(f), argtypes::Vector{Any})
1235-
for i = 2:length(argtypes)
1236-
a = widenconditional(argtypes[i])
1237-
if !(isa(a, Const) || isconstType(a))
1238-
return nothing
1239-
end
1240-
end
1241-
args = Any[ (a = widenconditional(argtypes[i]);
1242-
isa(a, Const) ? a.val : (a::DataType).parameters[1]) for i in 2:length(argtypes) ]
1243-
try
1244-
value = Core._apply_pure(f, args)
1245-
return Const(value)
1246-
catch
1247-
return nothing
1248-
end
1249-
end
1250-
12511265
function argtype_by_index(argtypes::Vector{Any}, i::Int)
12521266
n = length(argtypes)
12531267
na = argtypes[n]
@@ -1586,8 +1600,10 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
15861600
elseif max_methods > 1 && istopfunction(f, :copyto!)
15871601
max_methods = 1
15881602
elseif la == 3 && istopfunction(f, :typejoin)
1589-
val = pure_eval_call(f, argtypes)
1590-
return CallMeta(val === nothing ? Type : val, MethodResultPure())
1603+
if is_all_const_arg(arginfo)
1604+
val = _pure_eval_call(f, arginfo)
1605+
return CallMeta(val === nothing ? Type : val, MethodResultPure())
1606+
end
15911607
end
15921608
atype = argtypes_to_type(argtypes)
15931609
return abstract_call_gf_by_type(interp, f, arginfo, atype, sv, max_methods)

base/compiler/methodtable.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int
8484
_min_val[] = typemin(UInt)
8585
_max_val[] = typemax(UInt)
8686
ms = _methods_by_ftype(sig, nothing, limit, table.world, false, _min_val, _max_val, _ambig)
87-
end
88-
if ms === false
89-
return missing
87+
if ms === false
88+
return missing
89+
end
9090
end
9191
return MethodLookupResult(ms::Vector{Any}, WorldRange(_min_val[], _max_val[]), _ambig[] != 0)
9292
end
@@ -123,3 +123,8 @@ end
123123

124124
# This query is not cached
125125
findsup(@nospecialize(sig::Type), table::CachedMethodTable) = findsup(sig, table.table)
126+
127+
isoverlayed(::MethodTableView) = error("unsatisfied MethodTableView interface")
128+
isoverlayed(::InternalMethodTable) = false
129+
isoverlayed(::OverlayMethodTable) = true
130+
isoverlayed(mt::CachedMethodTable) = isoverlayed(mt.table)

0 commit comments

Comments
 (0)