Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of MoMo algorithm #721

Merged
merged 27 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
merge files
  • Loading branch information
fabian-sp committed Mar 26, 2024
commit d653da9eb277a7287a0a304163b20741d9e3ba24
6 changes: 2 additions & 4 deletions optax/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@
from optax.contrib.dadapt_adamw import DAdaptAdamWState
from optax.contrib.mechanic import MechanicState
from optax.contrib.mechanic import mechanize
from optax.contrib.momo import momo
from optax.contrib.momo import MomoState
from optax.contrib.momo_adam import momo_adam
from optax.contrib.momo_adam import MomoAdamState
from optax.contrib.momo import momo, momo_adam
from optax.contrib.momo import MomoState, MomoAdamState
from optax.contrib.privacy import differentially_private_aggregate
from optax.contrib.privacy import DifferentiallyPrivateAggregateState
from optax.contrib.privacy import dpsgd
Expand Down
120 changes: 120 additions & 0 deletions optax/contrib/momo.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,123 @@ def update_fn(
return p_update, new_state

return base.GradientTransformationExtraArgs(init_fn, update_fn)

class MomoAdamState(NamedTuple):
"""State of the `GradientTransformation` returned by `momo_adam`."""
exp_avg: base.Updates
exp_avg_sq: base.Updates
barf: float
gamma: float
count: float


def momo_adam(
learning_rate: base.ScalarOrSchedule = 1.0,
betas: tuple[float, float] = (0.9, 0.999),
fabian-sp marked this conversation as resolved.
Show resolved Hide resolved
eps: float = 1e-8,
lb: float = 0.0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about calling this f_min as with polyak_sgd ? https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.polyak_sgd

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer the term lower bound, as this is closer to how we describe this quantity in the paper (strictly speaking, you don't need the optimal value for deriving MoMo, but only a lower bound). But I see your point of consistent naming...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, but let's at least name it lower_bound ? lb is not very descriptive

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, can do that. Just fyi that I now added also an option where the lower bounded is estimated on the fly, which then will make some variable names a bit lengthy :) but at least the function argument should have a descriptive name

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes exactly, having descriptive names in the function signature is the most important. Its fine if private variables have more cyptic names.

weight_decay: float = 0.
) -> base.GradientTransformationExtraArgs:
"""Adaptive Learning Rates for Adam(W).

MoMo-Adam typically needs less tuning for value of `learning_rate`,
fabian-sp marked this conversation as resolved.
Show resolved Hide resolved
by exploting the fact that a lower bound of the loss (or the optimal value) is
known. For most tasks, zero is a lower bound and an accurate estimate of the
final loss.

MoMo performs Adam(W) with a Polyak-type learning rate. The
effective step size is
`min(learning_rate, <adaptive term>)`

where the adaptive term is computed on the fly.

Note that in `update_fn` you need to pass the latest (batch) loss to
the argument `loss`.

References:
[Schaipp et al., 2023](https://arxiv.org/abs/2305.07583)
Args:
learning_rate: User-specified learning rate. Recommended to be chosen
rather large, by default 1.0.
betas: Adam momentum coefficients (for EMA).
eps: eps for the underlying Adam Optimizer.
lb: Lower bound of the loss. Zero should be a good choice for many tasks.
weight_decay: Weight-decay parameter. Momo-Adam performs weight decay in
similar fashion to AdamW.

Returns:
A `GradientTransformation` object.
"""
def init_fn(params: base.Params) -> MomoAdamState:
exp_avg = tu.tree_map(lambda p: jnp.zeros(p.shape), params)
exp_avg_sq = tu.tree_map(lambda p: jnp.zeros(p.shape, jnp.float32), params)
barf = 0
gamma = 0
count = 0
return MomoAdamState(exp_avg, exp_avg_sq, barf, gamma, count)

def update_fn(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have you considered writing this optimizer as a chain using the existing scale_by_adam ? If possible it would probably result in a much shorter (and reusable) code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For computing the adaptive learning rate, we need to compute the Adam EMAs, and some other quantities based on them. So I thought, it would be best to avoid double computations to have all in one function. But I might be wrong here...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I think what you say makes sense

updates: base.Updates,
state: MomoAdamState,
params: Optional[base.Params],
loss: Optional[Array]) -> tuple[base.Updates, MomoAdamState]:
if params is None:
raise ValueError(base.NO_PARAMS_MSG)
if loss is None:
raise ValueError("""You need to pass the latest loss value to Momo.
Use `jax.value_and_grad` for this.""")
count = state.count
beta1, beta2 = betas
barf = beta1*state.barf + (1-beta1)*loss
exp_avg = tu.tree_map(
lambda ea, g: beta1 * ea + (1-beta1) * g,
state.exp_avg,
updates
)
exp_avg_sq = tu.tree_map(
lambda eas, g: beta2 * eas + (1-beta2) * g * g,
state.exp_avg_sq,
updates,
)
bc2 = 1-beta2**(count+1)
precond = tu.tree_map(
lambda eas: eps + jnp.sqrt(eas/bc2),
exp_avg_sq
)
exp_avg_weighted = tu.tree_map(
lambda ea, prec: ea/prec,
exp_avg,
precond
)
exp_avg_norm = tree_utils.tree_vdot(exp_avg,exp_avg_weighted)
gamma = beta1*state.gamma + (1-beta1)*tree_utils.tree_vdot(updates, params)
iprod = tree_utils.tree_vdot(exp_avg, params)
alpha = learning_rate(count) if callable(learning_rate) else learning_rate
bc1 = 1-beta1**(count+1)
t1 = jnp.maximum((1+alpha*weight_decay)*(
barf - bc1*lb - gamma
) + iprod , 0)/(exp_avg_norm)
# if denom is zero, take no step
t1 = cond(exp_avg_norm <= jnp.finfo(float).eps,
lambda: 0.,
lambda: t1
)
tau = jnp.minimum(alpha/bc1, t1)
p_update = tu.tree_map(
lambda ea, prec, p:
-(alpha*weight_decay)/(1+alpha*weight_decay)*p
- tau*ea/prec,
exp_avg,
precond,
params
)
new_state = MomoAdamState(
exp_avg=exp_avg,
exp_avg_sq=exp_avg_sq,
barf=barf,
gamma=gamma,
count=utils.safe_int32_increment(count)
)
return p_update, new_state

return base.GradientTransformationExtraArgs(init_fn, update_fn)
148 changes: 0 additions & 148 deletions optax/contrib/momo_adam.py

This file was deleted.

87 changes: 0 additions & 87 deletions optax/contrib/momo_adam_test.py

This file was deleted.

Loading