Skip to content

Commit

Permalink
inference: Remove special casing for !
Browse files Browse the repository at this point in the history
We have a handful of cases in inference where we look up functions by name
(using `istopfunction`) and give them special behavior. I'd like to remove
these. They're not only aesthetically ugly, but because they depend on binding
lookups, rather than values, they have unclear semantics as those bindings
change. They are also unsound should a user use the same name for something
different in their own top modules (of course, it's unlikely that a user would
do such a thing, but it's bad that they can't).

This particular PR removes the special case for `!`, which was there to
strengthen the inference result for `!` on Conditional. However, with
a little bit of strengthening of the rest of the system, this can be
equally well evaluated through the existing InterConditional mechanism.
  • Loading branch information
Keno committed Jul 31, 2024
1 parent 125bac4 commit e98f2ce
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 9 deletions.
46 changes: 38 additions & 8 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,31 @@ call_result_unused(sv::InferenceState, currpc::Int) =
isexpr(sv.src.code[currpc], :call) && isempty(sv.ssavalue_uses[currpc])
call_result_unused(si::StmtInfo) = !si.used

is_const_bool_or_bottom(@nospecialize(b)) = (isa(b, Const) && isa(b.val, Bool)) || b == Bottom
function can_propagate_conditional(@nospecialize(rt), argtypes::Vector{Any})
isa(rt, InterConditional) || return false
if rt.slot > length(argtypes)
# In the vararg tail - can't be conditional
@assert isvarargtype(argtypes[end])
return false
end
return isa(argtypes[rt.slot], Conditional) &&
is_const_bool_or_bottom(rt.thentype) && is_const_bool_or_bottom(rt.thentype)
end

function propagate_conditional(rt::InterConditional, cond::Conditional)
new_thentype = rt.thentype === Const(false) ? cond.elsetype : cond.thentype
new_elsetype = rt.elsetype === Const(true) ? cond.thentype : cond.elsetype
if rt.thentype == Bottom
@assert rt.elsetype != Bottom
return Conditional(cond.slot, Bottom, new_elsetype)
elseif rt.elsetype == Bottom
@assert rt.thentype != Bottom
return Conditional(cond.slot, new_thentype, Bottom)
end
return Conditional(cond.slot, new_thentype, new_elsetype)
end

function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, @nospecialize(atype),
sv::AbsIntState, max_methods::Int)
Expand Down Expand Up @@ -156,6 +181,15 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
@assert !(this_conditional isa Conditional || this_rt isa MustAlias) "invalid lattice element returned from inter-procedural context"
seen += 1

if can_propagate_conditional(this_conditional, argtypes)
# The only case where we need to keep this in rt is where
# we can directly propagate the conditional to a slot argument
# that is not one of our arguments, otherwise we keep all the
# relevant information in `conditionals` below.
this_rt = this_conditional
end

rettype = rettype βŠ”β‚š this_rt
exctype = exctype βŠ”β‚š this_exct
if has_conditional(π•ƒβ‚š, sv) && this_conditional !== Bottom && is_lattice_bool(π•ƒβ‚š, rettype) && fargs !== nothing
Expand Down Expand Up @@ -409,6 +443,9 @@ function from_interconditional(𝕃ᡒ::AbstractLattice, @nospecialize(rt), sv::
has_conditional(𝕃ᡒ, sv) || return widenconditional(rt)
(; fargs, argtypes) = arginfo
fargs === nothing && return widenconditional(rt)
if can_propagate_conditional(rt, argtypes)
return propagate_conditional(rt, argtypes[rt.slot]::Conditional)
end
slot = 0
alias = nothing
thentype = elsetype = Any
Expand Down Expand Up @@ -2217,13 +2254,6 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
end
elseif is_return_type(f)
return return_type_tfunc(interp, argtypes, si, sv)
elseif la == 2 && istopfunction(f, :!)
# handle Conditional propagation through !Bool
aty = argtypes[2]
if isa(aty, Conditional)
call = abstract_call_gf_by_type(interp, f, ArgInfo(fargs, Any[Const(f), Bool]), si, Tuple{typeof(f), Bool}, sv, max_methods) # make sure we've inferred `!(::Bool)`
return CallMeta(Conditional(aty.slot, aty.elsetype, aty.thentype), Any, call.effects, call.info)
end
elseif la == 3 && istopfunction(f, :!==)
# mark !== as exactly a negated call to ===
call = abstract_call_gf_by_type(interp, f, ArgInfo(fargs, Any[Const(f), Any, Any]), si, Tuple{typeof(f), Any, Any}, sv, max_methods)
Expand Down Expand Up @@ -3194,7 +3224,7 @@ function update_bestguess!(interp::AbstractInterpreter, frame::InferenceState,
# narrow representation of bestguess slightly to prepare for tmerge with rt
if rt isa InterConditional && bestguess isa Const
slot_id = rt.slot
old_id_type = slottypes[slot_id]
old_id_type = widenconditional(slottypes[slot_id])
if bestguess.val === true && rt.elsetype !== Bottom
bestguess = InterConditional(slot_id, old_id_type, Bottom)
elseif bestguess.val === false && rt.thentype !== Bottom
Expand Down
3 changes: 3 additions & 0 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,9 @@ mutable struct InferenceState
nargtypes = length(argtypes)
for i = 1:nslots
argtyp = (i > nargtypes) ? Bottom : argtypes[i]
if argtyp === Bool && has_conditional(typeinf_lattice(interp))
argtyp = Conditional(i, Const(true), Const(false))
end
slottypes[i] = argtyp
bb_vartable1[i] = VarState(argtyp, i > nargtypes)
end
Expand Down
11 changes: 10 additions & 1 deletion base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,19 @@ end
@nospecs shift_tfunc(𝕃::AbstractLattice, x, y) = shift_tfunc(widenlattice(𝕃), x, y)
@nospecs shift_tfunc(::JLTypeLattice, x, y) = widenconst(x)

function not_tfunc(𝕃::AbstractLattice, @nospecialize(b))
if isa(b, Conditional)
return Conditional(b.slot, b.elsetype, b.thentype)
elseif isa(b, Const)
return Const(not_int(b.val))
end
return math_tfunc(𝕃, b)
end

add_tfunc(and_int, 2, 2, and_int_tfunc, 1)
add_tfunc(or_int, 2, 2, or_int_tfunc, 1)
add_tfunc(xor_int, 2, 2, math_tfunc, 1)
add_tfunc(not_int, 1, 1, math_tfunc, 0) # usually used as not_int(::Bool) to negate a condition
add_tfunc(not_int, 1, 1, not_tfunc, 0) # usually used as not_int(::Bool) to negate a condition
add_tfunc(shl_int, 2, 2, shift_tfunc, 1)
add_tfunc(lshr_int, 2, 2, shift_tfunc, 1)
add_tfunc(ashr_int, 2, 2, shift_tfunc, 1)
Expand Down
5 changes: 5 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5866,3 +5866,8 @@ end
bar54341(args...) = foo54341(4, args...)

@test Core.Compiler.return_type(bar54341, Tuple{Vararg{Int}}) === Int

# InterConditional rt with Vararg argtypes
fcondvarargs(a, b, c, d) = isa(d, Int64)
gcondvarargs(a, x...) = return fcondvarargs(a, x...) ? isa(a, Int64) : !isa(a, Int64)
@test Core.Compiler.return_type(gcondvarargs, Tuple{Vararg{Any}}) === Bool

0 comments on commit e98f2ce

Please sign in to comment.