-
Notifications
You must be signed in to change notification settings - Fork 201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of MoMo algorithm #721
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Regarding the second bullet in the original post, I think I have now solved this by using |
Thanks a lot @fabian-sp ! This looks great. I'll make a more through review once I get the vanilla Polyak SGD working (#718) :-) |
Co-authored-by: Fabian Pedregosa <pedregosa@google.com>
Co-authored-by: Fabian Pedregosa <pedregosa@google.com>
hey @fabian-sp , so sorry for the huge delay on this .... Merging the Polyak step-size highlighted some subtle issues that ended all the way in fixing stuff in pytype ... anyway, we got it merged finally, and now we can focus on MoMo! A few of high level before we do a detailed review:
Thanks for all the work! 🙏🏼 |
Thanks @fabianp, I merged the files. Somehow, after updating to the latest main, if I run locally the
Do you know how to fix this? Maybe some old installation of the optax package, but I thought the tester script would install in a new environment anyhow? |
you might need to uninstall optax |
but yeah, its strange ... |
The Github action tests are failing now, not sure why but probably because of changes made in the rest of the package since then? |
Okay, so the issue was that Momo needs the loss function value in the update (like Polyak SGD). This seems to be incompatible with the common tests. So I removed it for now from there, and tests are passing locally. Depending on how you solved this for Polyak SGD, we can do it the same for Momo. @fabianp @vroulet |
so we basically wrote an if/else for polyak_sgd: https://github.com/google-deepmind/optax/blob/main/optax/_src/alias_test.py |
Okay, from my side the changes you requested should be implemented. A minor question: for Polyak-SGD you call the loss value argument |
yes, please use value so it's consistent with the rest of optimizers in optax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again for the contribution! Some minor comments here
optax/contrib/momo.py
Outdated
learning_rate: base.ScalarOrSchedule = 1e-2, | ||
betas: tuple[float, float] = (0.9, 0.999), | ||
eps: float = 1e-8, | ||
lb: float = 0.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about calling this f_min
as with polyak_sgd
? https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.polyak_sgd
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer the term lower bound, as this is closer to how we describe this quantity in the paper (strictly speaking, you don't need the optimal value for deriving MoMo, but only a lower bound). But I see your point of consistent naming...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, but let's at least name it lower_bound
? lb is not very descriptive
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, can do that. Just fyi that I now added also an option where the lower bounded is estimated on the fly, which then will make some variable names a bit lengthy :) but at least the function argument should have a descriptive name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes exactly, having descriptive names in the function signature is the most important. Its fine if private variables have more cyptic names.
count = jnp.zeros([], jnp.int32) | ||
return MomoAdamState(exp_avg, exp_avg_sq, barf, gamma, count) | ||
|
||
def update_fn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have you considered writing this optimizer as a chain using the existing scale_by_adam
? If possible it would probably result in a much shorter (and reusable) code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For computing the adaptive learning rate, we need to compute the Adam EMAs, and some other quantities based on them. So I thought, it would be best to avoid double computations to have all in one function. But I might be wrong here...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I think what you say makes sense
Co-authored-by: Fabian Pedregosa <pedregosa@google.com>
Co-authored-by: Fabian Pedregosa <pedregosa@google.com>
Co-authored-by: Fabian Pedregosa <pedregosa@google.com>
Somehow after changing the formatting suggestions of @fabianp , now one test fails (the injection test for I will also now implement the adaptive lower bound estimation, that we proposed in the paper. |
Thanks! For the injection test, you might want to try converting the elements of the state into arrays. I vaguely recall having similar problems and that solving the issue |
Thanks, I did this. But the problem was solved by using a loss value other than zero here: optax/optax/contrib/_common_test.py Line 117 in 97128c1
With zero, it takes no step (as the Polyak step size is zero), and so it compares numerically zero values against each other, which seems to fail. |
FYI there are some test failures triggered by the last jax release that are unrelated to this PR (so don't worry about that for now, we're working on fixing them in parallel) |
Anything left for me to do, or are you gonna merge after fixing the upstream bug? |
Can you update your branch from master? hopefully the tests will then run again |
hey Fabian, sorry for the delay on this one. I was about to merge yesterday, but I realized what seemed to me like duplicated code between test_momo.py and _test_common.py . In particular, it seems that test_momo is testing on the same parabola and rosenbrock function that _test_commo.py . Am I missing something? |
It's now merged without test_momo.py (it was throwing some errors on our internal tests). Let me know how important it was and we can add it back in a follow up PR |
Hi Fabian, thanks for the final checks. The test_momo.py indeed had the initial tests, but later I used the common test structure of _test_common.py. So deleting it should be fine. Thanks again for the interest in MoMo and the final push! 📦 |
Upon suggestion by @fabianp I implemented the MoMo algorithm. MoMo is esentially a Polyak step size for SGD with momentum and for Adam (see https://arxiv.org/abs/2305.07583).
The Rosenbrock and least squares tests are passing locally.
I have still a few questions as this is the first time I am implementing in Optax:
update_fn
. I named this argumentloss
, and adpated the tests. But maybe you have a convention how sth like this would be handled.