@@ -4,105 +4,113 @@ export autodiff
44
55using LLVM
66using LLVM. Interop
7- import MCAnalyzer: irgen
87
98include (" utils.jl" )
109include (" ad.jl" )
11- include (" opt.jl" )
12-
13- using . Opt: optimize!
14-
15- function emit (f, args)
16- # Obtain the function and all it's dependencies in one handy module
17- diffetypes = []
18- autodifftypes = Type[f]
19- i = 1
20- while i <= length (args)
21- push! (autodifftypes, args[i])
22- dt = whatType (args[i])
23- push! (diffetypes, dt)
24- if dt == " diffe_dup"
10+ include (" compiler.jl" )
11+
12+ struct Thunk{F, RT, TT, LLVMF}
13+ mod:: LLVM.Module
14+ entry:: LLVM.Function
15+
16+ function Thunk (f, rt, tt)
17+ # Drop from the signature the Enzyme calling-convetion
18+ diffetypes = []
19+ autodifftypes = []
20+ i = 1
21+ while i <= length (tt)
22+ push! (autodifftypes, tt[i])
23+ dt = whatType (tt[i])
24+ push! (diffetypes, dt)
25+ if dt == " diffe_dup"
26+ i+= 1
27+ end
2528 i+= 1
2629 end
27- i += 1
28- end
29- mod, ccf = irgen (Tuple{autodifftypes ... })
30-
31- ctx = context (mod)
32- rettype = convert (LLVMType, Float64)
33-
34- # argtypes2 = LLVMType[convert(LLVMType, T, true) for T in args]
35- argtypes2 = LLVMType[]
36-
37- i = 1
38- j = 1
39- orig_params = parameters (ccf)
40- for p in orig_params
41- push! ( argtypes2, llvmtype (p))
42- if diffetypes[i] == " diffe_dup "
30+
31+ source = Compiler . FunctionSpec (f, Base . to_tuple_type (autodifftypes), #= kernel =# false )
32+ target = Compiler . EnzymeTarget ()
33+ job = Compiler . EnzymeJob (target, source)
34+
35+ # Codegen the primal function and all its dependency in one module
36+ mod, primalf = Compiler . codegen ( :llvm , job, optimize = false )
37+
38+ # Now build the actual wrapper function
39+ ctx = context (mod)
40+ rettype = convert (LLVMType, rt)
41+
42+ i = 1
43+ orig_params = parameters (primalf)
44+ argtypes2 = LLVMType[]
45+ for p in orig_params
4346 push! (argtypes2, llvmtype (p))
44- i+= 2
45- else
46- i+= 1
47+ if diffetypes[i] == " diffe_dup"
48+ push! (argtypes2, llvmtype (p))
49+ i+= 2
50+ else
51+ i+= 1
52+ end
4753 end
48- end
4954
50- # TODO get function type from ccf
51- ft2 = LLVM. FunctionType (rettype, argtypes2)
52-
53- # create a wrapper Function that we will inline into the llvmcall
54- # generated by in the end `call_function`
55- llvmf = LLVM. Function (mod, " enzyme_entry" , ft2)
56- push! (function_attributes (llvmf), EnumAttribute (" alwaysinline" , 0 , ctx))
57-
58- # Create the FunctionType and funtion decleration for the intrinsic
59- pt = LLVM. PointerType (LLVM. Int8Type (ctx))
60- ftd = LLVM. FunctionType (rettype, LLVMType[pt], true )
61- autodiff = LLVM. Function (mod, " __enzyme_autodiff" , ftd)
62-
63- params = LLVM. Value[]
64- i = 1
65- j = 1
66- llvm_params = parameters (llvmf)
67- while j <= length (args)
68- push! (params, MDString (diffetypes[i]))
69- if diffetypes[i] == " diffe_dup"
55+ # TODO get function type from primalf
56+ ft2 = LLVM. FunctionType (rettype, argtypes2)
57+
58+ # create a wrapper Function that we will inline into the llvmcall
59+ # generated by `call_function` in `autodiff`
60+ llvmf = LLVM. Function (mod, " enzyme_entry" , ft2)
61+ push! (function_attributes (llvmf), EnumAttribute (" alwaysinline" , 0 , ctx))
62+
63+ # Create the FunctionType and funtion decleration for the intrinsic
64+ pt = LLVM. PointerType (LLVM. Int8Type (ctx))
65+ ftd = LLVM. FunctionType (rettype, LLVMType[pt], true )
66+ autodiff = LLVM. Function (mod, " __enzyme_autodiff" , ftd)
67+
68+ params = LLVM. Value[]
69+ i = 1
70+ j = 1
71+ llvm_params = parameters (llvmf)
72+ while j <= length (tt)
73+ push! (params, MDString (diffetypes[i]))
74+ if diffetypes[i] == " diffe_dup"
75+ push! (params, llvm_params[j])
76+ j+= 1
77+ end
7078 push! (params, llvm_params[j])
71- j+= 1
79+ j += 1
80+ i += 1
7281 end
73- push! (params, llvm_params[j])
74- j += 1
75- i += 1
76- end
7782
78- Builder (ctx) do builder
79- entry = BasicBlock (llvmf, " entry" , ctx)
80- position! (builder, entry)
83+ Builder (ctx) do builder
84+ entry = BasicBlock (llvmf, " entry" , ctx)
85+ position! (builder, entry)
8186
82- tc = bitcast! (builder, ccf, pt)
83- pushfirst! (params, tc)
87+ tc = bitcast! (builder, primalf, pt)
88+ pushfirst! (params, tc)
8489
85- val = call! (builder, autodiff, params)
90+ val = call! (builder, autodiff, params)
8691
87- # if T === Nothing
88- # ret!(builder)
89- # else
9092 ret! (builder, val)
91- # end
92- end
93-
94- llvmf, mod
95- end
93+ end
9694
97- @generated function autodiff (f, args... )
98- llvmf, mod = emit (f, args)
95+ # Run pipeline and Enzyme pass
96+ Compiler. optimize! (job, mod, llvmf)
97+ strip_debuginfo! (mod)
9998
100- # Run pipeline and Enzyme pass
101- optimize! (mod)
102- strip_debuginfo! (mod)
99+ new {typeof(f), rt, Tuple{tt...}, llvmf} (mod, llvmf)
100+ end
101+ end
103102
103+ # This is rather wonky... we should instead integrate with the ORCJIT C-API
104+ # https://github.com/JuliaGPU/GPUCompiler.jl/issues/3
105+ # We are also re-running Julia's optimization pipeline again
106+ @generated function (thunk:: Thunk{F, RT, TT, LLVMF} )(args... ) where {F, RT, TT, LLVMF}
104107 _args = (:(args[$ i]) for i in 1 : length (args))
105- call_function (llvmf, Float64, Tuple{args... }, Expr (:tuple , _args... ))
108+ call_function (LLVMF, Float64, Tuple{args... }, Expr (:tuple , _args... ))
109+ end
110+
111+ function autodiff (f, args... )
112+ thunk = Thunk (f, Float64, map (Core. Typeof, args))
113+ thunk (args... )
106114end
107115
108116end # module
0 commit comments