@@ -36,7 +36,7 @@ struct NeverInlineMeta <: InlineStateMeta end
36
36
import GPUCompiler: abstract_call_known, GPUInterpreter
37
37
import Core. Compiler: CallMeta, Effects, NoCallInfo, ArgInfo,
38
38
StmtInfo, AbsIntState, EFFECTS_TOTAL,
39
- MethodResultPure
39
+ MethodResultPure, CallInfo, IRCode
40
40
41
41
function abstract_call_known (meta:: InlineStateMeta , interp:: GPUInterpreter , @nospecialize (f),
42
42
arginfo:: ArgInfo , si:: StmtInfo , sv:: AbsIntState , max_methods:: Int )
@@ -69,5 +69,179 @@ function inlining_handler(meta::InlineStateMeta, interp::GPUInterpreter, @nospec
69
69
return nothing
70
70
end
71
71
72
+ struct MockEnzymeMeta end
72
73
73
- end
74
+ # Having to define this function is annoying
75
+ # introduce `abstract type InferenceMeta`
76
+ function inlining_handler (meta:: MockEnzymeMeta , interp:: GPUInterpreter , @nospecialize (atype), callinfo)
77
+ return nothing
78
+ end
79
+
80
+ function autodiff end
81
+
82
+ import GPUCompiler: DeferredCallInfo
83
+ struct AutodiffCallInfo <: CallInfo
84
+ rt
85
+ info:: DeferredCallInfo
86
+ end
87
+
88
+ function abstract_call_known (meta:: Nothing , interp:: GPUInterpreter , f:: typeof (autodiff),
89
+ arginfo:: ArgInfo , si:: StmtInfo , sv:: AbsIntState , max_methods:: Int )
90
+ (; fargs, argtypes) = arginfo
91
+
92
+ @assert f === autodiff
93
+ if length (argtypes) <= 1
94
+ @static if VERSION < v " 1.11.0-"
95
+ return CallMeta (Union{}, Effects (), NoCallInfo ())
96
+ else
97
+ return CallMeta (Union{}, Union{}, Effects (), NoCallInfo ())
98
+ end
99
+ end
100
+
101
+ other_fargs = fargs === nothing ? nothing : fargs[2 : end ]
102
+ other_arginfo = ArgInfo (other_fargs, argtypes[2 : end ])
103
+ # TODO : Ought we not change absint to use MockEnzymeMeta(), otherwise we fill the cache for nothing.
104
+ call = Core. Compiler. abstract_call (interp, other_arginfo, si, sv, max_methods)
105
+ callinfo = DeferredCallInfo (MockEnzymeMeta (), call. rt, call. info)
106
+
107
+ # Real Enzyme must compute `rt` and `exct` according to enzyme semantics
108
+ # and likely perform a unwrapping of fargs...
109
+ rt = call. rt
110
+
111
+ # TODO : Edges? Effects?
112
+ @static if VERSION < v " 1.11.0-"
113
+ # Can't use call.effects since otherwise this call might be just replaced with rt
114
+ return CallMeta (rt, Effects (), AutodiffCallInfo (rt, callinfo))
115
+ else
116
+ return CallMeta (rt, call. exct, Effects (), AutodiffCallInfo (rt, callinfo))
117
+ end
118
+ end
119
+
120
+ function abstract_call_known (meta:: MockEnzymeMeta , interp:: GPUInterpreter , @nospecialize (f),
121
+ arginfo:: ArgInfo , si:: StmtInfo , sv:: AbsIntState , max_methods:: Int )
122
+ return nothing
123
+ end
124
+
125
+ import Core. Compiler: insert_node!, NewInstruction, ReturnNode, Instruction, InliningState, Signature
126
+
127
+ # We really need a Compiler stdlib
128
+ Base. getindex (ir:: IRCode , i) = Core. Compiler. getindex (ir, i)
129
+ Base. setindex! (inst:: Instruction , val, i) = Core. Compiler. setindex! (inst, val, i)
130
+
131
+ const FlagType = VERSION >= v " 1.11.0-" ? UInt32 : UInt8
132
+ function Core. Compiler. handle_call! (todo:: Vector{Pair{Int,Any}} , ir:: IRCode , stmt_idx:: Int ,
133
+ stmt:: Expr , info:: AutodiffCallInfo , flag:: FlagType ,
134
+ sig:: Signature , state:: InliningState )
135
+
136
+ # Goal:
137
+ # The IR we want to inline here is:
138
+ # unpack the args ..
139
+ # ptr = gpuc.deferred(MockEnzymeMeta(), f, primal_args...)
140
+ # ret = ccall("extern __autodiff", llvmcall, RT, Tuple{Ptr{Cvoid, args...}}, ptr, adjoint_args...)
141
+
142
+ # 0. Obtain primal mi from DeferredCallInfo
143
+ # TODO : remove this code duplication
144
+ deferred_info = info. info
145
+ minfo = deferred_info. info
146
+ results = minfo. results
147
+ if length (results. matches) != 1
148
+ return nothing
149
+ end
150
+ match = only (results. matches)
151
+
152
+ # lookup the target mi with correct edge tracking
153
+ # TODO : Effects?
154
+ case = Core. Compiler. compileable_specialization (
155
+ match, Core. Compiler. Effects (), Core. Compiler. InliningEdgeTracker (state), info)
156
+ @assert case isa Core. Compiler. InvokeCase
157
+ @assert stmt. head === :call
158
+
159
+ # Now create the IR we want to inline
160
+ ir = Core. Compiler. IRCode () # contains a placeholder
161
+ args = [Core. Compiler. Argument (i) for i in 2 : length (stmt. args)] # f, args...
162
+ idx = 0
163
+
164
+ # 0. Enzyme proper: Desugar args
165
+ primal_args = args
166
+ primal_argtypes = match. spec_types. parameters[2 : end ]
167
+
168
+ adjoint_rt = info. rt
169
+ adjoint_args = args # TODO
170
+ adjoint_argtypes = primal_argtypes
171
+
172
+ # 1: Since Julia's inliner goes bottom up we need to pretend that we inlined the deferred call
173
+ expr = Expr (:foreigncall ,
174
+ " extern gpuc.lookup" ,
175
+ Ptr{Cvoid},
176
+ Core. svec (#= meta=# Any, #= mi=# Any, #= f=# Any, primal_argtypes... ), # Must use Any for MethodInstance or ftype
177
+ 0 ,
178
+ QuoteNode (:llvmcall ),
179
+ deferred_info. meta,
180
+ case. invoke,
181
+ primal_args...
182
+ )
183
+ ptr = insert_node! (ir, (idx += 1 ), NewInstruction (expr, Ptr{Cvoid}))
184
+
185
+ # 2. Call to magic `__autodiff`
186
+ expr = Expr (:foreigncall ,
187
+ " extern __autodiff" ,
188
+ adjoint_rt,
189
+ Core. svec (Ptr{Cvoid}, Any, adjoint_argtypes... ),
190
+ 0 ,
191
+ QuoteNode (:llvmcall ),
192
+ ptr,
193
+ adjoint_args...
194
+ )
195
+ ret = insert_node! (ir, idx, NewInstruction (expr, adjoint_rt))
196
+
197
+ # Finally replace placeholder return
198
+ ir[Core. SSAValue (1 )][:inst ] = Core. ReturnNode (ret)
199
+ ir[Core. SSAValue (1 )][:type ] = Ptr{Cvoid}
200
+
201
+ ir = Core. Compiler. compact! (ir)
202
+
203
+ # which mi to use here?
204
+ # push inlining todos
205
+ # TODO : Effects
206
+ # aviatesk mentioned using inlining_policy instead...
207
+ itodo = Core. Compiler. InliningTodo (case. invoke, ir, Core. Compiler. Effects ())
208
+ @assert itodo. linear_inline_eligible
209
+ push! (todo, (stmt_idx=> itodo))
210
+
211
+ return nothing
212
+ end
213
+
214
+ function mock_enzyme! (@nospecialize (job), intrinsic, mod:: LLVM.Module )
215
+ changed = false
216
+
217
+ for use in LLVM. uses (intrinsic)
218
+ call = LLVM. user (use)
219
+ LLVM. @dispose builder= LLVM. IRBuilder () begin
220
+ LLVM. position! (builder, call)
221
+ ops = LLVM. operands (call)
222
+ target = ops[1 ]
223
+ if target isa LLVM. ConstantExpr && (LLVM. opcode (target) == LLVM. API. LLVMPtrToInt ||
224
+ LLVM. opcode (target) == LLVM. API. LLVMBitCast)
225
+ target = first (LLVM. operands (target))
226
+ end
227
+ funcT = LLVM. called_type (call)
228
+ funcT = LLVM. FunctionType (LLVM. return_type (funcT), LLVM. parameters (funcT)[3 : end ])
229
+ direct_call = LLVM. call! (builder, funcT, target,
230
+ [ops[i] for i in 3 : length (ops)])
231
+
232
+ LLVM. replace_uses! (call, direct_call)
233
+ end
234
+ if isempty (LLVM. uses (call))
235
+ LLVM. erase! (call)
236
+ changed = true
237
+ else
238
+ # the validator will detect this
239
+ end
240
+ end
241
+
242
+ return changed
243
+ end
244
+
245
+ GPUCompiler. register_plugin! (" __autodiff" , mock_enzyme!)
246
+
247
+ end # module
0 commit comments