Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of MoMo algorithm #721

Merged
merged 27 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
added to common tests
  • Loading branch information
fabian-sp committed Mar 27, 2024
commit 008785f71daf9661336993c2f4ec5e36175fd514
26 changes: 20 additions & 6 deletions optax/contrib/_common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
dict(opt_name='cocob', opt_kwargs=dict(alpha=100.0, eps=1e-8)),
dict(opt_name='dadapt_adamw', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='prodigy', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='momo', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='momo_adam', opt_kwargs=dict(learning_rate=1e-1)),
)


Expand All @@ -42,7 +44,7 @@ def _setup_parabola(dtype):
initial_params = jnp.array([-1.0, 10.0, 1.0], dtype=dtype)
final_params = jnp.array([1.0, -1.0, 1.0], dtype=dtype)

@jax.grad
@jax.value_and_grad
def get_updates(params):
return jnp.sum(numerics.abs_sq(params - final_params))

Expand All @@ -57,7 +59,7 @@ def _setup_rosenbrock(dtype):
initial_params = jnp.array([0.0, 0.0], dtype=dtype)
final_params = jnp.array([a, a**2], dtype=dtype)

@jax.grad
@jax.value_and_grad
def get_updates(params):
return numerics.abs_sq(a - params[0]) + b * numerics.abs_sq(
params[1] - params[0] ** 2
Expand All @@ -79,8 +81,12 @@ def test_optimizers(self, opt_name, opt_kwargs, target, dtype):

@jax.jit
def step(params, state):
updates = get_updates(params)
updates, state = opt.update(updates, state, params)
loss, updates = get_updates(params)
if opt_name in ['momo', 'momo_adam']:
update_kwargs = {'loss': loss}
else:
update_kwargs = {}
updates, state = opt.update(updates, state, params, **update_kwargs)
params = update.apply_updates(params, updates)
return params, state

Expand All @@ -107,12 +113,20 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams(
params = [jnp.negative(jnp.ones((2, 3))), jnp.ones((2, 5, 2))]
grads = [jnp.ones((2, 3)), jnp.negative(jnp.ones((2, 5, 2)))]

if opt_name in ['momo', 'momo_adam']:
update_kwargs = {'loss': jnp.array(0.)}
else:
update_kwargs = {}

state = self.variant(opt.init)(params)
updates, new_state = self.variant(opt.update)(grads, state, params)
updates, new_state = self.variant(opt.update)(
grads, state, params, **update_kwargs
)

state_inject = self.variant(opt_inject.init)(params)
updates_inject, new_state_inject = self.variant(opt_inject.update)(
grads, state_inject, params)
grads, state_inject, params, **update_kwargs
)

with self.subTest('Equality of updates.'):
chex.assert_trees_all_close(updates_inject, updates, rtol=1e-4)
Expand Down
8 changes: 4 additions & 4 deletions optax/contrib/momo.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def momo(
learning_rate: base.ScalarOrSchedule = 1.0,
beta: float = 0.9,
lb: float = 0.0,
weight_decay: float = 0.
weight_decay: float = 0.0
) -> base.GradientTransformationExtraArgs:
"""Adaptive Learning Rates for SGD with momentum.

Expand Down Expand Up @@ -134,11 +134,11 @@ class MomoAdamState(NamedTuple):


def momo_adam(
learning_rate: base.ScalarOrSchedule = 1.0,
learning_rate: base.ScalarOrSchedule = 1e-2,
betas: tuple[float, float] = (0.9, 0.999),
fabian-sp marked this conversation as resolved.
Show resolved Hide resolved
eps: float = 1e-8,
lb: float = 0.0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about calling this f_min as with polyak_sgd ? https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.polyak_sgd

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer the term lower bound, as this is closer to how we describe this quantity in the paper (strictly speaking, you don't need the optimal value for deriving MoMo, but only a lower bound). But I see your point of consistent naming...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, but let's at least name it lower_bound ? lb is not very descriptive

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, can do that. Just fyi that I now added also an option where the lower bounded is estimated on the fly, which then will make some variable names a bit lengthy :) but at least the function argument should have a descriptive name

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes exactly, having descriptive names in the function signature is the most important. Its fine if private variables have more cyptic names.

weight_decay: float = 0.
weight_decay: float = 0.0
) -> base.GradientTransformationExtraArgs:
"""Adaptive Learning Rates for Adam(W).

Expand Down Expand Up @@ -175,7 +175,7 @@ def init_fn(params: base.Params) -> MomoAdamState:
exp_avg_sq = tu.tree_map(lambda p: jnp.zeros(p.shape, jnp.float32), params)
barf = 0
gamma = 0
count = 0
count = jnp.zeros([], jnp.int32)
return MomoAdamState(exp_avg, exp_avg_sq, barf, gamma, count)

def update_fn(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have you considered writing this optimizer as a chain using the existing scale_by_adam ? If possible it would probably result in a much shorter (and reusable) code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For computing the adaptive learning rate, we need to compute the Adam EMAs, and some other quantities based on them. So I thought, it would be best to avoid double computations to have all in one function. But I might be wrong here...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I think what you say makes sense

Expand Down
2 changes: 1 addition & 1 deletion optax/contrib/momo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class MomoAdamTest(chex.TestCase):
dtype=(jnp.float32,),
)
def test_optimization(self, opt_name, target, dtype):
opt = getattr(contrib, opt_name)()
opt = getattr(contrib, opt_name)(learning_rate=0.1)
initial_params, final_params, get_updates = target(dtype)
@jax.jit
def step(params, state):
Expand Down
Loading