@@ -117,18 +117,10 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
117
117
@assert Meta. isexpr (call, :call )
118
118
119
119
# Annotate all arguments in the signature as scalars
120
- inputs = map (call. args[2 : end ]) do arg
121
- esc (Meta. isexpr (arg, :(:: )) ? arg : Expr (:(:: ), arg, :Number ))
122
- end
123
-
120
+ inputs = esc .(_constrain_and_name .(call. args[2 : end ], :Number ))
124
121
# Remove annotations and escape names for the call
125
- for (i, arg) in enumerate (call. args)
126
- if Meta. isexpr (arg, :(:: ))
127
- call. args[i] = esc (first (arg. args))
128
- else
129
- call. args[i] = esc (arg)
130
- end
131
- end
122
+ call. args[2 : end ] .= _unconstrain .(call. args[2 : end ])
123
+ call. args = esc .(call. args)
132
124
133
125
# For consistency in code that follows we make all partials tuple expressions
134
126
partials = map (partials) do partial
@@ -143,6 +135,7 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
143
135
return call, setup_stmts, inputs, partials
144
136
end
145
137
138
+
146
139
function scalar_frule_expr (f, call, setup_stmts, inputs, partials)
147
140
n_outputs = length (partials)
148
141
n_inputs = length (inputs)
@@ -178,7 +171,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)
178
171
179
172
# Δs is the input to the propagator rule
180
173
# because this is a pull-back there is one per output of function
181
- Δs = [Symbol (string ( :Δ , i) ) for i in 1 : n_outputs]
174
+ Δs = [Symbol (:Δ , i) for i in 1 : n_outputs]
182
175
183
176
# 1 partial derivative per input
184
177
pullback_returns = map (1 : n_inputs) do input_i
@@ -189,7 +182,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)
189
182
# Multi-output functions have pullbacks with a tuple input that will be destructured
190
183
pullback_input = n_outputs == 1 ? first (Δs) : Expr (:tuple , Δs... )
191
184
pullback = quote
192
- function $ (propagator_name (f, :pullback ))($ pullback_input)
185
+ function $ (esc ( propagator_name (f, :pullback ) ))($ pullback_input)
193
186
return (NO_FIELDS, $ (pullback_returns... ))
194
187
end
195
188
end
@@ -215,16 +208,14 @@ function propagation_expr(Δs, ∂s, _conj = false)
215
208
∂s = map (esc, ∂s)
216
209
n∂s = length (∂s)
217
210
218
- # Due to bugs in Julia 1.0, we can't use `.+` or `.*` inside expression
219
- # literals.
211
+ # Due to bugs in Julia 1.0, we can't use `.+` or `.*` inside expression literals.
220
212
∂_mul_Δs = if _conj
221
213
ntuple (i-> :(conj ($ (∂s[i])) * $ (Δs[i])), n∂s)
222
214
else
223
215
ntuple (i-> :($ (∂s[i]) * $ (Δs[i])), n∂s)
224
216
end
225
217
226
- # Avoiding the extra `+` operation, it is potentially expensive for vector
227
- # mode AD.
218
+ # Avoiding the extra `+` operation, it is potentially expensive for vector mode AD.
228
219
sumed_∂_mul_Δs = if n∂s > 1
229
220
# we use `@.` to broadcast `*` and `+`
230
221
:(@. + ($ (∂_mul_Δs... )))
@@ -258,3 +249,143 @@ This is able to deal with fairly complex expressions for `f`:
258
249
propagator_name (f:: Expr , propname:: Symbol ) = propagator_name (f. args[end ], propname)
259
250
propagator_name (fname:: Symbol , propname:: Symbol ) = Symbol (fname, :_ , propname)
260
251
propagator_name (fname:: QuoteNode , propname:: Symbol ) = propagator_name (fname. value, propname)
252
+
253
+ """
254
+ @non_differentiable(signature_expression)
255
+
256
+ A helper to make it easier to declare that a method is not not differentiable.
257
+ This is a short-hand for defining an [`frule`](@ref) and [`rrule`](@ref) that
258
+ return [`DoesNotExist()`](@ref) for all partials (except for the function `s̄elf`-partial
259
+ itself which is `NO_FIELDS`)
260
+
261
+ Keyword arguments should not be included.
262
+
263
+ ```jldoctest
264
+ julia> @non_differentiable Base.:(==)(a, b)
265
+
266
+ julia> _, pullback = rrule(==, 2.0, 3.0);
267
+
268
+ julia> pullback(1.0)
269
+ (Zero(), DoesNotExist(), DoesNotExist())
270
+ ```
271
+
272
+ You can place type-constraints in the signature:
273
+ ```jldoctest
274
+ julia> @non_differentiable Base.length(xs::Union{Number, Array})
275
+
276
+ julia> frule((Zero(), 1), length, [2.0, 3.0])
277
+ (2, DoesNotExist())
278
+ ```
279
+
280
+ !!! warning
281
+ This helper macro covers only the simple common cases.
282
+ It does not support Varargs, or `where`-clauses.
283
+ For these you can declare the `rrule` and `frule` directly
284
+
285
+ """
286
+ macro non_differentiable (sig_expr)
287
+ Meta. isexpr (sig_expr, :call ) || error (" Invalid use of `@non_differentiable`" )
288
+ for arg in sig_expr. args
289
+ _isvararg (arg) && error (" @non_differentiable does not support Varargs like: $arg " )
290
+ end
291
+
292
+ primal_name, orig_args = Iterators. peel (sig_expr. args)
293
+
294
+ constrained_args = _constrain_and_name .(orig_args, :Any )
295
+ primal_sig_parts = [:(:: typeof ($ primal_name)), constrained_args... ]
296
+
297
+ unconstrained_args = _unconstrain .(constrained_args)
298
+ primal_invoke = Expr (:call , esc (primal_name), esc .(unconstrained_args)... )
299
+
300
+ quote
301
+ $ (_nondiff_frule_expr (primal_sig_parts, primal_invoke))
302
+ $ (_nondiff_rrule_expr (primal_sig_parts, primal_invoke))
303
+ end
304
+ end
305
+
306
+ function _nondiff_frule_expr (primal_sig_parts, primal_invoke)
307
+ return Expr (
308
+ :(= ),
309
+ Expr (:call , :(ChainRulesCore. frule), esc (:_ ), esc .(primal_sig_parts)... ),
310
+ # Julia functions always only have 1 output, so just return a single DoesNotExist()
311
+ Expr (:tuple , primal_invoke, DoesNotExist ()),
312
+ )
313
+ end
314
+
315
+ function _nondiff_rrule_expr (primal_sig_parts, primal_invoke)
316
+ num_primal_inputs = length (primal_sig_parts) - 1
317
+ primal_name = first (primal_invoke. args)
318
+ pullback_expr = Expr (
319
+ :function ,
320
+ Expr (:call , esc (propagator_name (primal_name, :pullback )), esc (:_ )),
321
+ Expr (:tuple , NO_FIELDS, ntuple (_-> DoesNotExist (), num_primal_inputs)... )
322
+ )
323
+ rrule_defn = Expr (
324
+ :(= ),
325
+ Expr (:call , :(ChainRulesCore. rrule), esc .(primal_sig_parts)... ),
326
+ Expr (:tuple , primal_invoke, pullback_expr),
327
+ )
328
+ return rrule_defn
329
+ end
330
+
331
+
332
+ # ##########
333
+ # Helpers
334
+
335
+ """
336
+ _isvararg(expr)
337
+
338
+ returns true if the expression could represent a vararg
339
+
340
+ ```jldoctest
341
+ julia> ChainRulesCore._isvararg(:(x...))
342
+ true
343
+
344
+ julia> ChainRulesCore._isvararg(:(x::Int...))
345
+ true
346
+
347
+ julia> ChainRulesCore._isvararg(:(::Int...))
348
+ true
349
+
350
+ julia> ChainRulesCore._isvararg(:(x::Vararg))
351
+ true
352
+
353
+ julia> ChainRulesCore._isvararg(:(x::Vararg{Int}))
354
+ true
355
+
356
+ julia> ChainRulesCore._isvararg(:(::Vararg))
357
+ true
358
+
359
+ julia> ChainRulesCore._isvararg(:(::Vararg{Int}))
360
+ true
361
+
362
+ julia> ChainRulesCore._isvararg(:(x))
363
+ false
364
+ ````
365
+ """
366
+ _isvararg (expr) = false
367
+ function _isvararg (expr:: Expr )
368
+ Meta. isexpr (expr, :... ) && return true
369
+ if Meta. isexpr (expr, :(:: ))
370
+ constraint = last (expr. args)
371
+ constraint == :Vararg && return true
372
+ Meta. isexpr (constraint, :curly ) && first (constraint. args) == :Vararg && return true
373
+ end
374
+ return false
375
+ end
376
+
377
+
378
+ " turn both `a` and `a::S` into `a`"
379
+ _unconstrain (arg:: Symbol ) = arg
380
+ function _unconstrain (arg:: Expr )
381
+ Meta. isexpr (arg, :(:: ), 2 ) && return arg. args[1 ] # drop constraint.
382
+ error (" malformed arguments: $arg " )
383
+ end
384
+
385
+ " turn both `a` and `::constraint` into `a::constraint` etc"
386
+ function _constrain_and_name (arg:: Expr , _)
387
+ Meta. isexpr (arg, :(:: ), 2 ) && return arg # it is already fine.
388
+ Meta. isexpr (arg, :(:: ), 1 ) && return Expr (:(:: ), gensym (), arg. args[1 ]) # add name
389
+ error (" malformed arguments: $arg " )
390
+ end
391
+ _constrain_and_name (name:: Symbol , constraint) = Expr (:(:: ), name, constraint) # add type
0 commit comments