Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support non-standard scalars in test_scalar #61

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

sethaxen
Copy link
Member

@sethaxen sethaxen commented Oct 6, 2020

test_scalar currently is very Real and Complex focused. This PR generalizes test_scalar to work the same for any scalar for which FiniteDifferences.to_vec (and a handful of base functions) are implemented.

We test it with Quaternions.Quaternion. We'd ideally test against a more minimal number, but it turns out one needs to implement quite a few base methods to get a new number to work correctly.

@oxinabox
Copy link
Member

oxinabox commented Oct 7, 2020

will review tomorrow

src/testers.jl Outdated Show resolved Hide resolved
src/testers.jl Outdated Show resolved Hide resolved
# check that same tangent is produced for tangent 1.0 and 1.0 + 0.0im
@testset "$f at $z, with tangent $Δz" for (i, Δz) in enumerate(Δzs)
frule_test(f, (z, Δz); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
if !isa(Δz, Real) && i == 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be:

Suggested change
if !isa(Δz, Real) && i == 1
if !isa(Δz, Real) && length(Δzs) == 1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, no. i == 1 when when the given tangent vector is purely real, even if it isn't a Real. And this test checks that using an actually Real tangent vector gives the same result.

Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think i would like to look at this again after the comments are addressed.
I am not sure i properly undestand what is going on for the
if !isa(Δz, Real) && i == 1
branches,
and I think I would be better able to, if we have split this into two methods, one for real and one not for real.

src/testers.jl Outdated
Comment on lines 143 to 146
vΩ, Ω_from_vec = to_vec(Ω)
# orthonormal cotangent vectors
vΩ_basis = Diagonal(ones(eltype(vΩ), length(vΩ)))
ΔΩs = [Ω_from_vec(vΩ_basis[:, i]) for i in axes(vΩ_basis, 2)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should move this out into a helper function basis_vectors

@@ -112,61 +112,55 @@ at input point `z` to confirm that there are correct `frule` and `rrule`s provid

`fkwargs` are passed to `f` as keyword arguments.
All keyword arguments except for `fdm` and `fkwargs` are passed to `isapprox`.

To use this tester for a scalar type `MyNumber <: AbstractNumber`,
`FiniteDifferences.to_vec(::MyNumber)` must be implemented.
"""
function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we simplify this code by defining a seperate method for:

Suggested change
function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...)
function test_scalar(f, z::Real; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...)

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would simplify the frule test because we wouldn't need the basis, but if the output is non-real we still need the basis on the output for the rrule test. Adding a separate method would require us to maintain that code in two places.

Comment on lines +270 to +275
function FiniteDifferences.to_vec(q::Quaternion)
function Quaternion_from_vec(q_vec)
return Quaternion(q_vec[1], q_vec[2], q_vec[3], q_vec[4])
end
return [q.s, q.v1, q.v2, q.v3], Quaternion_from_vec
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should move this to be defined in the package itself.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean define this in Quaternions or FiniteDifferences?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or ChainRulesTestUtils?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FiniteDifferences

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if it makes sense to make Quaternions an optional dependency for FiniteDifferences. Since I am only defining this for the purpose of testing, I'm comfortable with being type-piratical but just in the test suite where it can't pollute the methods table for other users. Thoughts?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @oxinabox what do you think?

return quatfun(q), quatfun_pullback
end

q = quatrand()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we define rand_tangent(:: Quaternion) in this package also?
@willtebbutt do you have plans around further advancing rand_tangent ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not currently -- not sure that there's much to do beyond integrating it in with ChainRulesTestUtils in some way or another and continuing to add new methods where necessary.


Get a set of basis (co)tangent vectors for `x`.

This function assumes that the (co)tangent vectors are of the same type as `x` and requires
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be nice to give basis vectors of the same type as the result of rand_tangent instead, but I'm not certain how to do that.

@oxinabox
Copy link
Member

Is this ready for rereview?

@sethaxen
Copy link
Member Author

Is this ready for rereview?

Not yet, I'll try to finish it up this week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants