-
Notifications
You must be signed in to change notification settings - Fork 126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds support for energy based learning with NLL loss (LEO) #30
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! The API looks good -- just a sampling function in the optimizer class. Added some comments to address. The GPU test is failing due to cholesky upper not available. We can add support for sparse solvers and optimize performance in later PRs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job on this PR! I see all the functionality is there, but I propose we shuffle some things around a bit. I'm OK leaving that as a separate PR, but we should definitely make issues for the comments so we don't forget.
Regarding the broken test, upper=
kwarg for cholesky was only introduced in torch
1.10.0. It used to be in torch.cholesky
, but it was removed in torch.linalg.cholesky
in v1.9.0 (see screenshot).
For now, can you update the call to match the one used in v1.9.0 so that the test passes?
…esearch#30) * add tests for leo with GN/LM optimizers * add sampler to GN/LM optimizers * run leo on 2d state estimation, add viz, learning_method options
This PR introduces LEO, learning energy-based models in optimization (https://arxiv.org/abs/2108.02274). LEO is a method to learn models end-to-end within second-order optimizers like Gauss-Newton. The main difference is that instead of unrolling the optimizer and minimizing the MSE tracking loss, this introduces the NLL energy-based loss that does not backpropagate through the optimizer. It requires low-energy samples from the optimizer, pushing up the energy of optimizer samples and pushing down the energy of ground truth samples.
To execute it, run
python examples/state_estimation_2d.py
withlearning_method="leo"
This should update the learnable cost weights so that the optimizer trajectory (orange) matches the ground truth trajectory (green).