-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Curvlinops backend & add default functorch
implementations of many curvature quantities
#146
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reviewed the curvlinops
part for now and pushed fixes directly.
I noticed that our tests are not systematic enough and that we are currently missing a few important cases (e.g., KFAC(-expand) is not the same with backpack and curvlinops for Conv2d
modules). Since this branch still uses the old ASDL version, I think we should fix the tests of the curvature backends after this PR and #144 are merged.
@aleximmer please give a final check and feel free to merge! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks overall good to me. Please have a look at the two comments and let me know what you think. After that it's good to go from my side.
Thanks for taking on the effort and including curvlinops!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, forgot to add this comment in the previous review.
Changelog:
functorch
implementation for many quantities: (inCurvatureInterface
itself)Rationale for (1): BackPACK & ASDL are not actively supported anymore and are hard to extend. Note that we use Curvlinops mainly for KFACs.
Rationale for (2):
functorch
is a first-class PyTorch module => has the highest compatibility to any PyTorch models => futureproofinglaplace-torch
, less dependence to third-party backendsTogether: They make
laplace-torch
very flexible and compatible with most architectures out there (and in the future).With this PR,
Laplace(backend=CurvlinopsBackend)
supports:Note that previous backends only support some subsets of them.
I'd like to ask @runame to specifically review this since he's involved in Curvlinops. Specifically, for the K-FAC backend proposed here.