Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support automatic differentiation of the NN inside the loss function? #150

Closed
rozsasarpi opened this issue Sep 26, 2020 · 1 comment
Closed

Comments

@rozsasarpi
Copy link

I saw that you plan to support automatic differentiation of the NN inside the loss function. Do you have a plan/roadmap for this? I'm interested to look into how this could be done so if you have some code snippets and/or notes on this please share them.

@ChrisRackauckas
Copy link
Member

The problem is that the most efficient way to do this is to use forward-mode inside of the loss function and then reverse over that. Using ForwardDiff in here kind of works, but it requires a few workaround and hard to generalize. But Zygote got a forward mode: FluxML/Zygote.jl#503 . It needs a few more things to be fully compatible with standard NNs people want to use for physics-informed neural networks though: FluxML/Zygote.jl#654 . But we are well aware of this, and @Keno has been working on a major improvement to the AD system which make this a lot better, and @DhairyaLGandhi is aware of this use case.

That said, the reason why it's not a huge issue is that the computational complexity of numerical and forward mode is the same, with forward mode just decreasing the number of primal calculations and allowing a bit more SIMD/CSE in some cases. You never see forward mode more than 4x better than the fastest numerical differentiation schemes, usually more around 2x. What's essential for performance is the reverse mode of the loss function, which is already there, since that has a massive complexity change. This is why I haven't made a big push to get "something for now and better for later", and instead am just waiting for the big nested AD changes coming later this year since the actual difference to a user will be rather slim (much slimmer than you might suspect), so it's not worth making a fuss about until special compiler tools allow for faster high order derivatives (which something like PyTorch doesn't do, so that would be something to write home about).

Pinging @KirillZubov since I know he was curious about this detail as well.

But indeed, this is an issue so thanks for opening it so we can formally track it. We'll start updating the public on this more often since there's a lot going on here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants