Skip to content

Support execution of code from external abstract interpreters #52964

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions Compiler/src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,30 @@ include("bootstrap.jl")
include("reflection_interface.jl")
include("opaque_closure.jl")

abstract type AbstractCompiler end
const CompilerInstance = Union{AbstractCompiler, Nothing}
const NativeCompiler = Nothing

current_compiler() = ccall(:jl_get_current_task, Ref{Task}, ()).compiler::CompilerInstance

"""
abstract_interpreter(::CompilerInstance, world::UInt)

Construct an appropriate abstract interpreter for the given compiler instance.
"""
function abstract_interpreter end

abstract_interpreter(::Nothing, world::UInt) = NativeInterpreter(world)

"""
compiler_world(::CompilerInstance)

The compiler world to execute this compiler instance in.
"""

compiler_world(::Nothing) = unsafe_load(cglobal(:jl_typeinf_world, UInt))
compiler_world(::AbstractCompiler) = get_world_counter() # equivalent to invokelatest

macro __SOURCE_FILE__()
__source__.file === nothing && return nothing
return QuoteNode(__source__.file::Symbol)
Expand Down
29 changes: 29 additions & 0 deletions Compiler/src/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2391,6 +2391,33 @@ function abstract_throw_methoderror(interp::AbstractInterpreter, argtypes::Vecto
return Future(CallMeta(Union{}, exct, EFFECTS_THROWS, NoCallInfo()))
end

function abstract_call_within(interp::AbstractInterpreter, (; fargs, argtypes)::ArgInfo, si::StmtInfo,
sv::AbsIntState, max_methods::Int=get_max_methods(interp, sv))
if length(argtypes) < 2

return CallMeta(Union{}, Any, Effects(), NoCallInfo())
end
CT = argtypes[2]
other_compiler = singleton_type(CT)
if other_compiler === nothing
if CT isa Const
other_compiler = CT.val
else
# Compiler is not a singleton type result may depend on runtime configuration
add_remark!(interp, sv, "Skipped call_within since compiler plugin not constant")
return CallMeta(Any, Any, Effects(), NoCallInfo())
end
end
# Change world to one where our methods exist.
cworld = invokelatest(compiler_world, other_compiler)::UInt
other_interp = Core._call_in_world(cworld, abstract_interpreter, other_compiler, get_inference_world(interp))
other_fargs = fargs === nothing ? nothing : fargs[3:end]
other_arginfo = ArgInfo(other_fargs, argtypes[3:end])
call = Core._call_in_world(cworld, abstract_call, other_interp, other_arginfo, si, sv, max_methods)
# TODO: Edges? Effects?
return CallMeta(call.rt, call.exct, call.effects, WithinCallInfo(other_compiler, call.info))
end

const generic_getglobal_effects = Effects(EFFECTS_THROWS, consistent=ALWAYS_FALSE, inaccessiblememonly=ALWAYS_FALSE)
const generic_getglobal_exct = Union{ArgumentError, TypeError, ConcurrencyViolationError, UndefVarError}
function abstract_eval_getglobal(interp::AbstractInterpreter, sv::AbsIntState, saw_latestworld::Bool, @nospecialize(M), @nospecialize(s))
Expand Down Expand Up @@ -2631,6 +2658,8 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
return abstract_throw(interp, argtypes, sv)
elseif f === Core.throw_methoderror
return abstract_throw_methoderror(interp, argtypes, sv)
elseif f === Core._call_within
return abstract_call_within(interp, arginfo, si, sv, max_methods)
elseif f === Core.getglobal
return Future(abstract_eval_getglobal(interp, sv, si.saw_latestworld, argtypes))
elseif f === Core.setglobal!
Expand Down
5 changes: 5 additions & 0 deletions Compiler/src/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -497,4 +497,9 @@ function add_edges_impl(edges::Vector{Any}, info::GlobalAccessInfo)
push!(edges, info.b)
end

struct WithinCallInfo <: CallInfo
compiler::CompilerInstance
info::CallInfo
end

@specialize
12 changes: 12 additions & 0 deletions Compiler/src/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1233,6 +1233,18 @@ function collectinvokes!(wq::Vector{CodeInstance}, ci::CodeInfo)
end
end

# typeinf_ext_toplevel is going to be executed within `jl_typeinf_world`
function typeinf_ext_toplevel(compiler::CompilerInstance, mi::MethodInstance, world::UInt, source_mode::UInt8)
if compiler === nothing
return typeinf_ext_toplevel(abstract_interpreter(compiler, world), mi, source_mode)
else
# Change world to one where our methods exist.
cworld = invokelatest(compiler_world, compiler)::UInt
absint = Core._call_in_world(cworld, abstract_interpreter, compiler, world)
return Core._call_in_world(cworld, typeinf_ext_toplevel, absint, mi, source_mode)
end
end

# This is a bridge for the C code calling `jl_typeinf_func()` on a single Method match
function typeinf_ext_toplevel(mi::MethodInstance, world::UInt, source_mode::UInt8)
interp = NativeInterpreter(world)
Expand Down
13 changes: 9 additions & 4 deletions Compiler/test/newinterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ sessions. However it is an usual Julia object of the type `code_cache::IdDict{Me
making it easier for debugging and inspecting the compiler behavior.
"""
macro newinterp(InterpName, ephemeral_cache::Bool=false)
cache_token = QuoteNode(gensym(string(InterpName, "CacheToken")))
InterpCompilerName = esc(Symbol(string(InterpName, "Compiler")))
InterpCacheName = esc(Symbol(string(InterpName, "Cache")))
InterpName = esc(InterpName)
C = Core
Expand All @@ -27,31 +27,36 @@ macro newinterp(InterpName, ephemeral_cache::Bool=false)
end
$InterpCacheName() = $InterpCacheName(IdDict{$C.MethodInstance,$C.CodeInstance}())
end)
struct $InterpCompilerName <: $CC.AbstractCompiler end
$CC.abstract_interpreter(compiler::$InterpCompilerName, world::UInt) =
$InterpName(;world, compiler)
struct $InterpName <: $Compiler.AbstractInterpreter
meta # additional information
world::UInt
inf_params::$Compiler.InferenceParams
opt_params::$Compiler.OptimizationParams
inf_cache::Vector{$Compiler.InferenceResult}
$(ephemeral_cache && :(code_cache::$InterpCacheName))
compiler::$InterpCompilerName
function $InterpName(meta = nothing;
world::UInt = Base.get_world_counter(),
compiler::$InterpCompilerName = $InterpCompilerName(),
inf_params::$Compiler.InferenceParams = $Compiler.InferenceParams(),
opt_params::$Compiler.OptimizationParams = $Compiler.OptimizationParams(),
inf_cache::Vector{$Compiler.InferenceResult} = $Compiler.InferenceResult[],
$(ephemeral_cache ?
Expr(:kw, :(code_cache::$InterpCacheName), :($InterpCacheName())) :
Expr(:kw, :_, :nothing)))
return $(ephemeral_cache ?
:(new(meta, world, inf_params, opt_params, inf_cache, code_cache)) :
:(new(meta, world, inf_params, opt_params, inf_cache)))
:(new(meta, world, inf_params, opt_params, inf_cache, code_cache, compiler)) :
:(new(meta, world, inf_params, opt_params, inf_cache, compiler)))
end
end
$Compiler.InferenceParams(interp::$InterpName) = interp.inf_params
$Compiler.OptimizationParams(interp::$InterpName) = interp.opt_params
$Compiler.get_inference_world(interp::$InterpName) = interp.world
$Compiler.get_inference_cache(interp::$InterpName) = interp.inf_cache
$Compiler.cache_owner(::$InterpName) = $cache_token
$Compiler.cache_owner(interp::$InterpName) = interp.compiler
$(ephemeral_cache && quote
$Compiler.code_cache(interp::$InterpName) = $Compiler.WorldView(interp.code_cache, $Compiler.WorldRange(interp.world))
$Compiler.get(wvc::$Compiler.WorldView{$InterpCacheName}, mi::$C.MethodInstance, default) = get(wvc.cache.dict, mi, default)
Expand Down
67 changes: 67 additions & 0 deletions Compiler/test/plugins.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
original_load_path = copy(Base.LOAD_PATH)
pushfirst!(Base.LOAD_PATH, joinpath(@__DIR__, "plugins"))

using Test
using Tracer

# XXX: should these be in `Tracer/test/runtests.jl`?

function fib(x)
if x <= 1
return x
else
return fib(x-1) + fib(x-2)
end
end

let tr = trace(fib, 1)
@test tr.f == fib
@test tr.args == (1,)
child = only(tr.children)
@test child.f == Base.:<=
@test child.args == (1,1)
end

let tr = trace(fib, 2)
@test tr.f == fib
@test tr.args == (2,)
@test length(tr.children) == 6
end


using MultilineFusion

# XXX: should these be in `MultilineFusion/test/runtests.jl`?

function multiline(A, B)
C = A .* B
D = C .+ A
end

let A = ones(3,3)
B = ones(3)
@test (@inferred multiline_fusion(multiline, A, B))::Matrix{Float64} == multiline(A, B)
end

let (ir, _) = only(Base.code_ircode(multiline, (Matrix{Float64}, Vector{Float64}), optimize_until="compact 1"))
@test length(ir.stmts) == 5
@test ir.stmts[2][:stmt].args[1] == GlobalRef(Base, :materialize)
end

let (ir, _) = only(Base.code_ircode(multiline, (Matrix{Float64}, Vector{Float64}), optimize_until="compact 1", interp=MultilineFusion.MLFInterp()))
@test length(ir.stmts) == 4
end

# XXX: should these be in `CustomMethodTables/test/runtests.jl`?
using CustomMethodTables

Base.Experimental.@MethodTable(CustomMT)
Base.Experimental.@overlay CustomMT Base.sin(x::Float64) = Base.cos(x)

# FIXME: Currently doesn't infer and ends in "Skipped call_within since compiler plugin not constant"
overlay(f, args...) = CustomMethodTables.overlay(CustomMT, f, args...)
@test_broken overlay(sin, 1.0) == cos(1.0) # Bug in inference, not using the method_table for initial lookup
@test overlay((x)->sin(x), 1.0) == cos(1.0)

empty!(Base.LOAD_PATH)
append!(Base.LOAD_PATH, original_load_path)

Check warning on line 67 in Compiler/test/plugins.jl

View workflow job for this annotation

GitHub Actions / Check whitespace

Whitespace check

no trailing newline
3 changes: 3 additions & 0 deletions Compiler/test/plugins/MultilineFusion/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
name = "MultilineFusion"
uuid = "bb4966f2-fd13-4cc8-856b-cab8c274a504"
version = "0.1.0"
Loading
Loading