Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tests are slow: use vjvp ? #204

Open
mzgubic opened this issue Aug 6, 2021 · 5 comments
Open

Tests are slow: use vjvp ? #204

mzgubic opened this issue Aug 6, 2021 · 5 comments

Comments

@mzgubic
Copy link
Member

mzgubic commented Aug 6, 2021

Compare Zygote's gradtest

julia> @btime gradtest(x -> sum(abs2, x), randn(4, 3, 2))
  17.977 μs (236 allocations: 5.86 KiB)

and CRTUs test_rrule

@btime test_rrule(Zygote.ZygoteRuleConfig(), x -> sum(abs2, x), randn(4, 3, 2); rrule_f=rrule_via_ad)
  1.490 ms (6884 allocations: 480.10 KiB)

it's nearly 100x slower. Do we understand why? I didn't have time to look into it so just making an issue.

ChainRules tests take a pretty long time to run so this might be worth improving.

@oxinabox
Copy link
Member

oxinabox commented Aug 9, 2021

A big part of it is that FiniteDifferences.jl is slower, but more accurate what Zygote's gradtest does.
gradtest is equivelent to central_fdm(3, 1; adapt=0); where as we use central_fdm(5, 1; adapt=1)
so we query 2 more additional points near the point, and we have 1 step of adaption to determine the optimal step-size.

We could try tinkering with that, getting it faster that way might be possible, and might not break too many tests, though we might need to also relax some of the atol/rtol.

@willtebbutt
Copy link
Member

Another thing to consider is that we're currently generating entire vjps, rather than vjvps (vector-Jacobian-vector products), which would be sufficient.

The thing that we really need to test for reverse-mode is that

< J' ȳ, ẋ>  <ȳ, J ẋ>

rrules are good at computing < J' ȳ, ẋ>, while finite differencing is good at computing <ȳ, J ẋ>. We're currently using finite differencing to approximate < J' ȳ, ẋ>, which means we make O(length(ẋ)) more call to FiniteDifferences than we ought to.

The only extra thing we should need to make this work is the ability to compute inner products between the output of rrules, which probably isn't much more difficult than comparing between the output of FiniteDifferences and rrule anyway.

I've been doing this for a while in TemporalGPs.jl and it seems to work really well. I wrote loads of code to hack around ChainRulesTestUtils not being up to what I needed prior to ADIA -- didn't want to contribute it back at the time because I wasn't completely sure whether this was the right way to go about things, but I'm now convinced that it is.

@oxinabox
Copy link
Member

Yeah, and dot i.e. inner product is something all tangent types should overload.

A problem maybe is if it fails that won't tell you where you failed, will it?

@willtebbutt
Copy link
Member

Yeah, and dot i.e. inner product is something all tangent types should overload.

Indeed.

A problem maybe is if it fails that won't tell you where you failed, will it?

Yeah, this is a problem. It does make life a bit trickier when it comes to debugging. I've typically found that you want to retain the ability to (slowly) compute the vjp for debugging puposes. Fortunately, you don't really need to have things like the ability to check for equality lying around to do this -- you'll always be doing it by eye.

One of the real benefits of doing things this way is that you can test AD at scale, For example, whereas with the current way of doing things really requires small problem sizes, the inner product approach can handle any problem size in which you're happy to make a small handful of function evaluations. The advantage is less that it's better to test on big problems, and more that its convenient to be able to test any old problem you have lying around regardless its size.

@willtebbutt
Copy link
Member

Had a quick stab at doing this (I'll not be pushing this further myself in the immediate future, just wanted to see what it might look like) #208

On the surface of it, it doesn't look like we have any substantial practical impediments to doing this, but I've not dug into the details.

@oxinabox oxinabox changed the title Tests are slow Tests are slow: use vjvp ? Aug 22, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants