Skip to content

Commit 1cc17db

Browse files
authored
Merge pull request #36 from wsmoses/vc/julia_integration
Import Julia source and redo CI for 1.3
2 parents 89fe0ab + 313131e commit 1cc17db

File tree

10 files changed

+400
-27
lines changed

10 files changed

+400
-27
lines changed

.github/workflows/julia.yml

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,53 +9,43 @@ jobs:
99
strategy:
1010
fail-fast: false
1111
matrix:
12-
julia-version: [1.2.0]
12+
julia-version: [1.3]
1313
os: [ubuntu-18.04]
1414

1515
steps:
16-
- name: "Checkout Enzyme.jl code"
17-
uses: actions/checkout@v1
18-
with:
19-
repository: wsmoses/Enzyme.jl
20-
clean: false
21-
path: Enzyme.jl
22-
token: ${{secrets.enzymejl_secret}}
23-
ref: master
2416
- name: "Checkout Enzyme code"
2517
uses: actions/checkout@v1
2618
with:
2719
fetch-depth: 1
2820
- name: "Set up Julia"
29-
uses: julia-actions/setup-julia@v0.2
21+
uses: julia-actions/setup-julia@v1
3022
with:
3123
version: ${{ matrix.julia-version }}
3224
- name: "Download LLVM"
3325
run: |
34-
julia --project=contrib -e 'using Pkg; pkg"instantiate"'
35-
julia --project=contrib contrib/build_LLVM.v6.0.1.jl
26+
julia -e 'using Pkg; pkg"add LLVM_jll"'
27+
ARTIFACT_DIR=`julia -e "using LLVM_jll; print(LLVM_jll.artifact_dir)"`
28+
echo "ARTIFACT_DIR=${ARTIFACT_DIR}"
29+
# workaround https://github.com/JuliaPackaging/Yggdrasil/issues/652
3630
sudo mkdir -p /workspace
37-
sudo ln -s `pwd`/contrib/usr /workspace/destdir
38-
sudo ln -s `pwd`/contrib/usr/tools/opt `pwd`/contrib/usr/bin/opt
39-
sudo ln -s `pwd`/contrib/usr/tools/FileCheck `pwd`/contrib/usr/bin/FileCheck
40-
sudo ln -s `pwd`/contrib/usr/tools/llvm-config `pwd`/contrib/usr/bin/llvm-config
31+
sudo ln -s ${ARTIFACT_DIR} /workspace/destdir
32+
sudo mkdir -p /workspace/destdir/bin/
33+
sudo ln -s /workspace/destdir/tools/opt /workspace/destdir/bin/opt
34+
sudo ln -s /workspace/destdir/tools/FileCheck /workspace/destdir/bin/FileCheck
35+
sudo ln -s /workspace/destdir/tools/llvm-config /workspace/destdir/bin/llvm-config
4136
env:
42-
NO_DEPS: true
37+
JULIA_PROJECT: "/tmp/proj"
38+
4339
- name: "Build Enzyme"
4440
run: |
4541
mkdir build
4642
cd build
4743
cmake --version
48-
ls /home/runner/work/Enzyme/Enzyme/enzyme/Enzyme/
49-
cmake -DLLVM_DIR=../contrib/usr/lib/cmake/llvm -DLLVM_EXTERNAL_LIT=../contrib/usr/tools/lit/lit.py ../enzyme
44+
cmake -DLLVM_DIR=/workspace/destdir/lib/cmake/llvm \
45+
-DLLVM_EXTERNAL_LIT=/workspace/destdir/tools/lit/lit.py ../enzyme
5046
make -j
5147
- name: "Julia tests"
5248
run: |
53-
cd ../Enzyme.jl
54-
julia --project=. -e 'using Pkg; pkg"instantiate"; pkg"add ReverseDiff"'
55-
julia --project=. test/runtests.jl
49+
ENZYME_PATH=`pwd`/build/Enzyme julia -e 'using Pkg; pkg"test"'
5650
env:
57-
ENZYME_PATH: "../Enzyme/build/Enzyme"
58-
- name: "Check Enzyme"
59-
run: |
60-
#cd build
61-
#make check-enzyme
51+
JULIA_PROJECT: "enzyme/julia"

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
*.swp
22
*.swo
3+
build
4+
.vscode

enzyme/julia/LICENSE.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
The MIT License (MIT)
2+
3+
Copyright © 2019 William Moses, Valentin Churavy, and other contributors
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in
13+
all copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21+
THE SOFTWARE.

enzyme/julia/Project.toml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
name = "Enzyme"
2+
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
3+
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
8+
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
9+
MCAnalyzer = "a81df072-f4bb-11e8-03d3-cfaeda626d18"
10+
11+
[compat]
12+
julia = "1.3"
13+
LLVM = "1.3"
14+
MCAnalyzer = "0.1"

enzyme/julia/src/Enzyme.jl

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
module Enzyme
2+
3+
export autodiff
4+
5+
using LLVM
6+
using LLVM.Interop
7+
import MCAnalyzer: irgen
8+
9+
include("utils.jl")
10+
include("ad.jl")
11+
include("opt.jl")
12+
13+
using .Opt: optimize!
14+
15+
function emit(f, args)
16+
# Obtain the function and all it's dependencies in one handy module
17+
diffetypes = []
18+
autodifftypes = Type[f]
19+
i = 1
20+
while i <= length(args)
21+
push!(autodifftypes, args[i])
22+
dt = whatType(args[i])
23+
push!(diffetypes, dt)
24+
if dt == "diffe_dup"
25+
i+=1
26+
end
27+
i+=1
28+
end
29+
mod, ccf = irgen(Tuple{autodifftypes...})
30+
31+
ctx = context(mod)
32+
rettype = convert(LLVMType, Float64)
33+
34+
#argtypes2 = LLVMType[convert(LLVMType, T, true) for T in args]
35+
argtypes2 = LLVMType[]
36+
37+
i = 1
38+
j = 1
39+
orig_params = parameters(ccf)
40+
for p in orig_params
41+
push!(argtypes2, llvmtype(p))
42+
if diffetypes[i] == "diffe_dup"
43+
push!(argtypes2, llvmtype(p))
44+
i+=2
45+
else
46+
i+=1
47+
end
48+
end
49+
50+
# TODO get function type from ccf
51+
ft2 = LLVM.FunctionType(rettype, argtypes2)
52+
53+
# create a wrapper Function that we will inline into the llvmcall
54+
# generated by in the end `call_function`
55+
llvmf = LLVM.Function(mod, "enzyme_entry", ft2)
56+
push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0, ctx))
57+
58+
# Create the FunctionType and funtion decleration for the intrinsic
59+
pt = LLVM.PointerType(LLVM.Int8Type(ctx))
60+
ftd = LLVM.FunctionType(rettype, LLVMType[pt], true)
61+
autodiff = LLVM.Function(mod, "__enzyme_autodiff", ftd)
62+
63+
params = LLVM.Value[]
64+
i = 1
65+
j = 1
66+
llvm_params = parameters(llvmf)
67+
while j <= length(args)
68+
push!(params, MDString(diffetypes[i]))
69+
if diffetypes[i] == "diffe_dup"
70+
push!(params, llvm_params[j])
71+
j+=1
72+
end
73+
push!(params, llvm_params[j])
74+
j += 1
75+
i += 1
76+
end
77+
78+
Builder(ctx) do builder
79+
entry = BasicBlock(llvmf, "entry", ctx)
80+
position!(builder, entry)
81+
82+
tc = bitcast!(builder, ccf, pt)
83+
pushfirst!(params, tc)
84+
85+
val = call!(builder, autodiff, params)
86+
87+
#if T === Nothing
88+
# ret!(builder)
89+
#else
90+
ret!(builder, val)
91+
#end
92+
end
93+
94+
llvmf, mod
95+
end
96+
97+
@generated function autodiff(f, args...)
98+
llvmf, mod = emit(f, args)
99+
100+
# Run pipeline and Enzyme pass
101+
optimize!(mod)
102+
strip_debuginfo!(mod)
103+
104+
_args = (:(args[$i]) for i in 1:length(args))
105+
call_function(llvmf, Float64, Tuple{args...}, Expr(:tuple, _args...))
106+
end
107+
108+
end # module

enzyme/julia/src/ad.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
@enum Diffe begin
2+
Duplicate = 1
3+
Output = 2
4+
Constant = 3
5+
end
6+
7+
function whatType(@nospecialize(dt))
8+
if <:(dt, Array)
9+
sub = whatType(eltype(dt))
10+
if sub == "diffe_dup"
11+
return "diffe_dup"
12+
elseif sub == "diffe_out"
13+
return "diffe_dup"
14+
else
15+
@assert(sub == "diffe_const")
16+
return "diffe_const"
17+
end
18+
end
19+
if <:(dt, Real)
20+
return "diffe_out"
21+
end
22+
if <:(dt, Int)
23+
return "diffe_const"
24+
end
25+
if <:(dt, String)
26+
return "diffe_const"
27+
end
28+
29+
if !hasfieldcount(dt)
30+
# just be safe for now
31+
return "diffe_dup"
32+
end
33+
34+
@assert(hasfieldcount(dt))
35+
@assert(isstructtype(dt))
36+
passpointer = true
37+
if passpointer
38+
ty = "diffe_const"
39+
for (ft, fn) in zip(fieldtypes(dt), fieldnames(dt))
40+
sub = whatType(ft)
41+
if sub == "diffe_dup"
42+
ty = "diffe_dup"
43+
elseif sub == "diffe_out"
44+
ty = "diffe_dup"
45+
else
46+
@assert(sub == "diffe_const")
47+
end
48+
end
49+
return ty
50+
else
51+
ty = "diffe_const"
52+
for (ft, fn) in zip(fieldtypes(dt), fieldnames(dt))
53+
sub = whatType(ft)
54+
if sub == "diffe_dup"
55+
ty = "diffe_dup"
56+
elseif sub == "diffe_out"
57+
if ty != "diffe_dup"
58+
ty = "diffe_out"
59+
end
60+
else
61+
@assert(sub == "diffe_const")
62+
end
63+
end
64+
return ty
65+
end
66+
end

0 commit comments

Comments
 (0)