Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Curvlinops backend & add default functorch implementations of many curvature quantities #146

Merged
merged 19 commits into from
Mar 9, 2024

Conversation

wiseodd
Copy link
Collaborator

@wiseodd wiseodd commented Mar 2, 2024

Changelog:

  1. Add Curvlinops backend.
  2. Add default functorch implementation for many quantities: (in CurvatureInterface itself)
    1. Jacobians
    2. Batch gradients
    3. Diag GGN (both stochastic and exact) and diag EF
    4. Full GGN (both stochastic and exact) and full EF

Rationale for (1): BackPACK & ASDL are not actively supported anymore and are hard to extend. Note that we use Curvlinops mainly for KFACs.

Rationale for (2): functorch is a first-class PyTorch module => has the highest compatibility to any PyTorch models => futureproofing laplace-torch, less dependence to third-party backends

Together: They make laplace-torch very flexible and compatible with most architectures out there (and in the future).

With this PR, Laplace(backend=CurvlinopsBackend) supports:

  1. Diag-GGN-exact, diag-GGN-MC, diag-EF for both classification and regression
  2. Kron-GGN-exact, kron-GGN-MC, kron-EF for both classification and regression
  3. Full-GGN-exact, full-GGN-MC, full-EF, full-Hessian-exact for both classification and regression

Note that previous backends only support some subsets of them.

I'd like to ask @runame to specifically review this since he's involved in Curvlinops. Specifically, for the K-FAC backend proposed here.

@wiseodd wiseodd added the enhancement New feature or request label Mar 2, 2024
@wiseodd wiseodd requested review from aleximmer and runame March 2, 2024 18:07
@wiseodd wiseodd self-assigned this Mar 2, 2024
@wiseodd wiseodd marked this pull request as ready for review March 3, 2024 19:22
Copy link
Collaborator

@runame runame left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reviewed the curvlinops part for now and pushed fixes directly.

I noticed that our tests are not systematic enough and that we are currently missing a few important cases (e.g., KFAC(-expand) is not the same with backpack and curvlinops for Conv2d modules). Since this branch still uses the old ASDL version, I think we should fix the tests of the curvature backends after this PR and #144 are merged.

tests/test_curv_backends_backpack.py Outdated Show resolved Hide resolved
@wiseodd
Copy link
Collaborator Author

wiseodd commented Mar 8, 2024

@aleximmer please give a final check and feel free to merge!

Copy link
Owner

@aleximmer aleximmer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks overall good to me. Please have a look at the two comments and let me know what you think. After that it's good to go from my side.

Thanks for taking on the effort and including curvlinops!

laplace/baselaplace.py Outdated Show resolved Hide resolved
laplace/curvature/curvature.py Outdated Show resolved Hide resolved
Copy link
Owner

@aleximmer aleximmer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, forgot to add this comment in the previous review.

laplace/curvature/curvature.py Outdated Show resolved Hide resolved
@wiseodd wiseodd merged commit fd4de8a into main Mar 9, 2024
@wiseodd wiseodd deleted the backend-curvlinops branch March 9, 2024 23:37
@runame runame linked an issue Mar 12, 2024 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add default Jacobians and individual gradients to curvature
3 participants