Skip to content

inference: Remove special casing for ! #55271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,31 @@ call_result_unused(sv::InferenceState, currpc::Int) =
isexpr(sv.src.code[currpc], :call) && isempty(sv.ssavalue_uses[currpc])
call_result_unused(si::StmtInfo) = !si.used

is_const_bool_or_bottom(@nospecialize(b)) = (isa(b, Const) && isa(b.val, Bool)) || b == Bottom
function can_propagate_conditional(@nospecialize(rt), argtypes::Vector{Any})
isa(rt, InterConditional) || return false
if rt.slot > length(argtypes)
# In the vararg tail - can't be conditional
@assert isvarargtype(argtypes[end])
return false
end
return isa(argtypes[rt.slot], Conditional) &&
is_const_bool_or_bottom(rt.thentype) && is_const_bool_or_bottom(rt.thentype)
end

function propagate_conditional(rt::InterConditional, cond::Conditional)
new_thentype = rt.thentype === Const(false) ? cond.elsetype : cond.thentype
new_elsetype = rt.elsetype === Const(true) ? cond.thentype : cond.elsetype
if rt.thentype == Bottom
@assert rt.elsetype != Bottom
return Conditional(cond.slot, Bottom, new_elsetype)
elseif rt.elsetype == Bottom
@assert rt.thentype != Bottom
return Conditional(cond.slot, new_thentype, Bottom)
end
return Conditional(cond.slot, new_thentype, new_elsetype)
end

function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, @nospecialize(atype),
sv::AbsIntState, max_methods::Int)
Expand Down Expand Up @@ -156,6 +181,15 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
@assert !(this_conditional isa Conditional || this_rt isa MustAlias) "invalid lattice element returned from inter-procedural context"
seen += 1

if can_propagate_conditional(this_conditional, argtypes)
# The only case where we need to keep this in rt is where
# we can directly propagate the conditional to a slot argument
# that is not one of our arguments, otherwise we keep all the
# relevant information in `conditionals` below.
this_rt = this_conditional
end

rettype = rettype ⊔ₚ this_rt
exctype = exctype ⊔ₚ this_exct
if has_conditional(𝕃ₚ, sv) && this_conditional !== Bottom && is_lattice_bool(𝕃ₚ, rettype) && fargs !== nothing
Expand Down Expand Up @@ -409,6 +443,9 @@ function from_interconditional(𝕃ᵢ::AbstractLattice, @nospecialize(rt), sv::
has_conditional(𝕃ᵢ, sv) || return widenconditional(rt)
(; fargs, argtypes) = arginfo
fargs === nothing && return widenconditional(rt)
if can_propagate_conditional(rt, argtypes)
return propagate_conditional(rt, argtypes[rt.slot]::Conditional)
end
slot = 0
alias = nothing
thentype = elsetype = Any
Expand Down Expand Up @@ -2217,13 +2254,6 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
end
elseif is_return_type(f)
return return_type_tfunc(interp, argtypes, si, sv)
elseif la == 2 && istopfunction(f, :!)
# handle Conditional propagation through !Bool
aty = argtypes[2]
if isa(aty, Conditional)
call = abstract_call_gf_by_type(interp, f, ArgInfo(fargs, Any[Const(f), Bool]), si, Tuple{typeof(f), Bool}, sv, max_methods) # make sure we've inferred `!(::Bool)`
return CallMeta(Conditional(aty.slot, aty.elsetype, aty.thentype), Any, call.effects, call.info)
end
elseif la == 3 && istopfunction(f, :!==)
# mark !== as exactly a negated call to ===
call = abstract_call_gf_by_type(interp, f, ArgInfo(fargs, Any[Const(f), Any, Any]), si, Tuple{typeof(f), Any, Any}, sv, max_methods)
Expand Down Expand Up @@ -3194,7 +3224,7 @@ function update_bestguess!(interp::AbstractInterpreter, frame::InferenceState,
# narrow representation of bestguess slightly to prepare for tmerge with rt
if rt isa InterConditional && bestguess isa Const
slot_id = rt.slot
old_id_type = slottypes[slot_id]
old_id_type = widenconditional(slottypes[slot_id])
if bestguess.val === true && rt.elsetype !== Bottom
bestguess = InterConditional(slot_id, old_id_type, Bottom)
elseif bestguess.val === false && rt.thentype !== Bottom
Expand Down
3 changes: 3 additions & 0 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,9 @@ mutable struct InferenceState
nargtypes = length(argtypes)
for i = 1:nslots
argtyp = (i > nargtypes) ? Bottom : argtypes[i]
if argtyp === Bool && has_conditional(typeinf_lattice(interp))
argtyp = Conditional(i, Const(true), Const(false))
end
slottypes[i] = argtyp
bb_vartable1[i] = VarState(argtyp, i > nargtypes)
end
Expand Down
11 changes: 10 additions & 1 deletion base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,19 @@ end
@nospecs shift_tfunc(𝕃::AbstractLattice, x, y) = shift_tfunc(widenlattice(𝕃), x, y)
@nospecs shift_tfunc(::JLTypeLattice, x, y) = widenconst(x)

function not_tfunc(𝕃::AbstractLattice, @nospecialize(b))
if isa(b, Conditional)
return Conditional(b.slot, b.elsetype, b.thentype)
elseif isa(b, Const)
return Const(not_int(b.val))
end
return math_tfunc(𝕃, b)
end

add_tfunc(and_int, 2, 2, and_int_tfunc, 1)
add_tfunc(or_int, 2, 2, or_int_tfunc, 1)
add_tfunc(xor_int, 2, 2, math_tfunc, 1)
add_tfunc(not_int, 1, 1, math_tfunc, 0) # usually used as not_int(::Bool) to negate a condition
add_tfunc(not_int, 1, 1, not_tfunc, 0) # usually used as not_int(::Bool) to negate a condition
add_tfunc(shl_int, 2, 2, shift_tfunc, 1)
add_tfunc(lshr_int, 2, 2, shift_tfunc, 1)
add_tfunc(ashr_int, 2, 2, shift_tfunc, 1)
Expand Down
5 changes: 5 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5866,3 +5866,8 @@ end
bar54341(args...) = foo54341(4, args...)

@test Core.Compiler.return_type(bar54341, Tuple{Vararg{Int}}) === Int

# InterConditional rt with Vararg argtypes
fcondvarargs(a, b, c, d) = isa(d, Int64)
gcondvarargs(a, x...) = return fcondvarargs(a, x...) ? isa(a, Int64) : !isa(a, Int64)
@test Core.Compiler.return_type(gcondvarargs, Tuple{Vararg{Any}}) === Bool
Loading