Skip to content

Commit fafaadb

Browse files
committed
cleanup
1 parent ffe719e commit fafaadb

File tree

6 files changed

+230
-196
lines changed

6 files changed

+230
-196
lines changed

jlpkg/Enzyme/Project.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,13 @@ BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
88
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
99
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1010
MCAnalyzer = "a81df072-f4bb-11e8-03d3-cfaeda626d18"
11+
12+
[compat]
13+
julia = "1.0"
14+
15+
[extras]
16+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
17+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
18+
19+
[targets]
20+
test = ["Test", "ReverseDiff"]

jlpkg/Enzyme/src/Enzyme.jl

Lines changed: 11 additions & 196 deletions
Original file line numberDiff line numberDiff line change
@@ -1,86 +1,16 @@
11
module Enzyme
22

3-
include(joinpath(@__DIR__, "..", "deps", "deps.jl"))
3+
export autodiff
44

55
using LLVM
66
using LLVM.Interop
77
import MCAnalyzer: irgen
88

9+
include("utils.jl")
10+
include("ad.jl")
11+
include("opt.jl")
912

10-
@enum Diffe begin
11-
Duplicate = 1
12-
Output = 2
13-
Constant = 3
14-
end
15-
16-
function hasfieldcount(@nospecialize(dt))
17-
try
18-
fieldcount(dt)
19-
catch
20-
return false
21-
end
22-
return true
23-
end
24-
25-
function whatType(@nospecialize(dt))
26-
if <:(dt, Array)
27-
sub = whatType(eltype(dt))
28-
if sub == "diffe_dup"
29-
return "diffe_dup"
30-
elseif sub == "diffe_out"
31-
return "diffe_dup"
32-
else
33-
@assert(sub == "diffe_const")
34-
return "diffe_const"
35-
end
36-
end
37-
if <:(dt, Real)
38-
return "diffe_out"
39-
end
40-
if <:(dt, Int)
41-
return "diffe_const"
42-
end
43-
if <:(dt, String)
44-
return "diffe_const"
45-
end
46-
47-
if ! hasfieldcount(dt)
48-
# just be safe for now
49-
return "diffe_dup"
50-
end
51-
@assert(hasfieldcount(dt))
52-
@assert(isstructtype(dt))
53-
passpointer = true
54-
if passpointer
55-
ty = "diffe_const"
56-
for (ft, fn) in zip(fieldtypes(dt), fieldnames(dt))
57-
sub = whatType(ft)
58-
if sub == "diffe_dup"
59-
ty = "diffe_dup"
60-
elseif sub == "diffe_out"
61-
ty = "diffe_dup"
62-
else
63-
@assert(sub == "diffe_const")
64-
end
65-
end
66-
return ty
67-
else
68-
ty = "diffe_const"
69-
for (ft, fn) in zip(fieldtypes(dt), fieldnames(dt))
70-
sub = whatType(ft)
71-
if sub == "diffe_dup"
72-
ty = "diffe_dup"
73-
elseif sub == "diffe_out"
74-
if ty != "diffe_dup"
75-
ty = "diffe_out"
76-
end
77-
else
78-
@assert(sub == "diffe_const")
79-
end
80-
end
81-
return ty
82-
end
83-
end
13+
using .Opt: optimize!
8414

8515
@generated function autodiff(f, args...)
8616
# Obtain the function and all it's dependencies in one handy module
@@ -106,12 +36,8 @@ end
10636

10737
i = 1
10838
j = 1
109-
Base.println(typeof(ccf))
110-
Base.println(typeof(llvmtype(ccf)))
111-
Base.println(llvmtype(ccf))
11239
orig_params = parameters(ccf)
11340
for p in orig_params
114-
Base.println(llvmtype(p))
11541
push!(argtypes2, llvmtype(p))
11642
if diffetypes[i] == "diffe_dup"
11743
push!(argtypes2, llvmtype(p))
@@ -120,8 +46,6 @@ end
12046
i+=1
12147
end
12248
end
123-
Base.println(argtypes2)
124-
12549

12650
# TODO get function type from ccf
12751
ft2 = LLVM.FunctionType(rettype, argtypes2)
@@ -148,8 +72,8 @@ end
14872
j+=1
14973
end
15074
push!(params, llvm_params[j])
151-
j+=1
152-
i+=1
75+
j += 1
76+
i += 1
15377
end
15478

15579
Builder(ctx) do builder
@@ -168,120 +92,11 @@ end
16892
#end
16993
end
17094

171-
# TODO: Run pipeline and Enzyme pass
95+
# Run pipeline and Enzyme pass
96+
optimize!(mod)
17297

17398
_args = (:(args[$i]) for i in 1:length(args))
17499
call_function(llvmf, Float64, Tuple{args...}, Expr(:tuple, _args...))
175100
end
176-
177-
178-
function jlpasses!(pm)
179-
ccall(:jl_add_optimization_passes, Nothing,
180-
(LLVM.API.LLVMPassManagerRef, Cint),
181-
LLVM.ref(pm), optlevel[])
182-
end
183-
184-
function enzyme_pass!(pm)
185-
ccall((libenzyme, :AddEnzymePass), Nothing, (LLVM.API.LLVMPassManagerRef,) LLVM.ref(pm))
186-
end
187-
188-
function optimize!(mod, opt_level=2)
189-
# everying except unroll, slpvec, loop-vec
190-
# then finish Julia GC
191-
ModulePassManager() do pm
192-
if opt_level < 2
193-
cfgsimplification!(pm)
194-
if opt_level == 1
195-
scalar_repl_aggregates!(pm) # SSA variant?
196-
instruction_combinining!(pm)
197-
early_cse!(pm)
198-
end
199-
mem_cpy_opt!(pm)
200-
always_inliner!(pm)
201-
lower_simdloop!(pm)
202-
203-
# GC passes
204-
barrier_noop!(pm)
205-
lower_exc_handlers!(pm)
206-
gc_invariant_verifier(pm, false)
207-
late_lower_gc_frame!(pm)
208-
# TODO: FinalLowerGCPass
209-
lower_ptls!(pm, #=dump_native=# false)
210-
211-
# Enzyme pass
212-
barrier_noop!(pm)
213-
enzyme_pass!(pm)
214-
else
215-
propagate_julia_addrsp!(pm)
216-
scoped_no_alias_aa!(pm)
217-
type_based_alias_analysis!(pm)
218-
if opt_level >= 2
219-
basic_alias_analysis!(pm)
220-
end
221-
cfgsimplification!(pm)
222-
# TODO: DCE
223-
scalar_repl_aggregates!(pm) # SSA variant?
224-
mem_cpy_opt!(pm)
225-
always_inliner!(pm)
226-
alloc_opt!(pm)
227-
instruction_combinining!(pm)
228-
cfgsimplification!(pm)
229-
scalar_repl_aggregates!(pm) # SSA variant?
230-
instruction_combinining!(pm)
231-
jump_threading!(pm)
232-
instruction_combinining!(pm)
233-
reassociate!(pm)
234-
early_cse!(pm)
235-
alloc_opt!(pm)
236-
loop_idiom!(pm)
237-
loop_rotate!(pm)
238-
lower_simdloop!(pm)
239-
licm!(pm)
240-
loop_unswitch!(pm)
241-
instruction_combinining!(pm)
242-
ind_var_simplify!(pm)
243-
loop_deletion!(pm)
244-
# SimpleLoopUnroll -- not for Enzyme
245-
alloc_opt!(pm)
246-
scalar_repl_aggregates!(pm) # SSA variant?
247-
instruction_combinining!(pm)
248-
gvn!(pm)
249-
mem_cpy_opt!(pm)
250-
sccp!(pm)
251-
# TODO: Sinking Pass
252-
# TODO: LLVM <7 InstructionSimplifier
253-
instruction_combinining!(pm)
254-
jump_threading!(pm)
255-
dead_store_elimination!(pm)
256-
alloc_opt!(pm)
257-
cfgsimplification!(pm)
258-
loop_idiom!(pm)
259-
loop_deletion!(pm)
260-
jump_threading!(pm)
261-
# SLP_Vectorizer -- not for Enzyme
262-
aggressive_dce!(pm)
263-
instruction_combinining!(pm)
264-
# Loop Vectorize -- not for Enzyme
265-
# InstCombine
266-
267-
# GC passes
268-
barrier_noop!(pm)
269-
lower_exc_handlers!(pm)
270-
gc_invariant_verifier(pm, false)
271-
late_lower_gc_frame!(pm)
272-
# TODO: FinalLowerGCPass
273-
# TODO: DCE
274-
lower_ptls!(pm, #=dump_native=# false)
275-
cfgsimplification!(pm)
276-
instruction_combinining!(pm) # Extra for Enzyme
277-
# CombineMulAddPass will run on second pass
278-
279-
# Enzyme pass
280-
barrier_noop!(pm)
281-
enzyme_pass!(pm)
282-
end
283-
run!(pm, mod)
284-
end
285-
end
286-
287-
end
101+
102+
end # module

jlpkg/Enzyme/src/ad.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
@enum Diffe begin
2+
Duplicate = 1
3+
Output = 2
4+
Constant = 3
5+
end
6+
7+
function whatType(@nospecialize(dt))
8+
if <:(dt, Array)
9+
sub = whatType(eltype(dt))
10+
if sub == "diffe_dup"
11+
return "diffe_dup"
12+
elseif sub == "diffe_out"
13+
return "diffe_dup"
14+
else
15+
@assert(sub == "diffe_const")
16+
return "diffe_const"
17+
end
18+
end
19+
if <:(dt, Real)
20+
return "diffe_out"
21+
end
22+
if <:(dt, Int)
23+
return "diffe_const"
24+
end
25+
if <:(dt, String)
26+
return "diffe_const"
27+
end
28+
29+
if !hasfieldcount(dt)
30+
# just be safe for now
31+
return "diffe_dup"
32+
end
33+
34+
@assert(hasfieldcount(dt))
35+
@assert(isstructtype(dt))
36+
passpointer = true
37+
if passpointer
38+
ty = "diffe_const"
39+
for (ft, fn) in zip(fieldtypes(dt), fieldnames(dt))
40+
sub = whatType(ft)
41+
if sub == "diffe_dup"
42+
ty = "diffe_dup"
43+
elseif sub == "diffe_out"
44+
ty = "diffe_dup"
45+
else
46+
@assert(sub == "diffe_const")
47+
end
48+
end
49+
return ty
50+
else
51+
ty = "diffe_const"
52+
for (ft, fn) in zip(fieldtypes(dt), fieldnames(dt))
53+
sub = whatType(ft)
54+
if sub == "diffe_dup"
55+
ty = "diffe_dup"
56+
elseif sub == "diffe_out"
57+
if ty != "diffe_dup"
58+
ty = "diffe_out"
59+
end
60+
else
61+
@assert(sub == "diffe_const")
62+
end
63+
end
64+
return ty
65+
end
66+
end

0 commit comments

Comments
 (0)