|
| 1 | +module Enzyme |
| 2 | + |
| 3 | +export autodiff |
| 4 | + |
| 5 | +using LLVM |
| 6 | +using LLVM.Interop |
| 7 | +import MCAnalyzer: irgen |
| 8 | + |
| 9 | +include("utils.jl") |
| 10 | +include("ad.jl") |
| 11 | +include("opt.jl") |
| 12 | + |
| 13 | +using .Opt: optimize! |
| 14 | + |
| 15 | +function emit(f, args) |
| 16 | + # Obtain the function and all it's dependencies in one handy module |
| 17 | + diffetypes = [] |
| 18 | + autodifftypes = Type[f] |
| 19 | + i = 1 |
| 20 | + while i <= length(args) |
| 21 | + push!(autodifftypes, args[i]) |
| 22 | + dt = whatType(args[i]) |
| 23 | + push!(diffetypes, dt) |
| 24 | + if dt == "diffe_dup" |
| 25 | + i+=1 |
| 26 | + end |
| 27 | + i+=1 |
| 28 | + end |
| 29 | + mod, ccf = irgen(Tuple{autodifftypes...}) |
| 30 | + |
| 31 | + ctx = context(mod) |
| 32 | + rettype = convert(LLVMType, Float64) |
| 33 | + |
| 34 | + #argtypes2 = LLVMType[convert(LLVMType, T, true) for T in args] |
| 35 | + argtypes2 = LLVMType[] |
| 36 | + |
| 37 | + i = 1 |
| 38 | + j = 1 |
| 39 | + orig_params = parameters(ccf) |
| 40 | + for p in orig_params |
| 41 | + push!(argtypes2, llvmtype(p)) |
| 42 | + if diffetypes[i] == "diffe_dup" |
| 43 | + push!(argtypes2, llvmtype(p)) |
| 44 | + i+=2 |
| 45 | + else |
| 46 | + i+=1 |
| 47 | + end |
| 48 | + end |
| 49 | + |
| 50 | + # TODO get function type from ccf |
| 51 | + ft2 = LLVM.FunctionType(rettype, argtypes2) |
| 52 | + |
| 53 | + # create a wrapper Function that we will inline into the llvmcall |
| 54 | + # generated by in the end `call_function` |
| 55 | + llvmf = LLVM.Function(mod, "enzyme_entry", ft2) |
| 56 | + push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0, ctx)) |
| 57 | + |
| 58 | + # Create the FunctionType and funtion decleration for the intrinsic |
| 59 | + pt = LLVM.PointerType(LLVM.Int8Type(ctx)) |
| 60 | + ftd = LLVM.FunctionType(rettype, LLVMType[pt], true) |
| 61 | + autodiff = LLVM.Function(mod, "__enzyme_autodiff", ftd) |
| 62 | + |
| 63 | + params = LLVM.Value[] |
| 64 | + i = 1 |
| 65 | + j = 1 |
| 66 | + llvm_params = parameters(llvmf) |
| 67 | + while j <= length(args) |
| 68 | + push!(params, MDString(diffetypes[i])) |
| 69 | + if diffetypes[i] == "diffe_dup" |
| 70 | + push!(params, llvm_params[j]) |
| 71 | + j+=1 |
| 72 | + end |
| 73 | + push!(params, llvm_params[j]) |
| 74 | + j += 1 |
| 75 | + i += 1 |
| 76 | + end |
| 77 | + |
| 78 | + Builder(ctx) do builder |
| 79 | + entry = BasicBlock(llvmf, "entry", ctx) |
| 80 | + position!(builder, entry) |
| 81 | + |
| 82 | + tc = bitcast!(builder, ccf, pt) |
| 83 | + pushfirst!(params, tc) |
| 84 | + |
| 85 | + val = call!(builder, autodiff, params) |
| 86 | + |
| 87 | + #if T === Nothing |
| 88 | + # ret!(builder) |
| 89 | + #else |
| 90 | + ret!(builder, val) |
| 91 | + #end |
| 92 | + end |
| 93 | + |
| 94 | + llvmf, mod |
| 95 | +end |
| 96 | + |
| 97 | +@generated function autodiff(f, args...) |
| 98 | + llvmf, mod = emit(f, args) |
| 99 | + |
| 100 | + # Run pipeline and Enzyme pass |
| 101 | + optimize!(mod) |
| 102 | + strip_debuginfo!(mod) |
| 103 | + |
| 104 | + _args = (:(args[$i]) for i in 1:length(args)) |
| 105 | + call_function(llvmf, Float64, Tuple{args...}, Expr(:tuple, _args...)) |
| 106 | +end |
| 107 | + |
| 108 | +end # module |
0 commit comments