Skip to content
Prev Previous commit
Next Next commit
Finish testing and cleaning code on non_differentiable macro
  • Loading branch information
oxinabox committed Aug 28, 2020
commit 0d5b1d8077a6ae5ca7672ada70cffa8b70004d9b
50 changes: 24 additions & 26 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,9 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
@assert Meta.isexpr(call, :call)

# Annotate all arguments in the signature as scalars
inputs = _constrain_and_name.(call.args[2:end], :Number)

inputs = esc.(_constrain_and_name.(call.args[2:end], :Number))
# Remove annotations and escape names for the call
call.args = _unconstrain.(call.args)
call.args[2:end] .= _unconstrain.(call.args[2:end])
call.args = esc.(call.args)

# For consistency in code that follows we make all partials tuple expressions
Expand Down Expand Up @@ -186,7 +185,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)

# Δs is the input to the propagator rule
# because this is a pull-back there is one per output of function
Δs = [Symbol(string(:Δ, i)) for i in 1:n_outputs]
Δs = [Symbol(:Δ, i) for i in 1:n_outputs]

# 1 partial derivative per input
pullback_returns = map(1:n_inputs) do input_i
Expand All @@ -197,7 +196,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)
# Multi-output functions have pullbacks with a tuple input that will be destructured
pullback_input = n_outputs == 1 ? first(Δs) : Expr(:tuple, Δs...)
pullback = quote
function $(propagator_name(f, :pullback))($pullback_input)
function $(esc(propagator_name(f, :pullback)))($pullback_input)
return (NO_FIELDS, $(pullback_returns...))
end
end
Expand All @@ -223,16 +222,14 @@ function propagation_expr(Δs, ∂s, _conj = false)
∂s = map(esc, ∂s)
n∂s = length(∂s)

# Due to bugs in Julia 1.0, we can't use `.+` or `.*` inside expression
# literals.
# Due to bugs in Julia 1.0, we can't use `.+` or `.*` inside expression literals.
∂_mul_Δs = if _conj
ntuple(i->:(conj($(∂s[i])) * $(Δs[i])), n∂s)
else
ntuple(i->:($(∂s[i]) * $(Δs[i])), n∂s)
end

# Avoiding the extra `+` operation, it is potentially expensive for vector
# mode AD.
# Avoiding the extra `+` operation, it is potentially expensive for vector mode AD.
sumed_∂_mul_Δs = if n∂s > 1
# we use `@.` to broadcast `*` and `+`
:(@. +($(∂_mul_Δs...)))
Expand Down Expand Up @@ -273,37 +270,38 @@ macro non_differentiable(call_expr)
primal_name, orig_args = Iterators.peel(call_expr.args)

constrained_args = _constrain_and_name.(orig_args, :Any)
primal_sig_parts = [:(::typeof($primal_name)), constrained_args...]

unconstrained_args = _unconstrain.(constrained_args)
primal_invoke = Expr(:call, esc(primal_name), esc.(unconstrained_args)...)


primal_sig_parts = [:(::typeof($primal_name)), constrained_args...]

quote
$(_nondiff_frule_expr(primal_sig_parts, primal_invoke))
$(_nondiff_rrule_expr(primal_sig_parts, primal_invoke))
end
end

# TODO Move to frule helper
frule_defn = Expr(
function _nondiff_frule_expr(primal_sig_parts, primal_invoke)
return Expr(
:(=),
Expr(:call, :(ChainRulesCore.frule), esc(:_), esc.(primal_sig_parts)...),
# How many outputs we have it doesn't matter: `DoesNotExist()` is a iterator that
# returns `DoesNotExist()` for every position.
# Julia functions always only have 1 output, so just return a single DoesNotExist()
Expr(:tuple, primal_invoke, DoesNotExist())
)
end

# TODO Move to rrule helper

function _nondiff_rrule_expr(primal_sig_parts, primal_invoke)
num_primal_inputs = length(primal_sig_parts) - 1
primal_name = first(primal_invoke.args)
pullback_expr = Expr(
:function,
Expr(:call, esc(propagator_name(primal_name, :pullback)), esc(:_)),
Expr(:tuple, NO_FIELDS, (DoesNotExist() for _ in constrained_args)...)
Expr(:tuple, NO_FIELDS, ntuple(_->DoesNotExist(), num_primal_inputs)...)
)
rrule_defn = Expr(
:(=),
Expr(:call, :(ChainRulesCore.rrule), esc.(primal_sig_parts)...),
Expr(:tuple, primal_invoke, pullback_expr),
)

quote
$frule_defn
$rrule_defn
end
end

return rrule_defn
end
39 changes: 34 additions & 5 deletions test/rule_definition_tools.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,40 @@
@testset "rule_definition_tools.jl" begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like there's quite a lot of repeated code here. Did you consider writing a function to test that something has been successfully "non_differentiable"d?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Repeating code mades it read straight forward, and each is different enough that abstracting the tests would be make them harder to read.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me that that is true. Something like the following ought to do the majority of the work:

function test_nondifferentiable(foo, args, dargs, dy)
    @test frule(dargs, foo, args...) == foo(args...)

    y, pb = rrule(foo, args...)
    @test y == foo(args...)
    @test pb(dy) == (Zero(), map(_ -> DoesNotExist(), args)...)
end

To my mind this is more readable.

I'm not going to object to merging this over this though -- I'm happy to stick with what you've done if you feel strongly that it's more readable.


@testset "@nondifferentiable" begin
@testset "@non_differentiable" begin
@testset "nondiff_2_1" begin
nondiff_2_1(x, y) = fill(7.5, 100)[x + y]
@non_differentiable nondiff_2_1(::Any, ::Any)
@test frule((Zero(), 1.2, 2.3), nondiff_2_1, 3, 2) == (7.5, DoesNotExist())
res, pullback = rrule(nondiff_2_1, 3, 2)
@test res == 7.5
@test pullback(4.5) == (NO_FIELDS, DoesNotExist(), DoesNotExist())
end

end
end
@testset "nondiff_1_2" begin
nondiff_1_2(x) = (5.0, 3.0)
@non_differentiable nondiff_1_2(::Any)
@test frule((Zero(), 1.2), nondiff_1_2, 3.1) == ((5.0, 3.0), DoesNotExist())
res, pullback = rrule(nondiff_1_2, 3.1)
@test res == (5.0, 3.0)
@test isequal(
pullback(Composite{Tuple{Float64, Float64}}(1.2, 3.2)),
(NO_FIELDS, DoesNotExist()),
)
end

@testset "specific signature" begin
nonembed_identity(x) = x
@non_differentiable nonembed_identity(::Integer)

@test frule((Zero(), 1.2), nonembed_identity, 2) == (2, DoesNotExist())
@test frule((Zero(), 1.2), nonembed_identity, 2.0) == nothing

Base.remove_linenums!(@macroexpand @non_differentiable println(io::IO))
res, pullback = rrule(nonembed_identity, 2)
@test res == 2
@test pullback(1.2) == (NO_FIELDS, DoesNotExist())

@test rrule(nonembed_identity, 2.0) == nothing
end
end
end

@non_differentiable println(io::IO)