Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of MoMo algorithm #721

Merged
merged 27 commits into from
Apr 23, 2024
Merged

Conversation

fabian-sp
Copy link
Contributor

@fabian-sp fabian-sp commented Jan 19, 2024

Upon suggestion by @fabianp I implemented the MoMo algorithm. MoMo is esentially a Polyak step size for SGD with momentum and for Adam (see https://arxiv.org/abs/2305.07583).

The Rosenbrock and least squares tests are passing locally.

I have still a few questions as this is the first time I am implementing in Optax:

  • MoMo needs in each iteration the latest batch loss passed into update_fn. I named this argument loss, and adpated the tests. But maybe you have a convention how sth like this would be handled.

Copy link

google-cla bot commented Jan 19, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@fabian-sp
Copy link
Contributor Author

Regarding the second bullet in the original post, I think I have now solved this by using jax.lax.cond.

optax/contrib/momo.py Outdated Show resolved Hide resolved
optax/contrib/momo.py Outdated Show resolved Hide resolved
@fabianp
Copy link
Member

fabianp commented Jan 20, 2024

Thanks a lot @fabian-sp ! This looks great. I'll make a more through review once I get the vanilla Polyak SGD working (#718) :-)

fabian-sp and others added 4 commits January 20, 2024 21:08
Co-authored-by: Fabian Pedregosa <pedregosa@google.com>
Co-authored-by: Fabian Pedregosa <pedregosa@google.com>
@fabianp
Copy link
Member

fabianp commented Mar 26, 2024

hey @fabian-sp , so sorry for the huge delay on this .... Merging the Polyak step-size highlighted some subtle issues that ended all the way in fixing stuff in pytype ... anyway, we got it merged finally, and now we can focus on MoMo!

A few of high level before we do a detailed review:

  1. Since you first coded this, @vroulet added some common tests to the contrib directory (https://github.com/google-deepmind/optax/blob/main/optax/contrib/_common_test.py). Please add your solver to this file (under _OPTIMIZERS_UNDER_TEST)

  2. Please merge the files momo.py and momo_adam.py into a single file, and the same for the *_test.py files.

  3. Please take a look at the implementation of polyak_sgd in https://github.com/google-deepmind/optax/blob/main/optax/_src/alias.py . I suspect (but could be wrong) that a similar structure that splits the computation of the update from that of the step-size would make sense here too.

Thanks for all the work! 🙏🏼

@fabian-sp
Copy link
Contributor Author

Thanks @fabianp, I merged the files. Somehow, after updating to the latest main, if I run locally the test.sh I get the following error

ERROR: Cannot install optax 0.1.9.dev0 (from /home/schaipp/uni/other/optax/dist/optax-0.1.9.dev0.tar.gz) and optax 0.2.2.dev0 (from /home/schaipp/uni/other/optax/dist/optax-0.2.2.dev0.tar.gz) because these package versions have conflicting dependencies.

Do you know how to fix this? Maybe some old installation of the optax package, but I thought the tester script would install in a new environment anyhow?

@fabianp
Copy link
Member

fabianp commented Mar 26, 2024

you might need to uninstall optax pip uninstall optax before running the tests

@fabianp
Copy link
Member

fabianp commented Mar 26, 2024

but yeah, its strange ...

@fabian-sp
Copy link
Contributor Author

fabian-sp commented Mar 26, 2024

pip uninstall didnt solve it. Deleting all files in dist/ did the trick though.

The Github action tests are failing now, not sure why but probably because of changes made in the rest of the package since then?

@fabian-sp
Copy link
Contributor Author

Okay, so the issue was that Momo needs the loss function value in the update (like Polyak SGD). This seems to be incompatible with the common tests. So I removed it for now from there, and tests are passing locally.

Depending on how you solved this for Polyak SGD, we can do it the same for Momo. @fabianp @vroulet

@fabianp
Copy link
Member

fabianp commented Mar 26, 2024

so we basically wrote an if/else for polyak_sgd: https://github.com/google-deepmind/optax/blob/main/optax/_src/alias_test.py

@fabian-sp
Copy link
Contributor Author

Okay, from my side the changes you requested should be implemented. A minor question: for Polyak-SGD you call the loss value argument value, while for MoMo I called it loss. This is mostly because I am used to Pytorch - feel free to change this if you prefer value.

@fabianp
Copy link
Member

fabianp commented Mar 27, 2024

yes, please use value so it's consistent with the rest of optimizers in optax

Copy link
Member

@fabianp fabianp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for the contribution! Some minor comments here

optax/contrib/momo.py Outdated Show resolved Hide resolved
optax/contrib/momo.py Outdated Show resolved Hide resolved
optax/contrib/momo.py Outdated Show resolved Hide resolved
optax/contrib/momo.py Show resolved Hide resolved
optax/contrib/momo.py Outdated Show resolved Hide resolved
optax/contrib/momo.py Outdated Show resolved Hide resolved
optax/contrib/momo.py Outdated Show resolved Hide resolved
optax/contrib/momo.py Outdated Show resolved Hide resolved
learning_rate: base.ScalarOrSchedule = 1e-2,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
lb: float = 0.0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about calling this f_min as with polyak_sgd ? https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.polyak_sgd

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer the term lower bound, as this is closer to how we describe this quantity in the paper (strictly speaking, you don't need the optimal value for deriving MoMo, but only a lower bound). But I see your point of consistent naming...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, but let's at least name it lower_bound ? lb is not very descriptive

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, can do that. Just fyi that I now added also an option where the lower bounded is estimated on the fly, which then will make some variable names a bit lengthy :) but at least the function argument should have a descriptive name

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes exactly, having descriptive names in the function signature is the most important. Its fine if private variables have more cyptic names.

count = jnp.zeros([], jnp.int32)
return MomoAdamState(exp_avg, exp_avg_sq, barf, gamma, count)

def update_fn(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have you considered writing this optimizer as a chain using the existing scale_by_adam ? If possible it would probably result in a much shorter (and reusable) code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For computing the adaptive learning rate, we need to compute the Adam EMAs, and some other quantities based on them. So I thought, it would be best to avoid double computations to have all in one function. But I might be wrong here...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I think what you say makes sense

Co-authored-by: Fabian Pedregosa <pedregosa@google.com>
fabian-sp and others added 5 commits April 2, 2024 11:34
Co-authored-by: Fabian Pedregosa <pedregosa@google.com>
Co-authored-by: Fabian Pedregosa <pedregosa@google.com>
@fabian-sp
Copy link
Contributor Author

Somehow after changing the formatting suggestions of @fabianp , now one test fails (the injection test for momo_adam). It's a bit mysterious for me, maybe you can have a look.

I will also now implement the adaptive lower bound estimation, that we proposed in the paper.

@fabianp
Copy link
Member

fabianp commented Apr 2, 2024

Thanks! For the injection test, you might want to try converting the elements of the state into arrays. I vaguely recall having similar problems and that solving the issue

@fabian-sp
Copy link
Contributor Author

Thanks, I did this. But the problem was solved by using a loss value other than zero here:

update_kwargs = {'value': jnp.array(1.)}

With zero, it takes no step (as the Polyak step size is zero), and so it compares numerically zero values against each other, which seems to fail.

@fabianp
Copy link
Member

fabianp commented Apr 4, 2024

FYI there are some test failures triggered by the last jax release that are unrelated to this PR (so don't worry about that for now, we're working on fixing them in parallel)

@fabian-sp
Copy link
Contributor Author

Anything left for me to do, or are you gonna merge after fixing the upstream bug?

@fabianp
Copy link
Member

fabianp commented Apr 10, 2024

Can you update your branch from master? hopefully the tests will then run again

@fabianp
Copy link
Member

fabianp commented Apr 23, 2024

hey Fabian, sorry for the delay on this one. I was about to merge yesterday, but I realized what seemed to me like duplicated code between test_momo.py and _test_common.py . In particular, it seems that test_momo is testing on the same parabola and rosenbrock function that _test_commo.py . Am I missing something?

@copybara-service copybara-service bot merged commit 748ce7f into google-deepmind:main Apr 23, 2024
6 checks passed
@fabianp
Copy link
Member

fabianp commented Apr 23, 2024

It's now merged without test_momo.py (it was throwing some errors on our internal tests). Let me know how important it was and we can add it back in a follow up PR

@fabian-sp
Copy link
Contributor Author

Hi Fabian, thanks for the final checks. The test_momo.py indeed had the initial tests, but later I used the common test structure of _test_common.py. So deleting it should be fine.

Thanks again for the interest in MoMo and the final push! 📦

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants