diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 22bc1cf046d98..a7602f8cd4134 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -13,6 +13,31 @@ call_result_unused(sv::InferenceState, currpc::Int) = isexpr(sv.src.code[currpc], :call) && isempty(sv.ssavalue_uses[currpc]) call_result_unused(si::StmtInfo) = !si.used +is_const_bool_or_bottom(@nospecialize(b)) = (isa(b, Const) && isa(b.val, Bool)) || b == Bottom +function can_propagate_conditional(@nospecialize(rt), argtypes::Vector{Any}) + isa(rt, InterConditional) || return false + if rt.slot > length(argtypes) + # In the vararg tail - can't be conditional + @assert isvarargtype(argtypes[end]) + return false + end + return isa(argtypes[rt.slot], Conditional) && + is_const_bool_or_bottom(rt.thentype) && is_const_bool_or_bottom(rt.thentype) +end + +function propagate_conditional(rt::InterConditional, cond::Conditional) + new_thentype = rt.thentype === Const(false) ? cond.elsetype : cond.thentype + new_elsetype = rt.elsetype === Const(true) ? cond.thentype : cond.elsetype + if rt.thentype == Bottom + @assert rt.elsetype != Bottom + return Conditional(cond.slot, Bottom, new_elsetype) + elseif rt.elsetype == Bottom + @assert rt.thentype != Bottom + return Conditional(cond.slot, new_thentype, Bottom) + end + return Conditional(cond.slot, new_thentype, new_elsetype) +end + function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), arginfo::ArgInfo, si::StmtInfo, @nospecialize(atype), sv::AbsIntState, max_methods::Int) @@ -156,6 +181,15 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), end @assert !(this_conditional isa Conditional || this_rt isa MustAlias) "invalid lattice element returned from inter-procedural context" seen += 1 + + if can_propagate_conditional(this_conditional, argtypes) + # The only case where we need to keep this in rt is where + # we can directly propagate the conditional to a slot argument + # that is not one of our arguments, otherwise we keep all the + # relevant information in `conditionals` below. + this_rt = this_conditional + end + rettype = rettype โŠ”โ‚š this_rt exctype = exctype โŠ”โ‚š this_exct if has_conditional(๐•ƒโ‚š, sv) && this_conditional !== Bottom && is_lattice_bool(๐•ƒโ‚š, rettype) && fargs !== nothing @@ -409,6 +443,9 @@ function from_interconditional(๐•ƒแตข::AbstractLattice, @nospecialize(rt), sv:: has_conditional(๐•ƒแตข, sv) || return widenconditional(rt) (; fargs, argtypes) = arginfo fargs === nothing && return widenconditional(rt) + if can_propagate_conditional(rt, argtypes) + return propagate_conditional(rt, argtypes[rt.slot]::Conditional) + end slot = 0 alias = nothing thentype = elsetype = Any @@ -2217,13 +2254,6 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f), end elseif is_return_type(f) return return_type_tfunc(interp, argtypes, si, sv) - elseif la == 2 && istopfunction(f, :!) - # handle Conditional propagation through !Bool - aty = argtypes[2] - if isa(aty, Conditional) - call = abstract_call_gf_by_type(interp, f, ArgInfo(fargs, Any[Const(f), Bool]), si, Tuple{typeof(f), Bool}, sv, max_methods) # make sure we've inferred `!(::Bool)` - return CallMeta(Conditional(aty.slot, aty.elsetype, aty.thentype), Any, call.effects, call.info) - end elseif la == 3 && istopfunction(f, :!==) # mark !== as exactly a negated call to === call = abstract_call_gf_by_type(interp, f, ArgInfo(fargs, Any[Const(f), Any, Any]), si, Tuple{typeof(f), Any, Any}, sv, max_methods) @@ -3194,7 +3224,7 @@ function update_bestguess!(interp::AbstractInterpreter, frame::InferenceState, # narrow representation of bestguess slightly to prepare for tmerge with rt if rt isa InterConditional && bestguess isa Const slot_id = rt.slot - old_id_type = slottypes[slot_id] + old_id_type = widenconditional(slottypes[slot_id]) if bestguess.val === true && rt.elsetype !== Bottom bestguess = InterConditional(slot_id, old_id_type, Bottom) elseif bestguess.val === false && rt.thentype !== Bottom diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index c358b1177251f..38011656e41ea 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -312,6 +312,9 @@ mutable struct InferenceState nargtypes = length(argtypes) for i = 1:nslots argtyp = (i > nargtypes) ? Bottom : argtypes[i] + if argtyp === Bool && has_conditional(typeinf_lattice(interp)) + argtyp = Conditional(i, Const(true), Const(false)) + end slottypes[i] = argtyp bb_vartable1[i] = VarState(argtyp, i > nargtypes) end diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index 28e883d83312c..b40f65ab3ca1d 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -227,10 +227,19 @@ end @nospecs shift_tfunc(๐•ƒ::AbstractLattice, x, y) = shift_tfunc(widenlattice(๐•ƒ), x, y) @nospecs shift_tfunc(::JLTypeLattice, x, y) = widenconst(x) +function not_tfunc(๐•ƒ::AbstractLattice, @nospecialize(b)) + if isa(b, Conditional) + return Conditional(b.slot, b.elsetype, b.thentype) + elseif isa(b, Const) + return Const(not_int(b.val)) + end + return math_tfunc(๐•ƒ, b) +end + add_tfunc(and_int, 2, 2, and_int_tfunc, 1) add_tfunc(or_int, 2, 2, or_int_tfunc, 1) add_tfunc(xor_int, 2, 2, math_tfunc, 1) -add_tfunc(not_int, 1, 1, math_tfunc, 0) # usually used as not_int(::Bool) to negate a condition +add_tfunc(not_int, 1, 1, not_tfunc, 0) # usually used as not_int(::Bool) to negate a condition add_tfunc(shl_int, 2, 2, shift_tfunc, 1) add_tfunc(lshr_int, 2, 2, shift_tfunc, 1) add_tfunc(ashr_int, 2, 2, shift_tfunc, 1) diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 8b6da828af54d..9ae98b884bef4 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -5866,3 +5866,8 @@ end bar54341(args...) = foo54341(4, args...) @test Core.Compiler.return_type(bar54341, Tuple{Vararg{Int}}) === Int + +# InterConditional rt with Vararg argtypes +fcondvarargs(a, b, c, d) = isa(d, Int64) +gcondvarargs(a, x...) = return fcondvarargs(a, x...) ? isa(a, Int64) : !isa(a, Int64) +@test Core.Compiler.return_type(gcondvarargs, Tuple{Vararg{Any}}) === Bool