Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support non-standard scalars in test_scalar #61

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
9 changes: 8 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesTestUtils"
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
version = "0.5.3"
version = "0.5.4"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -14,4 +14,11 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ChainRulesCore = "0.9.1"
Compat = "3"
FiniteDifferences = "0.11.2"
Quaternions = "0.4"
julia = "1"

[extras]
Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0"

[targets]
test = ["Quaternions"]
66 changes: 34 additions & 32 deletions src/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,21 @@ function _make_jvp_call(fdm, f, xs, ẋs, ignores)
return jvp(fdm, f2, sigargs...)
end

"""
_basis_vectors(x::T) -> Vector{T}

Get a set of basis (co)tangent vectors for `x`.

This function assumes that the (co)tangent vectors are of the same type as `x` and requires
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be nice to give basis vectors of the same type as the result of rand_tangent instead, but I'm not certain how to do that.

that `FiniteDifferences.to_vec` be implemented for inputs of the same type as `x`.
"""
function _basis_vectors(x)
v, from_vec = FiniteDifferences.to_vec(x)
basis_coords = Diagonal(ones(eltype(v), length(v)))
basis_vecs = [from_vec(@view basis_coords[:, i]) for i in axes(basis_coords, 2)]
return basis_vecs
end

"""
test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), fkwargs=NamedTuple(), kwargs...)

Expand All @@ -112,61 +127,48 @@ at input point `z` to confirm that there are correct `frule` and `rrule`s provid

`fkwargs` are passed to `f` as keyword arguments.
All keyword arguments except for `fdm` and `fkwargs` are passed to `isapprox`.

To use this tester for a scalar type `MyNumber <: Number`,
`FiniteDifferences.to_vec(::MyNumber)` must be implemented.
"""
function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we simplify this code by defining a seperate method for:

Suggested change
function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...)
function test_scalar(f, z::Real; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...)

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would simplify the frule test because we wouldn't need the basis, but if the output is non-real we still need the basis on the output for the rrule test. Adding a separate method would require us to maintain that code in two places.

_ensure_not_running_on_functor(f, "test_scalar")
# z = x + im * y
# Ω = u(x, y) + im * v(x, y)
Ω = f(z; fkwargs...)

Δzs = _basis_vectors(z)
Δx = first(Δzs)
ΔΩs = _basis_vectors(Ω)

# test jacobian using forward mode
Δx = one(z)
@testset "$f at $z, with tangent $Δx" begin
# check ∂u_∂x and (if Ω is complex) ∂v_∂x via forward mode
frule_test(f, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
if z isa Complex
# check that same tangent is produced for tangent 1.0 and 1.0 + 0.0im
@testset "$f at $z, with tangent $Δz" for (i, Δz) in enumerate(Δzs)
frule_test(f, (z, Δz); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
if !isa(Δz, Real) && i == 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be:

Suggested change
if !isa(Δz, Real) && i == 1
if !isa(Δz, Real) && length(Δzs) == 1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, no. i == 1 when when the given tangent vector is purely real, even if it isn't a Real. And this test checks that using an actually Real tangent vector gives the same result.

# check that same tangent is produced for tangent real(one(z)) and one(z)
@test isapprox(
frule((Zero(), real(Δx)), f, z; fkwargs...)[2],
frule((Zero(), Δx), f, z; fkwargs...)[2],
frule((Zero(), real(Δz)), f, z; fkwargs...)[2],
frule((Zero(), Δz), f, z; fkwargs...)[2],
rtol=rtol,
atol=atol,
kwargs...,
)
end
end
if z isa Complex
Δy = one(z) * im
@testset "$f at $z, with tangent $Δy" begin
# check ∂u_∂y and (if Ω is complex) ∂v_∂y via forward mode
frule_test(f, (z, Δy); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
end
end

# test jacobian transpose using reverse mode
Δu = one(Ω)
@testset "$f at $z, with cotangent $Δu" begin
# check ∂u_∂x and (if z is complex) ∂u_∂y via reverse mode
rrule_test(f, Δu, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
if Ω isa Complex
# check that same cotangent is produced for cotangent 1.0 and 1.0 + 0.0im
@testset "$f at $z, with cotangent $ΔΩ" for (i, ΔΩ) in enumerate(ΔΩs)
rrule_test(f, ΔΩ, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
if !isa(ΔΩ, Real) && i == 1
# check that same cotangent is produced for cotangent real(one(Ω)) and one(Ω)
back = rrule(f, z)[2]
@test isapprox(
extern(back(real(Δu))[2]),
extern(back(Δu)[2]),
extern(back(real(ΔΩ))[2]),
extern(back(ΔΩ)[2]),
rtol=rtol,
atol=atol,
kwargs...,
)
end
end
if Ω isa Complex
Δv = one(Ω) * im
@testset "$f at $z, with cotangent $Δv" begin
# check ∂v_∂x and (if z is complex) ∂v_∂y via reverse mode
rrule_test(f, Δv, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
end
end
end

"""
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using ChainRulesCore
using ChainRulesTestUtils
using FiniteDifferences
using LinearAlgebra
using Quaternions
using Random
using Test

Expand Down
27 changes: 27 additions & 0 deletions test/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ sinconj(x) = sin(x)
primalapprox(x) = x


quatfun(q::Quaternion) = Quaternion(q.v3, 2 * q.v1, 3 * q.s, 4 * q.v2)

@testset "testers.jl" begin
@testset "test_scalar" begin
@testset "Ensure correct rules succeed" begin
Expand Down Expand Up @@ -367,4 +369,29 @@ primalapprox(x) = x
@test fails(()->rrule_test(my_identity2, 4.1, (2.2, 3.3)))
end
end

@testset "test quaternion non-standard scalar" begin
function FiniteDifferences.to_vec(q::Quaternion)
function Quaternion_from_vec(q_vec)
return Quaternion(q_vec[1], q_vec[2], q_vec[3], q_vec[4])
end
return [q.s, q.v1, q.v2, q.v3], Quaternion_from_vec
end
Comment on lines +374 to +379
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should move this to be defined in the package itself.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean define this in Quaternions or FiniteDifferences?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or ChainRulesTestUtils?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FiniteDifferences

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if it makes sense to make Quaternions an optional dependency for FiniteDifferences. Since I am only defining this for the purpose of testing, I'm comfortable with being type-piratical but just in the test suite where it can't pollute the methods table for other users. Thoughts?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @oxinabox what do you think?


function ChainRulesCore.frule((_, Δq), ::typeof(quatfun), q)
∂q = Quaternion(Δq)
return quatfun(q), Quaternion(∂q.v3, 2 * ∂q.v1, 3 * ∂q.s, 4 * ∂q.v2)
end

function ChainRulesCore.rrule(::typeof(quatfun), q)
function quatfun_pullback(ΔΩ)
∂Ω = Quaternion(ΔΩ)
return (NO_FIELDS, Quaternion(3 * ∂Ω.v2, 2 * ∂Ω.v1, 4 * ∂Ω.v3, ∂Ω.s))
end
return quatfun(q), quatfun_pullback
end

q = quatrand()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we define rand_tangent(:: Quaternion) in this package also?
@willtebbutt do you have plans around further advancing rand_tangent ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not currently -- not sure that there's much to do beyond integrating it in with ChainRulesTestUtils in some way or another and continuing to add new methods where necessary.

test_scalar(quatfun, q)
end
end