-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support non-standard scalars in test_scalar #61
base: main
Are you sure you want to change the base?
Changes from all commits
e87aba8
17c6bd1
2f4ebbf
0b84838
c6e67da
62c4d1a
21ed347
1d275ac
828278c
630feac
d213840
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -100,6 +100,21 @@ function _make_jvp_call(fdm, f, xs, ẋs, ignores) | |||||
return jvp(fdm, f2, sigargs...) | ||||||
end | ||||||
|
||||||
""" | ||||||
_basis_vectors(x::T) -> Vector{T} | ||||||
|
||||||
Get a set of basis (co)tangent vectors for `x`. | ||||||
|
||||||
This function assumes that the (co)tangent vectors are of the same type as `x` and requires | ||||||
that `FiniteDifferences.to_vec` be implemented for inputs of the same type as `x`. | ||||||
""" | ||||||
function _basis_vectors(x) | ||||||
v, from_vec = FiniteDifferences.to_vec(x) | ||||||
basis_coords = Diagonal(ones(eltype(v), length(v))) | ||||||
basis_vecs = [from_vec(@view basis_coords[:, i]) for i in axes(basis_coords, 2)] | ||||||
return basis_vecs | ||||||
end | ||||||
|
||||||
""" | ||||||
test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), fkwargs=NamedTuple(), kwargs...) | ||||||
|
||||||
|
@@ -112,61 +127,48 @@ at input point `z` to confirm that there are correct `frule` and `rrule`s provid | |||||
|
||||||
`fkwargs` are passed to `f` as keyword arguments. | ||||||
All keyword arguments except for `fdm` and `fkwargs` are passed to `isapprox`. | ||||||
|
||||||
To use this tester for a scalar type `MyNumber <: Number`, | ||||||
`FiniteDifferences.to_vec(::MyNumber)` must be implemented. | ||||||
""" | ||||||
function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we simplify this code by defining a seperate method for:
Suggested change
? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would simplify the |
||||||
_ensure_not_running_on_functor(f, "test_scalar") | ||||||
# z = x + im * y | ||||||
# Ω = u(x, y) + im * v(x, y) | ||||||
Ω = f(z; fkwargs...) | ||||||
|
||||||
Δzs = _basis_vectors(z) | ||||||
Δx = first(Δzs) | ||||||
ΔΩs = _basis_vectors(Ω) | ||||||
|
||||||
# test jacobian using forward mode | ||||||
Δx = one(z) | ||||||
@testset "$f at $z, with tangent $Δx" begin | ||||||
# check ∂u_∂x and (if Ω is complex) ∂v_∂x via forward mode | ||||||
frule_test(f, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...) | ||||||
if z isa Complex | ||||||
# check that same tangent is produced for tangent 1.0 and 1.0 + 0.0im | ||||||
@testset "$f at $z, with tangent $Δz" for (i, Δz) in enumerate(Δzs) | ||||||
frule_test(f, (z, Δz); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...) | ||||||
if !isa(Δz, Real) && i == 1 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this be:
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this case, no. |
||||||
# check that same tangent is produced for tangent real(one(z)) and one(z) | ||||||
@test isapprox( | ||||||
frule((Zero(), real(Δx)), f, z; fkwargs...)[2], | ||||||
frule((Zero(), Δx), f, z; fkwargs...)[2], | ||||||
frule((Zero(), real(Δz)), f, z; fkwargs...)[2], | ||||||
frule((Zero(), Δz), f, z; fkwargs...)[2], | ||||||
rtol=rtol, | ||||||
atol=atol, | ||||||
kwargs..., | ||||||
) | ||||||
end | ||||||
end | ||||||
if z isa Complex | ||||||
Δy = one(z) * im | ||||||
@testset "$f at $z, with tangent $Δy" begin | ||||||
# check ∂u_∂y and (if Ω is complex) ∂v_∂y via forward mode | ||||||
frule_test(f, (z, Δy); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...) | ||||||
end | ||||||
end | ||||||
|
||||||
# test jacobian transpose using reverse mode | ||||||
Δu = one(Ω) | ||||||
@testset "$f at $z, with cotangent $Δu" begin | ||||||
# check ∂u_∂x and (if z is complex) ∂u_∂y via reverse mode | ||||||
rrule_test(f, Δu, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...) | ||||||
if Ω isa Complex | ||||||
# check that same cotangent is produced for cotangent 1.0 and 1.0 + 0.0im | ||||||
@testset "$f at $z, with cotangent $ΔΩ" for (i, ΔΩ) in enumerate(ΔΩs) | ||||||
rrule_test(f, ΔΩ, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...) | ||||||
if !isa(ΔΩ, Real) && i == 1 | ||||||
# check that same cotangent is produced for cotangent real(one(Ω)) and one(Ω) | ||||||
back = rrule(f, z)[2] | ||||||
@test isapprox( | ||||||
extern(back(real(Δu))[2]), | ||||||
extern(back(Δu)[2]), | ||||||
extern(back(real(ΔΩ))[2]), | ||||||
extern(back(ΔΩ)[2]), | ||||||
rtol=rtol, | ||||||
atol=atol, | ||||||
kwargs..., | ||||||
) | ||||||
end | ||||||
end | ||||||
if Ω isa Complex | ||||||
Δv = one(Ω) * im | ||||||
@testset "$f at $z, with cotangent $Δv" begin | ||||||
# check ∂v_∂x and (if z is complex) ∂v_∂y via reverse mode | ||||||
rrule_test(f, Δv, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...) | ||||||
end | ||||||
end | ||||||
end | ||||||
|
||||||
""" | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,8 @@ sinconj(x) = sin(x) | |
primalapprox(x) = x | ||
|
||
|
||
quatfun(q::Quaternion) = Quaternion(q.v3, 2 * q.v1, 3 * q.s, 4 * q.v2) | ||
|
||
@testset "testers.jl" begin | ||
@testset "test_scalar" begin | ||
@testset "Ensure correct rules succeed" begin | ||
|
@@ -367,4 +369,29 @@ primalapprox(x) = x | |
@test fails(()->rrule_test(my_identity2, 4.1, (2.2, 3.3))) | ||
end | ||
end | ||
|
||
@testset "test quaternion non-standard scalar" begin | ||
function FiniteDifferences.to_vec(q::Quaternion) | ||
function Quaternion_from_vec(q_vec) | ||
return Quaternion(q_vec[1], q_vec[2], q_vec[3], q_vec[4]) | ||
end | ||
return [q.s, q.v1, q.v2, q.v3], Quaternion_from_vec | ||
end | ||
Comment on lines
+374
to
+379
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should move this to be defined in the package itself. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean define this in Quaternions or FiniteDifferences? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or ChainRulesTestUtils? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FiniteDifferences There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know if it makes sense to make Quaternions an optional dependency for FiniteDifferences. Since I am only defining this for the purpose of testing, I'm comfortable with being type-piratical but just in the test suite where it can't pollute the methods table for other users. Thoughts? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey @oxinabox what do you think? |
||
|
||
function ChainRulesCore.frule((_, Δq), ::typeof(quatfun), q) | ||
∂q = Quaternion(Δq) | ||
return quatfun(q), Quaternion(∂q.v3, 2 * ∂q.v1, 3 * ∂q.s, 4 * ∂q.v2) | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(quatfun), q) | ||
function quatfun_pullback(ΔΩ) | ||
∂Ω = Quaternion(ΔΩ) | ||
return (NO_FIELDS, Quaternion(3 * ∂Ω.v2, 2 * ∂Ω.v1, 4 * ∂Ω.v3, ∂Ω.s)) | ||
end | ||
return quatfun(q), quatfun_pullback | ||
end | ||
|
||
q = quatrand() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we define There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not currently -- not sure that there's much to do beyond integrating it in with |
||
test_scalar(quatfun, q) | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'd be nice to give basis vectors of the same type as the result of
rand_tangent
instead, but I'm not certain how to do that.