Skip to content

Commit 3cfc5e1

Browse files
committed
Mock Enzyme plugin
1 parent 4ef6019 commit 3cfc5e1

File tree

3 files changed

+200
-4
lines changed

3 files changed

+200
-4
lines changed

test/native_tests.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,6 @@ end
171171
# smoke test
172172
job, _ = Native.create_job(eval(kernel), (Int64,))
173173

174-
# TODO: Add a `kernel=true` test
175-
176174
ci, rt = only(GPUCompiler.code_typed(job))
177175
@test rt === Ptr{Cvoid}
178176

test/plugin_testsetup.jl

Lines changed: 176 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ struct NeverInlineMeta <: InlineStateMeta end
3636
import GPUCompiler: abstract_call_known, GPUInterpreter
3737
import Core.Compiler: CallMeta, Effects, NoCallInfo, ArgInfo,
3838
StmtInfo, AbsIntState, EFFECTS_TOTAL,
39-
MethodResultPure
39+
MethodResultPure, CallInfo, IRCode
4040

4141
function abstract_call_known(meta::InlineStateMeta, interp::GPUInterpreter, @nospecialize(f),
4242
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int)
@@ -69,5 +69,179 @@ function inlining_handler(meta::InlineStateMeta, interp::GPUInterpreter, @nospec
6969
return nothing
7070
end
7171

72+
struct MockEnzymeMeta end
7273

73-
end
74+
# Having to define this function is annoying
75+
# introduce `abstract type InferenceMeta`
76+
function inlining_handler(meta::MockEnzymeMeta, interp::GPUInterpreter, @nospecialize(atype), callinfo)
77+
return nothing
78+
end
79+
80+
function autodiff end
81+
82+
import GPUCompiler: DeferredCallInfo
83+
struct AutodiffCallInfo <: CallInfo
84+
rt
85+
info::DeferredCallInfo
86+
end
87+
88+
function abstract_call_known(meta::Nothing, interp::GPUInterpreter, f::typeof(autodiff),
89+
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int)
90+
(; fargs, argtypes) = arginfo
91+
92+
@assert f === autodiff
93+
if length(argtypes) <= 1
94+
@static if VERSION < v"1.11.0-"
95+
return CallMeta(Union{}, Effects(), NoCallInfo())
96+
else
97+
return CallMeta(Union{}, Union{}, Effects(), NoCallInfo())
98+
end
99+
end
100+
101+
other_fargs = fargs === nothing ? nothing : fargs[2:end]
102+
other_arginfo = ArgInfo(other_fargs, argtypes[2:end])
103+
# TODO: Ought we not change absint to use MockEnzymeMeta(), otherwise we fill the cache for nothing.
104+
call = Core.Compiler.abstract_call(interp, other_arginfo, si, sv, max_methods)
105+
callinfo = DeferredCallInfo(MockEnzymeMeta(), call.rt, call.info)
106+
107+
# Real Enzyme must compute `rt` and `exct` according to enzyme semantics
108+
# and likely perform a unwrapping of fargs...
109+
rt = call.rt
110+
111+
# TODO: Edges? Effects?
112+
@static if VERSION < v"1.11.0-"
113+
# Can't use call.effects since otherwise this call might be just replaced with rt
114+
return CallMeta(rt, Effects(), AutodiffCallInfo(rt, callinfo))
115+
else
116+
return CallMeta(rt, call.exct, Effects(), AutodiffCallInfo(rt, callinfo))
117+
end
118+
end
119+
120+
function abstract_call_known(meta::MockEnzymeMeta, interp::GPUInterpreter, @nospecialize(f),
121+
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int)
122+
return nothing
123+
end
124+
125+
import Core.Compiler: insert_node!, NewInstruction, ReturnNode, Instruction, InliningState, Signature
126+
127+
# We really need a Compiler stdlib
128+
Base.getindex(ir::IRCode, i) = Core.Compiler.getindex(ir, i)
129+
Base.setindex!(inst::Instruction, val, i) = Core.Compiler.setindex!(inst, val, i)
130+
131+
const FlagType = VERSION >= v"1.11.0-" ? UInt32 : UInt8
132+
function Core.Compiler.handle_call!(todo::Vector{Pair{Int,Any}}, ir::IRCode, stmt_idx::Int,
133+
stmt::Expr, info::AutodiffCallInfo, flag::FlagType,
134+
sig::Signature, state::InliningState)
135+
136+
# Goal:
137+
# The IR we want to inline here is:
138+
# unpack the args ..
139+
# ptr = gpuc.deferred(MockEnzymeMeta(), f, primal_args...)
140+
# ret = ccall("extern __autodiff", llvmcall, RT, Tuple{Ptr{Cvoid, args...}}, ptr, adjoint_args...)
141+
142+
# 0. Obtain primal mi from DeferredCallInfo
143+
# TODO: remove this code duplication
144+
deferred_info = info.info
145+
minfo = deferred_info.info
146+
results = minfo.results
147+
if length(results.matches) != 1
148+
return nothing
149+
end
150+
match = only(results.matches)
151+
152+
# lookup the target mi with correct edge tracking
153+
# TODO: Effects?
154+
case = Core.Compiler.compileable_specialization(
155+
match, Core.Compiler.Effects(), Core.Compiler.InliningEdgeTracker(state), info)
156+
@assert case isa Core.Compiler.InvokeCase
157+
@assert stmt.head === :call
158+
159+
# Now create the IR we want to inline
160+
ir = Core.Compiler.IRCode() # contains a placeholder
161+
args = [Core.Compiler.Argument(i) for i in 2:length(stmt.args)] # f, args...
162+
idx = 0
163+
164+
# 0. Enzyme proper: Desugar args
165+
primal_args = args
166+
primal_argtypes = match.spec_types.parameters[2:end]
167+
168+
adjoint_rt = info.rt
169+
adjoint_args = args # TODO
170+
adjoint_argtypes = primal_argtypes
171+
172+
# 1: Since Julia's inliner goes bottom up we need to pretend that we inlined the deferred call
173+
expr = Expr(:foreigncall,
174+
"extern gpuc.lookup",
175+
Ptr{Cvoid},
176+
Core.svec(#=meta=# Any, #=mi=# Any, #=f=# Any, primal_argtypes...), # Must use Any for MethodInstance or ftype
177+
0,
178+
QuoteNode(:llvmcall),
179+
deferred_info.meta,
180+
case.invoke,
181+
primal_args...
182+
)
183+
ptr = insert_node!(ir, (idx += 1), NewInstruction(expr, Ptr{Cvoid}))
184+
185+
# 2. Call to magic `__autodiff`
186+
expr = Expr(:foreigncall,
187+
"extern __autodiff",
188+
adjoint_rt,
189+
Core.svec(Ptr{Cvoid}, Any, adjoint_argtypes...),
190+
0,
191+
QuoteNode(:llvmcall),
192+
ptr,
193+
adjoint_args...
194+
)
195+
ret = insert_node!(ir, idx, NewInstruction(expr, adjoint_rt))
196+
197+
# Finally replace placeholder return
198+
ir[Core.SSAValue(1)][:inst] = Core.ReturnNode(ret)
199+
ir[Core.SSAValue(1)][:type] = Ptr{Cvoid}
200+
201+
ir = Core.Compiler.compact!(ir)
202+
203+
# which mi to use here?
204+
# push inlining todos
205+
# TODO: Effects
206+
# aviatesk mentioned using inlining_policy instead...
207+
itodo = Core.Compiler.InliningTodo(case.invoke, ir, Core.Compiler.Effects())
208+
@assert itodo.linear_inline_eligible
209+
push!(todo, (stmt_idx=>itodo))
210+
211+
return nothing
212+
end
213+
214+
function mock_enzyme!(@nospecialize(job), intrinsic, mod::LLVM.Module)
215+
changed = false
216+
217+
for use in LLVM.uses(intrinsic)
218+
call = LLVM.user(use)
219+
LLVM.@dispose builder=LLVM.IRBuilder() begin
220+
LLVM.position!(builder, call)
221+
ops = LLVM.operands(call)
222+
target = ops[1]
223+
if target isa LLVM.ConstantExpr && (LLVM.opcode(target) == LLVM.API.LLVMPtrToInt ||
224+
LLVM.opcode(target) == LLVM.API.LLVMBitCast)
225+
target = first(LLVM.operands(target))
226+
end
227+
funcT = LLVM.called_type(call)
228+
funcT = LLVM.FunctionType(LLVM.return_type(funcT), LLVM.parameters(funcT)[3:end])
229+
direct_call = LLVM.call!(builder, funcT, target,
230+
[ops[i] for i in 3:length(ops)])
231+
232+
LLVM.replace_uses!(call, direct_call)
233+
end
234+
if isempty(LLVM.uses(call))
235+
LLVM.erase!(call)
236+
changed = true
237+
else
238+
# the validator will detect this
239+
end
240+
end
241+
242+
return changed
243+
end
244+
245+
GPUCompiler.register_plugin!("__autodiff", mock_enzyme!)
246+
247+
end #module

test/ptx_tests.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,4 +504,28 @@ end
504504
ir = sprint(io->PTX.code_llvm(io, kernel_inline, Tuple{Ptr{Int64}, Int64}, meta=Plugin.NeverInlineMeta()))
505505
@test occursin("call fastcc i64 @julia_inline", ir)
506506
end
507+
508+
@testset "Mock Enzyme" begin
509+
function f(x)
510+
x^2
511+
end
512+
513+
function kernel(a, x)
514+
y = Plugin.autodiff(f, x)
515+
unsafe_store!(a, y)
516+
nothing
517+
end
518+
519+
# This tests deferred_codegen with kernel=true
520+
@show PTX.code_typed(kernel, Tuple{Ptr{Float64}, Float64})
521+
522+
ir = sprint(io->PTX.code_llvm(io, kernel, Tuple{Ptr{Float64}, Float64}, optimize=false))
523+
@test occursin("call double @__autodiff", ir)
524+
@test !occursin("call double @julia_f", ir)
525+
526+
ir = sprint(io->PTX.code_llvm(io, kernel, Tuple{Ptr{Float64}, Float64}, optimize=true))
527+
@test !occursin("call double @__autodiff", ir)
528+
@test occursin("call double @julia_f", ir)
529+
end
530+
507531
end #testitem

0 commit comments

Comments
 (0)