Skip to content

Commit d8eae23

Browse files
committed
Switch to GPUCompiler.jl
1 parent 9dec4f0 commit d8eae23

File tree

6 files changed

+268
-217
lines changed

6 files changed

+268
-217
lines changed

enzyme/julia/Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu
44
version = "0.1.0"
55

66
[deps]
7+
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
78
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
89
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
9-
MCAnalyzer = "a81df072-f4bb-11e8-03d3-cfaeda626d18"
1010

1111
[compat]
12-
julia = "1.3"
12+
GPUCompiler = "0.1"
1313
LLVM = "1.3"
14-
MCAnalyzer = "0.1"
14+
julia = "1.3"

enzyme/julia/src/Enzyme.jl

Lines changed: 88 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -4,105 +4,113 @@ export autodiff
44

55
using LLVM
66
using LLVM.Interop
7-
import MCAnalyzer: irgen
87

98
include("utils.jl")
109
include("ad.jl")
11-
include("opt.jl")
12-
13-
using .Opt: optimize!
14-
15-
function emit(f, args)
16-
# Obtain the function and all it's dependencies in one handy module
17-
diffetypes = []
18-
autodifftypes = Type[f]
19-
i = 1
20-
while i <= length(args)
21-
push!(autodifftypes, args[i])
22-
dt = whatType(args[i])
23-
push!(diffetypes, dt)
24-
if dt == "diffe_dup"
10+
include("compiler.jl")
11+
12+
struct Thunk{F, RT, TT, LLVMF}
13+
mod::LLVM.Module
14+
entry::LLVM.Function
15+
16+
function Thunk(f, rt, tt)
17+
# Drop from the signature the Enzyme calling-convetion
18+
diffetypes = []
19+
autodifftypes = []
20+
i = 1
21+
while i <= length(tt)
22+
push!(autodifftypes, tt[i])
23+
dt = whatType(tt[i])
24+
push!(diffetypes, dt)
25+
if dt == "diffe_dup"
26+
i+=1
27+
end
2528
i+=1
2629
end
27-
i+=1
28-
end
29-
mod, ccf = irgen(Tuple{autodifftypes...})
30-
31-
ctx = context(mod)
32-
rettype = convert(LLVMType, Float64)
33-
34-
#argtypes2 = LLVMType[convert(LLVMType, T, true) for T in args]
35-
argtypes2 = LLVMType[]
36-
37-
i = 1
38-
j = 1
39-
orig_params = parameters(ccf)
40-
for p in orig_params
41-
push!(argtypes2, llvmtype(p))
42-
if diffetypes[i] == "diffe_dup"
30+
31+
source = Compiler.FunctionSpec(f, Base.to_tuple_type(autodifftypes), #=kernel=# false)
32+
target = Compiler.EnzymeTarget()
33+
job = Compiler.EnzymeJob(target, source)
34+
35+
# Codegen the primal function and all its dependency in one module
36+
mod, primalf = Compiler.codegen(:llvm, job, optimize=false)
37+
38+
# Now build the actual wrapper function
39+
ctx = context(mod)
40+
rettype = convert(LLVMType, rt)
41+
42+
i = 1
43+
orig_params = parameters(primalf)
44+
argtypes2 = LLVMType[]
45+
for p in orig_params
4346
push!(argtypes2, llvmtype(p))
44-
i+=2
45-
else
46-
i+=1
47+
if diffetypes[i] == "diffe_dup"
48+
push!(argtypes2, llvmtype(p))
49+
i+=2
50+
else
51+
i+=1
52+
end
4753
end
48-
end
4954

50-
# TODO get function type from ccf
51-
ft2 = LLVM.FunctionType(rettype, argtypes2)
52-
53-
# create a wrapper Function that we will inline into the llvmcall
54-
# generated by in the end `call_function`
55-
llvmf = LLVM.Function(mod, "enzyme_entry", ft2)
56-
push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0, ctx))
57-
58-
# Create the FunctionType and funtion decleration for the intrinsic
59-
pt = LLVM.PointerType(LLVM.Int8Type(ctx))
60-
ftd = LLVM.FunctionType(rettype, LLVMType[pt], true)
61-
autodiff = LLVM.Function(mod, "__enzyme_autodiff", ftd)
62-
63-
params = LLVM.Value[]
64-
i = 1
65-
j = 1
66-
llvm_params = parameters(llvmf)
67-
while j <= length(args)
68-
push!(params, MDString(diffetypes[i]))
69-
if diffetypes[i] == "diffe_dup"
55+
# TODO get function type from primalf
56+
ft2 = LLVM.FunctionType(rettype, argtypes2)
57+
58+
# create a wrapper Function that we will inline into the llvmcall
59+
# generated by `call_function` in `autodiff`
60+
llvmf = LLVM.Function(mod, "enzyme_entry", ft2)
61+
push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0, ctx))
62+
63+
# Create the FunctionType and funtion decleration for the intrinsic
64+
pt = LLVM.PointerType(LLVM.Int8Type(ctx))
65+
ftd = LLVM.FunctionType(rettype, LLVMType[pt], true)
66+
autodiff = LLVM.Function(mod, "__enzyme_autodiff", ftd)
67+
68+
params = LLVM.Value[]
69+
i = 1
70+
j = 1
71+
llvm_params = parameters(llvmf)
72+
while j <= length(tt)
73+
push!(params, MDString(diffetypes[i]))
74+
if diffetypes[i] == "diffe_dup"
75+
push!(params, llvm_params[j])
76+
j+=1
77+
end
7078
push!(params, llvm_params[j])
71-
j+=1
79+
j += 1
80+
i += 1
7281
end
73-
push!(params, llvm_params[j])
74-
j += 1
75-
i += 1
76-
end
7782

78-
Builder(ctx) do builder
79-
entry = BasicBlock(llvmf, "entry", ctx)
80-
position!(builder, entry)
83+
Builder(ctx) do builder
84+
entry = BasicBlock(llvmf, "entry", ctx)
85+
position!(builder, entry)
8186

82-
tc = bitcast!(builder, ccf, pt)
83-
pushfirst!(params, tc)
87+
tc = bitcast!(builder, primalf, pt)
88+
pushfirst!(params, tc)
8489

85-
val = call!(builder, autodiff, params)
90+
val = call!(builder, autodiff, params)
8691

87-
#if T === Nothing
88-
# ret!(builder)
89-
#else
9092
ret!(builder, val)
91-
#end
92-
end
93-
94-
llvmf, mod
95-
end
93+
end
9694

97-
@generated function autodiff(f, args...)
98-
llvmf, mod = emit(f, args)
95+
# Run pipeline and Enzyme pass
96+
Compiler.optimize!(job, mod, llvmf)
97+
strip_debuginfo!(mod)
9998

100-
# Run pipeline and Enzyme pass
101-
optimize!(mod)
102-
strip_debuginfo!(mod)
99+
new{typeof(f), rt, Tuple{tt...}, llvmf}(mod, llvmf)
100+
end
101+
end
103102

103+
# This is rather wonky... we should instead integrate with the ORCJIT C-API
104+
# https://github.com/JuliaGPU/GPUCompiler.jl/issues/3
105+
# We are also re-running Julia's optimization pipeline again
106+
@generated function (thunk::Thunk{F, RT, TT, LLVMF})(args...) where {F, RT, TT, LLVMF}
104107
_args = (:(args[$i]) for i in 1:length(args))
105-
call_function(llvmf, Float64, Tuple{args...}, Expr(:tuple, _args...))
108+
call_function(LLVMF, Float64, Tuple{args...}, Expr(:tuple, _args...))
109+
end
110+
111+
function autodiff(f, args...)
112+
thunk = Thunk(f, Float64, map(Core.Typeof, args))
113+
thunk(args...)
106114
end
107115

108116
end # module

enzyme/julia/src/compiler.jl

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
module Compiler
2+
3+
using GPUCompiler
4+
using LLVM
5+
using LLVM.Interop
6+
7+
import GPUCompiler: FunctionSpec, codegen
8+
9+
import Libdl
10+
llvmver = LLVM.version().major
11+
if haskey(ENV, "ENZYME_PATH")
12+
enzyme_path = ENV["ENZYME_PATH"]
13+
else
14+
error("Please set the environment variable ENZYME_PATH")
15+
end
16+
const libenzyme = abspath(joinpath(enzyme_path, "LLVMEnzyme-$(llvmver).$(Libdl.dlext)"))
17+
18+
if !isfile(libenzyme)
19+
error("$(libenzyme) does not exist, Please specify a correct path in ENZYME_PATH, and restart Julia.")
20+
end
21+
22+
if Libdl.dlopen_e(libenzyme) in (C_NULL, nothing)
23+
error("$(libenzyme) cannot be opened, Please specify a correct path in ENZYME_PATH, and restart Julia.")
24+
end
25+
26+
function __init__()
27+
Libdl.dlopen(libenzyme, Libdl.RTLD_GLOBAL)
28+
LLVM.clopts("-enzyme_preopt=0")
29+
end
30+
31+
# Define EnzymeTarget & EnzymeJob
32+
using LLVM: triple, Target, TargetMachine
33+
import GPUCompiler: llvm_triple
34+
35+
Base.@kwdef struct EnzymeTarget <: AbstractCompilerTarget
36+
end
37+
38+
GPUCompiler.isintrinsic(::EnzymeTarget, fn::String) = true
39+
GPUCompiler.can_throw(::EnzymeTarget) = true
40+
41+
llvm_triple(::EnzymeTarget) = triple()
42+
43+
# GPUCompiler.llvm_datalayout(::EnzymeTarget) = nothing
44+
45+
function GPUCompiler.llvm_machine(target::EnzymeTarget)
46+
t = Target(llvm_triple(target))
47+
tm = TargetMachine(t, llvm_triple(target))
48+
LLVM.asm_verbosity!(tm, true)
49+
50+
return tm
51+
end
52+
53+
module Runtime
54+
# the runtime library
55+
signal_exception() = return
56+
malloc(sz) = return
57+
report_oom(sz) = return
58+
report_exception(ex) = return
59+
report_exception_name(ex) = return
60+
report_exception_frame(idx, func, file, line) = return
61+
end
62+
63+
GPUCompiler.runtime_module(target::EnzymeTarget) = Runtime
64+
65+
## job
66+
67+
export EnzymeJob
68+
69+
Base.@kwdef struct EnzymeJob <: AbstractCompilerJob
70+
target::EnzymeTarget
71+
source::FunctionSpec
72+
end
73+
74+
import GPUCompiler: target, source
75+
target(job::EnzymeJob) = job.target
76+
source(job::EnzymeJob) = job.source
77+
78+
Base.similar(job::EnzymeJob, source::FunctionSpec) =
79+
EnzymeJob(target=job.target, source=source)
80+
81+
function Base.show(io::IO, job::EnzymeJob)
82+
print(io, "Enzyme CompilerJob of ", GPUCompiler.source(job))
83+
end
84+
85+
# TODO: encode debug build or not in the compiler job
86+
# https://github.com/JuliaGPU/CUDAnative.jl/issues/368
87+
GPUCompiler.runtime_slug(job::EnzymeJob) = "enzyme"
88+
89+
include("compiler/optimize.jl")
90+
91+
end
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
function enzyme!(pm)
2+
ccall((:AddEnzymePass, libenzyme), Nothing, (LLVM.API.LLVMPassManagerRef,), LLVM.ref(pm))
3+
end
4+
5+
import GPUCompiler: optimize!
6+
function optimize!(job::EnzymeJob, mod::LLVM.Module, entry::LLVM.Function)
7+
tm = GPUCompiler.llvm_machine(target(job))
8+
9+
# everying except unroll, slpvec, loop-vec
10+
# then finish Julia GC
11+
ModulePassManager() do pm
12+
add_library_info!(pm, triple(mod))
13+
add_transform_info!(pm, tm)
14+
15+
propagate_julia_addrsp!(pm)
16+
scoped_no_alias_aa!(pm)
17+
type_based_alias_analysis!(pm)
18+
basic_alias_analysis!(pm)
19+
cfgsimplification!(pm)
20+
# TODO: DCE (doesn't exist in llvm-c)
21+
scalar_repl_aggregates!(pm) # SSA variant?
22+
mem_cpy_opt!(pm)
23+
always_inliner!(pm)
24+
alloc_opt!(pm)
25+
instruction_combining!(pm)
26+
cfgsimplification!(pm)
27+
scalar_repl_aggregates!(pm) # SSA variant?
28+
instruction_combining!(pm)
29+
jump_threading!(pm)
30+
instruction_combining!(pm)
31+
reassociate!(pm)
32+
early_cse!(pm)
33+
alloc_opt!(pm)
34+
loop_idiom!(pm)
35+
loop_rotate!(pm)
36+
lower_simdloop!(pm)
37+
licm!(pm)
38+
loop_unswitch!(pm)
39+
instruction_combining!(pm)
40+
ind_var_simplify!(pm)
41+
loop_deletion!(pm)
42+
# SimpleLoopUnroll -- not for Enzyme
43+
alloc_opt!(pm)
44+
scalar_repl_aggregates!(pm) # SSA variant?
45+
instruction_combining!(pm)
46+
gvn!(pm)
47+
mem_cpy_opt!(pm)
48+
sccp!(pm)
49+
# TODO: Sinking Pass
50+
# TODO: LLVM <7 InstructionSimplifier
51+
instruction_combining!(pm)
52+
jump_threading!(pm)
53+
dead_store_elimination!(pm)
54+
alloc_opt!(pm)
55+
cfgsimplification!(pm)
56+
loop_idiom!(pm)
57+
loop_deletion!(pm)
58+
jump_threading!(pm)
59+
# SLP_Vectorizer -- not for Enzyme
60+
aggressive_dce!(pm)
61+
instruction_combining!(pm)
62+
# Loop Vectorize -- not for Enzyme
63+
# InstCombine
64+
65+
# GC passes
66+
barrier_noop!(pm)
67+
lower_exc_handlers!(pm)
68+
gc_invariant_verifier!(pm, false)
69+
late_lower_gc_frame!(pm)
70+
final_lower_gc!(pm)
71+
# TODO: DCE doesn't exist in llvm-c
72+
lower_ptls!(pm, #=dump_native=# false)
73+
cfgsimplification!(pm)
74+
instruction_combining!(pm) # Extra for Enzyme
75+
# CombineMulAddPass will run on second pass
76+
77+
# Enzyme pass
78+
barrier_noop!(pm)
79+
enzyme!(pm)
80+
81+
run!(pm, mod)
82+
end
83+
84+
return entry
85+
end

0 commit comments

Comments
 (0)