Skip to content

Build compiled methods to handle llvmcall #67

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "0.1.1"

[deps]
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[extras]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand Down
84 changes: 79 additions & 5 deletions src/JuliaInterpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@ import Base: +, -, convert, isless
using Core: CodeInfo, SSAValue, SlotNumber, TypeMapEntry, SimpleVector, LineInfoNode, GotoNode, Slot,
GeneratedFunctionStub, MethodInstance, NewvarNode, TypeName

using UUIDs

export @enter, @make_stack, @interpret, Compiled, JuliaStackFrame

module CompiledCalls
# This module is for handling intrinsics that must be compiled (llvmcall)
end

"""
`Compiled` is a trait indicating that any `:call` expressions should be evaluated
using Julia's normal compiled-code evaluation. The alternative is to pass `stack=JuliaStackFrame[]`,
Expand Down Expand Up @@ -50,7 +56,7 @@ Important fields:
struct JuliaFrameCode
scope::Union{Method,Module}
code::CodeInfo
methodtables::Vector{TypeMapEntry} # line-by-line method tables for generic-function :call Exprs
methodtables::Vector{Union{Compiled,TypeMapEntry}} # line-by-line method tables for generic-function :call Exprs
used::BitSet
wrapper::Bool
generator::Bool
Expand All @@ -63,10 +69,14 @@ function JuliaFrameCode(frame::JuliaFrameCode; wrapper = frame.wrapper, generato
wrapper, generator, fullpath)
end

function JuliaFrameCode(scope, code::CodeInfo; wrapper=false, generator=false, fullpath=true)
code = optimize!(copy_codeinfo(code), moduleof(scope))
function JuliaFrameCode(scope, code::CodeInfo; wrapper=false, generator=false, fullpath=true, optimize=true)
if optimize
code, methodtables = optimize!(copy_codeinfo(code), moduleof(scope))
else
code = copy_codeinfo(code)
methodtables = Vector{Union{Compiled,TypeMapEntry}}(undef, length(code.code))
end
used = find_used(code)
methodtables = Vector{TypeMapEntry}(undef, length(code.code))
return JuliaFrameCode(scope, code, methodtables, used, wrapper, generator, fullpath)
end

Expand Down Expand Up @@ -612,6 +622,29 @@ function renumber_ssa!(stmts::Vector{Any}, ssalookup)
return stmts
end

# Pre-frame-construction lookup
function lookup_stmt(stmts, arg)
if isa(arg, SSAValue)
arg = stmts[arg.id]
end
if isa(arg, QuoteNode)
arg = arg.value
end
return arg
end

function smallest_ref(stmts, arg, idmin)
if isa(arg, SSAValue)
idmin = min(idmin, arg.id)
return smallest_ref(stmts, stmts[arg.id], idmin)
elseif isa(arg, Expr)
for a in arg.args
idmin = smallest_ref(stmts, a, idmin)
end
end
return idmin
end

function lookup_global_refs!(ex::Expr)
(ex.head == :isdefined || ex.head == :thunk || ex.head == :toplevel) && return nothing
for (i, a) in enumerate(ex.args)
Expand Down Expand Up @@ -676,7 +709,48 @@ function optimize!(code::CodeInfo, mod::Module)
ssalookup = cumsum(ssainc)
renumber_ssa!(new_code, ssalookup)
code.ssavaluetypes = length(new_code)
return code

# Replace :llvmcall and :foreigncall with compiled variants. See
# https://github.com/JuliaDebug/JuliaInterpreter.jl/issues/13#issuecomment-464880123
methodtables = Vector{Union{Compiled,TypeMapEntry}}(undef, length(code.code))
for (idx, stmt) in enumerate(code.code)
if isexpr(stmt, :call)
# Check for :llvmcall
arg1 = stmt.args[1]
if arg1 == :llvmcall || lookup_stmt(code.code, arg1) == Base.llvmcall
uuid = uuid4()
ustr = replace(string(uuid), '-'=>'_')
methname = Symbol("llvmcall_", ustr)
nargs = length(stmt.args)-4
argnames = [Symbol("arg", string(i)) for i = 1:nargs]
# Run a mini-interpreter to extract the types
framecode = JuliaFrameCode(CompiledCalls, code; optimize=false)
frame = prepare_locals(framecode, [])
idxstart = idx
for i = 2:4
idxstart = smallest_ref(code.code, stmt.args[i], idxstart)
end
frame.pc[] = JuliaProgramCounter(idxstart)
while true
pc = step_expr!(Compiled(), frame)
convert(Int, pc) == idx && break
pc === nothing && error("this should never happen")
end
str, RetType, ArgType = @lookup(frame, stmt.args[2]), @lookup(frame, stmt.args[3]), @lookup(frame, stmt.args[4])
def = quote
function $methname($(argnames...))
return Base.llvmcall($str, $RetType, $ArgType, $(argnames...))
end
end
f = Core.eval(CompiledCalls, def)
stmt.args[1] = QuoteNode(f)
deleteat!(stmt.args, 2:4)
methodtables[idx] = Compiled()
end
end
end

return code, methodtables
end

function prepare_locals(framecode, argvals::Vector{Any})
Expand Down
9 changes: 9 additions & 0 deletions src/interpret.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,15 @@ function evaluate_call!(::Compiled, frame::JuliaStackFrame, call_expr::Expr, pc;
end

function evaluate_call!(stack, frame::JuliaStackFrame, call_expr::Expr, pc; exec!::Function=finish_and_return!)
idx = convert(Int, pc)
if isassigned(frame.code.methodtables, idx)
tme = frame.code.methodtables[idx]
if isa(tme, Compiled)
fargs = collect_args(frame, call_expr)
f = to_function(fargs[1])
return f(fargs[2:end]...)
end
end
ret = maybe_evaluate_builtin(frame, call_expr)
isa(ret, Some{Any}) && return ret.value
fargs = collect_args(frame, call_expr)
Expand Down
25 changes: 25 additions & 0 deletions test/interpret.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,28 @@ if isdefined(Core.Compiler, :SNCA)
cfg = Core.Compiler.compute_basic_blocks(ci.code)
@test isa(@interpret(Core.Compiler.SNCA(cfg)), Vector{Int})
end

# llvmcall
function add1234(x::Tuple{Int32,Int32,Int32,Int32})
Base.llvmcall("""%3 = extractvalue [4 x i32] %0, 0
%4 = extractvalue [4 x i32] %0, 1
%5 = extractvalue [4 x i32] %0, 2
%6 = extractvalue [4 x i32] %0, 3
%7 = extractvalue [4 x i32] %1, 0
%8 = extractvalue [4 x i32] %1, 1
%9 = extractvalue [4 x i32] %1, 2
%10 = extractvalue [4 x i32] %1, 3
%11 = add i32 %3, %7
%12 = add i32 %4, %8
%13 = add i32 %5, %9
%14 = add i32 %6, %10
%15 = insertvalue [4 x i32] undef, i32 %11, 0
%16 = insertvalue [4 x i32] %15, i32 %12, 1
%17 = insertvalue [4 x i32] %16, i32 %13, 2
%18 = insertvalue [4 x i32] %17, i32 %14, 3
ret [4 x i32] %18""",Tuple{Int32,Int32,Int32,Int32},
Tuple{Tuple{Int32,Int32,Int32,Int32},Tuple{Int32,Int32,Int32,Int32}},
(Int32(1),Int32(2),Int32(3),Int32(4)),
x)
end
@test @interpret(add1234(map(Int32,(2,3,4,5)))) === map(Int32,(3,5,7,9))