-
Notifications
You must be signed in to change notification settings - Fork 65
nondifferentiable macro #207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
7f35d07
fc42649
0d5b1d8
15577d3
88459bd
0835187
1252509
4188659
bb2d463
00b959d
9ab6955
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,40 @@ | ||
@testset "rule_definition_tools.jl" begin | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like there's quite a lot of repeated code here. Did you consider writing a function to test that something has been successfully "non_differentiable"d? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Repeating code mades it read straight forward, and each is different enough that abstracting the tests would be make them harder to read. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not clear to me that that is true. Something like the following ought to do the majority of the work: function test_nondifferentiable(foo, args, dargs, dy)
@test frule(dargs, foo, args...) == foo(args...)
y, pb = rrule(foo, args...)
@test y == foo(args...)
@test pb(dy) == (Zero(), map(_ -> DoesNotExist(), args)...)
end To my mind this is more readable. I'm not going to object to merging this over this though -- I'm happy to stick with what you've done if you feel strongly that it's more readable. |
||
|
||
@testset "@nondifferentiable" begin | ||
@testset "@non_differentiable" begin | ||
@testset "nondiff_2_1" begin | ||
oxinabox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
nondiff_2_1(x, y) = fill(7.5, 100)[x + y] | ||
@non_differentiable nondiff_2_1(::Any, ::Any) | ||
@test frule((Zero(), 1.2, 2.3), nondiff_2_1, 3, 2) == (7.5, DoesNotExist()) | ||
res, pullback = rrule(nondiff_2_1, 3, 2) | ||
@test res == 7.5 | ||
@test pullback(4.5) == (NO_FIELDS, DoesNotExist(), DoesNotExist()) | ||
end | ||
|
||
end | ||
end | ||
@testset "nondiff_1_2" begin | ||
oxinabox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
nondiff_1_2(x) = (5.0, 3.0) | ||
@non_differentiable nondiff_1_2(::Any) | ||
@test frule((Zero(), 1.2), nondiff_1_2, 3.1) == ((5.0, 3.0), DoesNotExist()) | ||
res, pullback = rrule(nondiff_1_2, 3.1) | ||
@test res == (5.0, 3.0) | ||
@test isequal( | ||
pullback(Composite{Tuple{Float64, Float64}}(1.2, 3.2)), | ||
(NO_FIELDS, DoesNotExist()), | ||
) | ||
end | ||
|
||
@testset "specific signature" begin | ||
oxinabox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
nonembed_identity(x) = x | ||
@non_differentiable nonembed_identity(::Integer) | ||
|
||
@test frule((Zero(), 1.2), nonembed_identity, 2) == (2, DoesNotExist()) | ||
@test frule((Zero(), 1.2), nonembed_identity, 2.0) == nothing | ||
|
||
Base.remove_linenums!(@macroexpand @non_differentiable println(io::IO)) | ||
res, pullback = rrule(nonembed_identity, 2) | ||
@test res == 2 | ||
@test pullback(1.2) == (NO_FIELDS, DoesNotExist()) | ||
|
||
@test rrule(nonembed_identity, 2.0) == nothing | ||
end | ||
end | ||
end | ||
|
||
@non_differentiable println(io::IO) |
Uh oh!
There was an error while loading. Please reload this page.