Skip to content

Commit 95e094e

Browse files
authored
Merge pull request #207 from JuliaDiff/ox/nondiff
nondifferentiable macro
2 parents 0fe9da8 + 9ab6955 commit 95e094e

File tree

5 files changed

+239
-19
lines changed

5 files changed

+239
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.9.6"
3+
version = "0.9.7"
44

55
[deps]
66
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"

src/ChainRulesCore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using MuladdMacro: @muladd
44

55
export on_new_rule, refresh_rules # generation tools
66
export frule, rrule # core function
7-
export @scalar_rule, @thunk # definition helper macros
7+
export @non_differentiable, @scalar_rule, @thunk # definition helper macros
88
export canonicalize, extern, unthunk # differential operations
99
# differentials
1010
export Composite, DoesNotExist, InplaceableThunk, One, Thunk, Zero, AbstractZero, AbstractThunk

src/rule_definition_tools.jl

Lines changed: 148 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -117,18 +117,10 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
117117
@assert Meta.isexpr(call, :call)
118118

119119
# Annotate all arguments in the signature as scalars
120-
inputs = map(call.args[2:end]) do arg
121-
esc(Meta.isexpr(arg, :(::)) ? arg : Expr(:(::), arg, :Number))
122-
end
123-
120+
inputs = esc.(_constrain_and_name.(call.args[2:end], :Number))
124121
# Remove annotations and escape names for the call
125-
for (i, arg) in enumerate(call.args)
126-
if Meta.isexpr(arg, :(::))
127-
call.args[i] = esc(first(arg.args))
128-
else
129-
call.args[i] = esc(arg)
130-
end
131-
end
122+
call.args[2:end] .= _unconstrain.(call.args[2:end])
123+
call.args = esc.(call.args)
132124

133125
# For consistency in code that follows we make all partials tuple expressions
134126
partials = map(partials) do partial
@@ -143,6 +135,7 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
143135
return call, setup_stmts, inputs, partials
144136
end
145137

138+
146139
function scalar_frule_expr(f, call, setup_stmts, inputs, partials)
147140
n_outputs = length(partials)
148141
n_inputs = length(inputs)
@@ -178,7 +171,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)
178171

179172
# Δs is the input to the propagator rule
180173
# because this is a pull-back there is one per output of function
181-
Δs = [Symbol(string(, i)) for i in 1:n_outputs]
174+
Δs = [Symbol(, i) for i in 1:n_outputs]
182175

183176
# 1 partial derivative per input
184177
pullback_returns = map(1:n_inputs) do input_i
@@ -189,7 +182,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)
189182
# Multi-output functions have pullbacks with a tuple input that will be destructured
190183
pullback_input = n_outputs == 1 ? first(Δs) : Expr(:tuple, Δs...)
191184
pullback = quote
192-
function $(propagator_name(f, :pullback))($pullback_input)
185+
function $(esc(propagator_name(f, :pullback)))($pullback_input)
193186
return (NO_FIELDS, $(pullback_returns...))
194187
end
195188
end
@@ -215,16 +208,14 @@ function propagation_expr(Δs, ∂s, _conj = false)
215208
∂s = map(esc, ∂s)
216209
n∂s = length(∂s)
217210

218-
# Due to bugs in Julia 1.0, we can't use `.+` or `.*` inside expression
219-
# literals.
211+
# Due to bugs in Julia 1.0, we can't use `.+` or `.*` inside expression literals.
220212
∂_mul_Δs = if _conj
221213
ntuple(i->:(conj($(∂s[i])) * $(Δs[i])), n∂s)
222214
else
223215
ntuple(i->:($(∂s[i]) * $(Δs[i])), n∂s)
224216
end
225217

226-
# Avoiding the extra `+` operation, it is potentially expensive for vector
227-
# mode AD.
218+
# Avoiding the extra `+` operation, it is potentially expensive for vector mode AD.
228219
sumed_∂_mul_Δs = if n∂s > 1
229220
# we use `@.` to broadcast `*` and `+`
230221
:(@. +($(∂_mul_Δs...)))
@@ -258,3 +249,143 @@ This is able to deal with fairly complex expressions for `f`:
258249
propagator_name(f::Expr, propname::Symbol) = propagator_name(f.args[end], propname)
259250
propagator_name(fname::Symbol, propname::Symbol) = Symbol(fname, :_, propname)
260251
propagator_name(fname::QuoteNode, propname::Symbol) = propagator_name(fname.value, propname)
252+
253+
"""
254+
@non_differentiable(signature_expression)
255+
256+
A helper to make it easier to declare that a method is not not differentiable.
257+
This is a short-hand for defining an [`frule`](@ref) and [`rrule`](@ref) that
258+
return [`DoesNotExist()`](@ref) for all partials (except for the function `s̄elf`-partial
259+
itself which is `NO_FIELDS`)
260+
261+
Keyword arguments should not be included.
262+
263+
```jldoctest
264+
julia> @non_differentiable Base.:(==)(a, b)
265+
266+
julia> _, pullback = rrule(==, 2.0, 3.0);
267+
268+
julia> pullback(1.0)
269+
(Zero(), DoesNotExist(), DoesNotExist())
270+
```
271+
272+
You can place type-constraints in the signature:
273+
```jldoctest
274+
julia> @non_differentiable Base.length(xs::Union{Number, Array})
275+
276+
julia> frule((Zero(), 1), length, [2.0, 3.0])
277+
(2, DoesNotExist())
278+
```
279+
280+
!!! warning
281+
This helper macro covers only the simple common cases.
282+
It does not support Varargs, or `where`-clauses.
283+
For these you can declare the `rrule` and `frule` directly
284+
285+
"""
286+
macro non_differentiable(sig_expr)
287+
Meta.isexpr(sig_expr, :call) || error("Invalid use of `@non_differentiable`")
288+
for arg in sig_expr.args
289+
_isvararg(arg) && error("@non_differentiable does not support Varargs like: $arg")
290+
end
291+
292+
primal_name, orig_args = Iterators.peel(sig_expr.args)
293+
294+
constrained_args = _constrain_and_name.(orig_args, :Any)
295+
primal_sig_parts = [:(::typeof($primal_name)), constrained_args...]
296+
297+
unconstrained_args = _unconstrain.(constrained_args)
298+
primal_invoke = Expr(:call, esc(primal_name), esc.(unconstrained_args)...)
299+
300+
quote
301+
$(_nondiff_frule_expr(primal_sig_parts, primal_invoke))
302+
$(_nondiff_rrule_expr(primal_sig_parts, primal_invoke))
303+
end
304+
end
305+
306+
function _nondiff_frule_expr(primal_sig_parts, primal_invoke)
307+
return Expr(
308+
:(=),
309+
Expr(:call, :(ChainRulesCore.frule), esc(:_), esc.(primal_sig_parts)...),
310+
# Julia functions always only have 1 output, so just return a single DoesNotExist()
311+
Expr(:tuple, primal_invoke, DoesNotExist()),
312+
)
313+
end
314+
315+
function _nondiff_rrule_expr(primal_sig_parts, primal_invoke)
316+
num_primal_inputs = length(primal_sig_parts) - 1
317+
primal_name = first(primal_invoke.args)
318+
pullback_expr = Expr(
319+
:function,
320+
Expr(:call, esc(propagator_name(primal_name, :pullback)), esc(:_)),
321+
Expr(:tuple, NO_FIELDS, ntuple(_->DoesNotExist(), num_primal_inputs)...)
322+
)
323+
rrule_defn = Expr(
324+
:(=),
325+
Expr(:call, :(ChainRulesCore.rrule), esc.(primal_sig_parts)...),
326+
Expr(:tuple, primal_invoke, pullback_expr),
327+
)
328+
return rrule_defn
329+
end
330+
331+
332+
###########
333+
# Helpers
334+
335+
"""
336+
_isvararg(expr)
337+
338+
returns true if the expression could represent a vararg
339+
340+
```jldoctest
341+
julia> ChainRulesCore._isvararg(:(x...))
342+
true
343+
344+
julia> ChainRulesCore._isvararg(:(x::Int...))
345+
true
346+
347+
julia> ChainRulesCore._isvararg(:(::Int...))
348+
true
349+
350+
julia> ChainRulesCore._isvararg(:(x::Vararg))
351+
true
352+
353+
julia> ChainRulesCore._isvararg(:(x::Vararg{Int}))
354+
true
355+
356+
julia> ChainRulesCore._isvararg(:(::Vararg))
357+
true
358+
359+
julia> ChainRulesCore._isvararg(:(::Vararg{Int}))
360+
true
361+
362+
julia> ChainRulesCore._isvararg(:(x))
363+
false
364+
````
365+
"""
366+
_isvararg(expr) = false
367+
function _isvararg(expr::Expr)
368+
Meta.isexpr(expr, :...) && return true
369+
if Meta.isexpr(expr, :(::))
370+
constraint = last(expr.args)
371+
constraint == :Vararg && return true
372+
Meta.isexpr(constraint, :curly) && first(constraint.args) == :Vararg && return true
373+
end
374+
return false
375+
end
376+
377+
378+
"turn both `a` and `a::S` into `a`"
379+
_unconstrain(arg::Symbol) = arg
380+
function _unconstrain(arg::Expr)
381+
Meta.isexpr(arg, :(::), 2) && return arg.args[1] # drop constraint.
382+
error("malformed arguments: $arg")
383+
end
384+
385+
"turn both `a` and `::constraint` into `a::constraint` etc"
386+
function _constrain_and_name(arg::Expr, _)
387+
Meta.isexpr(arg, :(::), 2) && return arg # it is already fine.
388+
Meta.isexpr(arg, :(::), 1) && return Expr(:(::), gensym(), arg.args[1]) #add name
389+
error("malformed arguments: $arg")
390+
end
391+
_constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type

test/rule_definition_tools.jl

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""
2+
Along same lines as `@test_throws` but to test if a macro throw an exception when it is
3+
expanded.
4+
"""
5+
macro test_macro_throws(err_expr, expr)
6+
quote
7+
err = nothing
8+
try
9+
@macroexpand($(esc(expr)))
10+
catch load_err
11+
# all errors thrown at macro expansion time are LoadErrors, we need to unwrap
12+
@assert load_err isa LoadError
13+
err = load_err.error
14+
end
15+
# Reuse `@test_throws` logic
16+
if err!==nothing
17+
@test_throws $(esc(err_expr)) ($(Meta.quot(expr)); throw(err))
18+
else
19+
@test_throws $(esc(err_expr)) $(Meta.quot(expr))
20+
end
21+
end
22+
end
23+
24+
25+
@testset "rule_definition_tools.jl" begin
26+
@testset "@non_differentiable" begin
27+
@testset "two input one output function" begin
28+
nondiff_2_1(x, y) = fill(7.5, 100)[x + y]
29+
@non_differentiable nondiff_2_1(::Any, ::Any)
30+
@test frule((Zero(), 1.2, 2.3), nondiff_2_1, 3, 2) == (7.5, DoesNotExist())
31+
res, pullback = rrule(nondiff_2_1, 3, 2)
32+
@test res == 7.5
33+
@test pullback(4.5) == (NO_FIELDS, DoesNotExist(), DoesNotExist())
34+
end
35+
36+
@testset "one input, 2-tuple output function" begin
37+
nondiff_1_2(x) = (5.0, 3.0)
38+
@non_differentiable nondiff_1_2(::Any)
39+
@test frule((Zero(), 1.2), nondiff_1_2, 3.1) == ((5.0, 3.0), DoesNotExist())
40+
res, pullback = rrule(nondiff_1_2, 3.1)
41+
@test res == (5.0, 3.0)
42+
@test isequal(
43+
pullback(Composite{Tuple{Float64, Float64}}(1.2, 3.2)),
44+
(NO_FIELDS, DoesNotExist()),
45+
)
46+
end
47+
48+
@testset "constrained signature" begin
49+
nonembed_identity(x) = x
50+
@non_differentiable nonembed_identity(::Integer)
51+
52+
@test frule((Zero(), 1.2), nonembed_identity, 2) == (2, DoesNotExist())
53+
@test frule((Zero(), 1.2), nonembed_identity, 2.0) == nothing
54+
55+
res, pullback = rrule(nonembed_identity, 2)
56+
@test res == 2
57+
@test pullback(1.2) == (NO_FIELDS, DoesNotExist())
58+
59+
@test rrule(nonembed_identity, 2.0) == nothing
60+
end
61+
62+
@testset "Pointy UnionAll constraints" begin
63+
pointy_identity(x) = x
64+
@non_differentiable pointy_identity(::Vector{<:AbstractString})
65+
66+
@test frule((Zero(), 1.2), pointy_identity, ["2"]) == (["2"], DoesNotExist())
67+
@test frule((Zero(), 1.2), pointy_identity, 2.0) == nothing
68+
69+
res, pullback = rrule(pointy_identity, ["2"])
70+
@test res == ["2"]
71+
@test pullback(1.2) == (NO_FIELDS, DoesNotExist())
72+
73+
@test rrule(pointy_identity, 2.0) == nothing
74+
end
75+
76+
@testset "Not supported (Yet)" begin
77+
# Varargs are not supported
78+
@test_macro_throws ErrorException @non_differentiable vararg1(xs...)
79+
@test_macro_throws ErrorException @non_differentiable vararg1(xs::Vararg)
80+
81+
# Where clauses are not supported.
82+
@test_macro_throws(
83+
ErrorException,
84+
(@non_differentiable where_identity(::Vector{T}) where T<:AbstractString)
85+
)
86+
end
87+
end
88+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ using Test
1515

1616
include("ruleset_loading.jl")
1717
include("rules.jl")
18+
include("rule_definition_tools.jl")
1819

1920

2021
@testset "demos" begin

0 commit comments

Comments
 (0)