-
Notifications
You must be signed in to change notification settings - Fork 64
nondifferentiable macro #207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7f35d07
fc42649
0d5b1d8
15577d3
88459bd
0835187
1252509
4188659
bb2d463
00b959d
9ab6955
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
""" | ||
Along same lines as `@test_throws` but to test if a macro throw an exception when it is | ||
expanded. | ||
""" | ||
macro test_macro_throws(err_expr, expr) | ||
quote | ||
err = nothing | ||
try | ||
@macroexpand($(esc(expr))) | ||
catch load_err | ||
# all errors thrown at macro expansion time are LoadErrors, we need to unwrap | ||
@assert load_err isa LoadError | ||
err = load_err.error | ||
end | ||
# Reuse `@test_throws` logic | ||
if err!==nothing | ||
@test_throws $(esc(err_expr)) ($(Meta.quot(expr)); throw(err)) | ||
else | ||
@test_throws $(esc(err_expr)) $(Meta.quot(expr)) | ||
end | ||
end | ||
end | ||
|
||
|
||
@testset "rule_definition_tools.jl" begin | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like there's quite a lot of repeated code here. Did you consider writing a function to test that something has been successfully "non_differentiable"d? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Repeating code mades it read straight forward, and each is different enough that abstracting the tests would be make them harder to read. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not clear to me that that is true. Something like the following ought to do the majority of the work: function test_nondifferentiable(foo, args, dargs, dy)
@test frule(dargs, foo, args...) == foo(args...)
y, pb = rrule(foo, args...)
@test y == foo(args...)
@test pb(dy) == (Zero(), map(_ -> DoesNotExist(), args)...)
end To my mind this is more readable. I'm not going to object to merging this over this though -- I'm happy to stick with what you've done if you feel strongly that it's more readable. |
||
@testset "@non_differentiable" begin | ||
@testset "two input one output function" begin | ||
nondiff_2_1(x, y) = fill(7.5, 100)[x + y] | ||
@non_differentiable nondiff_2_1(::Any, ::Any) | ||
@test frule((Zero(), 1.2, 2.3), nondiff_2_1, 3, 2) == (7.5, DoesNotExist()) | ||
res, pullback = rrule(nondiff_2_1, 3, 2) | ||
@test res == 7.5 | ||
@test pullback(4.5) == (NO_FIELDS, DoesNotExist(), DoesNotExist()) | ||
end | ||
|
||
@testset "one input, 2-tuple output function" begin | ||
nondiff_1_2(x) = (5.0, 3.0) | ||
@non_differentiable nondiff_1_2(::Any) | ||
@test frule((Zero(), 1.2), nondiff_1_2, 3.1) == ((5.0, 3.0), DoesNotExist()) | ||
res, pullback = rrule(nondiff_1_2, 3.1) | ||
@test res == (5.0, 3.0) | ||
@test isequal( | ||
pullback(Composite{Tuple{Float64, Float64}}(1.2, 3.2)), | ||
(NO_FIELDS, DoesNotExist()), | ||
) | ||
end | ||
|
||
@testset "constrained signature" begin | ||
nonembed_identity(x) = x | ||
@non_differentiable nonembed_identity(::Integer) | ||
|
||
@test frule((Zero(), 1.2), nonembed_identity, 2) == (2, DoesNotExist()) | ||
@test frule((Zero(), 1.2), nonembed_identity, 2.0) == nothing | ||
|
||
res, pullback = rrule(nonembed_identity, 2) | ||
@test res == 2 | ||
@test pullback(1.2) == (NO_FIELDS, DoesNotExist()) | ||
|
||
@test rrule(nonembed_identity, 2.0) == nothing | ||
end | ||
|
||
@testset "Pointy UnionAll constraints" begin | ||
pointy_identity(x) = x | ||
@non_differentiable pointy_identity(::Vector{<:AbstractString}) | ||
|
||
@test frule((Zero(), 1.2), pointy_identity, ["2"]) == (["2"], DoesNotExist()) | ||
@test frule((Zero(), 1.2), pointy_identity, 2.0) == nothing | ||
|
||
res, pullback = rrule(pointy_identity, ["2"]) | ||
@test res == ["2"] | ||
@test pullback(1.2) == (NO_FIELDS, DoesNotExist()) | ||
|
||
@test rrule(pointy_identity, 2.0) == nothing | ||
end | ||
|
||
@testset "Not supported (Yet)" begin | ||
# Varargs are not supported | ||
@test_macro_throws ErrorException @non_differentiable vararg1(xs...) | ||
@test_macro_throws ErrorException @non_differentiable vararg1(xs::Vararg) | ||
|
||
# Where clauses are not supported. | ||
@test_macro_throws( | ||
ErrorException, | ||
(@non_differentiable where_identity(::Vector{T}) where T<:AbstractString) | ||
) | ||
end | ||
end | ||
end |
Uh oh!
There was an error while loading. Please reload this page.