11module Enzyme
22
3- include ( joinpath ( @__DIR__ , " .. " , " deps " , " deps.jl " ))
3+ export autodiff
44
55using LLVM
66using LLVM. Interop
77import MCAnalyzer: irgen
88
9+ include (" utils.jl" )
10+ include (" ad.jl" )
11+ include (" opt.jl" )
912
10- @enum Diffe begin
11- Duplicate = 1
12- Output = 2
13- Constant = 3
14- end
15-
16- function hasfieldcount (@nospecialize (dt))
17- try
18- fieldcount (dt)
19- catch
20- return false
21- end
22- return true
23- end
24-
25- function whatType (@nospecialize (dt))
26- if <: (dt, Array)
27- sub = whatType (eltype (dt))
28- if sub == " diffe_dup"
29- return " diffe_dup"
30- elseif sub == " diffe_out"
31- return " diffe_dup"
32- else
33- @assert (sub == " diffe_const" )
34- return " diffe_const"
35- end
36- end
37- if <: (dt, Real)
38- return " diffe_out"
39- end
40- if <: (dt, Int)
41- return " diffe_const"
42- end
43- if <: (dt, String)
44- return " diffe_const"
45- end
46-
47- if ! hasfieldcount (dt)
48- # just be safe for now
49- return " diffe_dup"
50- end
51- @assert (hasfieldcount (dt))
52- @assert (isstructtype (dt))
53- passpointer = true
54- if passpointer
55- ty = " diffe_const"
56- for (ft, fn) in zip (fieldtypes (dt), fieldnames (dt))
57- sub = whatType (ft)
58- if sub == " diffe_dup"
59- ty = " diffe_dup"
60- elseif sub == " diffe_out"
61- ty = " diffe_dup"
62- else
63- @assert (sub == " diffe_const" )
64- end
65- end
66- return ty
67- else
68- ty = " diffe_const"
69- for (ft, fn) in zip (fieldtypes (dt), fieldnames (dt))
70- sub = whatType (ft)
71- if sub == " diffe_dup"
72- ty = " diffe_dup"
73- elseif sub == " diffe_out"
74- if ty != " diffe_dup"
75- ty = " diffe_out"
76- end
77- else
78- @assert (sub == " diffe_const" )
79- end
80- end
81- return ty
82- end
83- end
13+ using . Opt: optimize!
8414
8515@generated function autodiff (f, args... )
8616 # Obtain the function and all it's dependencies in one handy module
10636
10737 i = 1
10838 j = 1
109- Base. println (typeof (ccf))
110- Base. println (typeof (llvmtype (ccf)))
111- Base. println (llvmtype (ccf))
11239 orig_params = parameters (ccf)
11340 for p in orig_params
114- Base. println (llvmtype (p))
11541 push! (argtypes2, llvmtype (p))
11642 if diffetypes[i] == " diffe_dup"
11743 push! (argtypes2, llvmtype (p))
12046 i+= 1
12147 end
12248 end
123- Base. println (argtypes2)
124-
12549
12650 # TODO get function type from ccf
12751 ft2 = LLVM. FunctionType (rettype, argtypes2)
14872 j+= 1
14973 end
15074 push! (params, llvm_params[j])
151- j+= 1
152- i+= 1
75+ j += 1
76+ i += 1
15377 end
15478
15579 Builder (ctx) do builder
@@ -168,120 +92,11 @@ end
16892 # end
16993 end
17094
171- # TODO : Run pipeline and Enzyme pass
95+ # Run pipeline and Enzyme pass
96+ optimize! (mod)
17297
17398 _args = (:(args[$ i]) for i in 1 : length (args))
17499 call_function (llvmf, Float64, Tuple{args... }, Expr (:tuple , _args... ))
175100end
176-
177-
178- function jlpasses! (pm)
179- ccall (:jl_add_optimization_passes , Nothing,
180- (LLVM. API. LLVMPassManagerRef, Cint),
181- LLVM. ref (pm), optlevel[])
182- end
183-
184- function enzyme_pass! (pm)
185- ccall ((libenzyme, :AddEnzymePass ), Nothing, (LLVM. API. LLVMPassManagerRef,) LLVM. ref (pm))
186- end
187-
188- function optimize! (mod, opt_level= 2 )
189- # everying except unroll, slpvec, loop-vec
190- # then finish Julia GC
191- ModulePassManager () do pm
192- if opt_level < 2
193- cfgsimplification! (pm)
194- if opt_level == 1
195- scalar_repl_aggregates! (pm) # SSA variant?
196- instruction_combinining! (pm)
197- early_cse! (pm)
198- end
199- mem_cpy_opt! (pm)
200- always_inliner! (pm)
201- lower_simdloop! (pm)
202-
203- # GC passes
204- barrier_noop! (pm)
205- lower_exc_handlers! (pm)
206- gc_invariant_verifier (pm, false )
207- late_lower_gc_frame! (pm)
208- # TODO : FinalLowerGCPass
209- lower_ptls! (pm, #= dump_native=# false )
210-
211- # Enzyme pass
212- barrier_noop! (pm)
213- enzyme_pass! (pm)
214- else
215- propagate_julia_addrsp! (pm)
216- scoped_no_alias_aa! (pm)
217- type_based_alias_analysis! (pm)
218- if opt_level >= 2
219- basic_alias_analysis! (pm)
220- end
221- cfgsimplification! (pm)
222- # TODO : DCE
223- scalar_repl_aggregates! (pm) # SSA variant?
224- mem_cpy_opt! (pm)
225- always_inliner! (pm)
226- alloc_opt! (pm)
227- instruction_combinining! (pm)
228- cfgsimplification! (pm)
229- scalar_repl_aggregates! (pm) # SSA variant?
230- instruction_combinining! (pm)
231- jump_threading! (pm)
232- instruction_combinining! (pm)
233- reassociate! (pm)
234- early_cse! (pm)
235- alloc_opt! (pm)
236- loop_idiom! (pm)
237- loop_rotate! (pm)
238- lower_simdloop! (pm)
239- licm! (pm)
240- loop_unswitch! (pm)
241- instruction_combinining! (pm)
242- ind_var_simplify! (pm)
243- loop_deletion! (pm)
244- # SimpleLoopUnroll -- not for Enzyme
245- alloc_opt! (pm)
246- scalar_repl_aggregates! (pm) # SSA variant?
247- instruction_combinining! (pm)
248- gvn! (pm)
249- mem_cpy_opt! (pm)
250- sccp! (pm)
251- # TODO : Sinking Pass
252- # TODO : LLVM <7 InstructionSimplifier
253- instruction_combinining! (pm)
254- jump_threading! (pm)
255- dead_store_elimination! (pm)
256- alloc_opt! (pm)
257- cfgsimplification! (pm)
258- loop_idiom! (pm)
259- loop_deletion! (pm)
260- jump_threading! (pm)
261- # SLP_Vectorizer -- not for Enzyme
262- aggressive_dce! (pm)
263- instruction_combinining! (pm)
264- # Loop Vectorize -- not for Enzyme
265- # InstCombine
266-
267- # GC passes
268- barrier_noop! (pm)
269- lower_exc_handlers! (pm)
270- gc_invariant_verifier (pm, false )
271- late_lower_gc_frame! (pm)
272- # TODO : FinalLowerGCPass
273- # TODO : DCE
274- lower_ptls! (pm, #= dump_native=# false )
275- cfgsimplification! (pm)
276- instruction_combinining! (pm) # Extra for Enzyme
277- # CombineMulAddPass will run on second pass
278-
279- # Enzyme pass
280- barrier_noop! (pm)
281- enzyme_pass! (pm)
282- end
283- run! (pm, mod)
284- end
285- end
286-
287- end
101+
102+ end # module
0 commit comments