Skip to content

Commit

Permalink
More inferrability improvements
Browse files Browse the repository at this point in the history
Detected as part of my work on JuliaLang/julia#37163
  • Loading branch information
timholy committed Aug 23, 2020
1 parent 0440d20 commit 07ffecb
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
12 changes: 10 additions & 2 deletions src/Cthulhu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ using UUIDs
using Core: MethodInstance
const Compiler = Core.Compiler

if isdefined(Base, :mapany)
const mapany = Base.mapany
else
mapany(f, itr) = map!(f, Vector{Any}(undef, length(itr)::Int), itr) # convenient for Expr.args
end

Base.@kwdef mutable struct CthulhuConfig
enable_highlighter::Bool = false
highlighter::Cmd = `pygmentize -l`
Expand Down Expand Up @@ -174,7 +180,9 @@ function _descend(mi::MethodInstance; iswarn::Bool, params=current_params(), opt
callsite = callsites[cid]

if callsite.info isa MultiCallInfo
sub_callsites = map(ci->Callsite(callsite.id, ci), callsite.info.callinfos)
sub_callsites = let callsite=callsite
map(ci->Callsite(callsite.id, ci), callsite.info.callinfos)
end
if isempty(sub_callsites)
@warn "Expected multiple callsites, but found none. Please fill an issue with a reproducing example" callsite.info
continue
Expand Down Expand Up @@ -231,7 +239,7 @@ function _descend(mi::MethodInstance; iswarn::Bool, params=current_params(), opt
id = Base.PkgId(UUID("295af30f-e4ad-537b-8983-00126c2a3abe"), "Revise")
mod = get(Base.loaded_modules, id, nothing)
if mod !== nothing
revise = getfield(mod, :revise)
revise = getfield(mod, :revise)::Function
revise()
mi = first_method_instance(mi.specTypes)
end
Expand Down
16 changes: 8 additions & 8 deletions src/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ function transform(::Val{:CuFunction}, callsite, callexpr, CI, mi, slottypes; pa
return Callsite(callsite.id, CuCallInfo(callinfo(Tuple{widenconst(ft), tt.val.parameters...}, Nothing, params=params)))
end

function find_callsites(CI, mi, slottypes; params=current_params(), multichoose::Bool=false, kwargs...)
function find_callsites(CI::Core.CodeInfo, mi::Core.MethodInstance, slottypes; params=current_params(), multichoose::Bool=false, kwargs...)
sptypes = sptypes_from_meth_instance(mi)
callsites = Callsite[]

Expand Down Expand Up @@ -87,7 +87,7 @@ function find_callsites(CI, mi, slottypes; params=current_params(), multichoose:
end
elseif c.head === :call
rt = CI.ssavaluetypes[id]
types = map(arg -> widenconst(argextype(arg, CI, sptypes, slottypes)), c.args)
types = mapany(arg -> widenconst(argextype(arg, CI, sptypes, slottypes)), c.args)

# Look through _apply
ok = true
Expand Down Expand Up @@ -142,9 +142,9 @@ function find_callsites(CI, mi, slottypes; params=current_params(), multichoose:
end
thatcher(types[1])
sigs = let types=types
map(ft-> [ft, types[2:end]...], fts)
mapany(ft-> Any[ft, types[2:end]...], fts)
end
cis = map(t -> callinfo(Tuple{t...}, rt, params=params), sigs)
cis = CallInfo[callinfo(Tuple{t...}, rt, params=params) for t in sigs]
callsite = Callsite(id, MultiCallInfo(Tuple{types...}, rt, cis))
else
ft = Base.unwrap_unionall(types[1])
Expand Down Expand Up @@ -201,30 +201,30 @@ if isdefined(Core.Compiler, :AbstractInterpreter)
ccall(:jl_typeinf_begin, Cvoid, ())
result = Core.Compiler.InferenceResult(mi)
frame = Core.Compiler.InferenceState(result, false, interp)
frame === nothing && return (nothing, Any)
frame === nothing && return (nothing, Any, Any[])
if Compiler.typeinf(interp, frame) && run_optimizer
oparams = Core.Compiler.OptimizationParams(interp)
opt = Compiler.OptimizationState(frame, oparams, interp)
Compiler.optimize(opt, oparams, result.result)
opt.src.inferred = true
end
ccall(:jl_typeinf_end, Cvoid, ())
frame.inferred || return (nothing, Any)
frame.inferred || return (nothing, Any, Any[])
return (frame.src, result.result, frame.slottypes)
end
else
function do_typeinf_slottypes(mi::Core.Compiler.MethodInstance, run_optimizer::Bool, params::Core.Compiler.Params)
ccall(:jl_typeinf_begin, Cvoid, ())
result = Core.Compiler.InferenceResult(mi)
frame = Core.Compiler.InferenceState(result, false, params)
frame === nothing && return (nothing, Any)
frame === nothing && return (nothing, Any, Any[])
if Compiler.typeinf(frame) && run_optimizer
opt = Compiler.OptimizationState(frame)
Compiler.optimize(opt, result.result)
opt.src.inferred = true
end
ccall(:jl_typeinf_end, Cvoid, ())
frame.inferred || return (nothing, Any)
frame.inferred || return (nothing, Any, Any[])
return (frame.src, result.result, frame.slottypes)
end
end
Expand Down

0 comments on commit 07ffecb

Please sign in to comment.