Skip to content

Add cautious mars, improve test reliability by skipping grad diff for… #2351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ def test_optim_factory(optimizer):
lr = (1e-2,) * 4
if optimizer in ('mars', 'nadam', 'claprop', 'crmsproptf', 'cadafactorbv', 'csgdw', 'clamb'):
lr = (1e-3,) * 4
elif optimizer in ('cmars',):
lr = (1e-4,) * 4

try:
if not opt_info.second_order: # basic tests don't support second order right now
Expand Down
7 changes: 7 additions & 0 deletions timm/optim/_optim_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,13 @@ def _register_cautious_optimizers(registry: OptimizerRegistry) -> None:
has_betas=True,
defaults = {'caution': True}
),
OptimInfo(
name='cmars',
opt_class=Mars,
description='Cautious MARS',
has_betas=True,
defaults={'caution': True}
),
OptimInfo(
name='cnadamw',
opt_class=NAdamW,
Expand Down
96 changes: 60 additions & 36 deletions timm/optim/mars.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,38 +14,50 @@
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: Apache-2.0
import math
from typing import Optional, Tuple

import torch
from torch.optim.optimizer import Optimizer


def mars_single_tensor(
p,
grad,
exp_avg,
exp_avg_sq,
lr,
weight_decay,
beta1,
beta2,
last_grad,
eps,
step,
gamma,
mars_type,
is_grad_2d,
optimize_1d,
lr_1d_factor,
betas_1d,
from ._types import ParamsT


def _mars_single_tensor_step(
p: torch.Tensor,
grad: torch.Tensor,
exp_avg: torch.Tensor,
exp_avg_sq: torch.Tensor,
lr: float,
weight_decay: float,
beta1: float,
beta2: float,
last_grad: torch.Tensor,
eps: float,
step: int,
gamma: float,
mars_type: str,
is_grad_2d: bool,
optimize_1d: bool,
lr_1d_factor: bool,
betas_1d: Tuple[float, float],
caution: bool,
):
# optimize_1d: use MARS for 1d para, not: use AdamW for 1d para
# optimize_1d ==> use MARS for 1d param, else use AdamW
if optimize_1d or is_grad_2d:
one_minus_beta1 = 1. - beta1
c_t = (grad - last_grad).mul_(gamma * (beta1 / one_minus_beta1)).add_(grad)
c_t_norm = torch.norm(c_t)
if c_t_norm > 1.:
c_t = c_t / c_t_norm
if step == 1:
# this is a timm addition, making first step more consistent when no grad history, otherwise tests fail
c_t = grad
else:
c_t = (grad - last_grad).mul_(gamma * (beta1 / one_minus_beta1)).add_(grad)
c_t_norm = torch.norm(c_t)
if c_t_norm > 1.:
c_t = c_t / c_t_norm
exp_avg.mul_(beta1).add_(c_t, alpha=one_minus_beta1)
if caution:
mask = (exp_avg * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
exp_avg = exp_avg * mask
if mars_type == "adamw":
exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1. - beta2)
bias_correction1 = 1.0 - beta1 ** step
Expand All @@ -64,6 +76,10 @@ def mars_single_tensor(
bias_correction1 = 1.0 - beta1_1d ** step
bias_correction2 = 1.0 - beta2_1d ** step
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
if caution:
mask = (exp_avg * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
exp_avg = exp_avg * mask
update = p * weight_decay + (exp_avg / bias_correction1).div_(denom)
p.add_(update, alpha=-(lr * lr_1d_factor))
return exp_avg, exp_avg_sq
Expand All @@ -78,16 +94,17 @@ class Mars(Optimizer):
"""
def __init__(
self,
params,
lr=3e-3,
betas=(0.9, 0.99),
eps=1e-8,
weight_decay=0.,
gamma=0.025,
mars_type="adamw",
optimize_1d=False,
lr_1d_factor=1.0,
betas_1d=None,
params: ParamsT,
lr: float = 3e-3,
betas: Tuple[float, float] = (0.9, 0.99),
eps: float = 1e-8,
weight_decay: float = 0.,
gamma: float = 0.025,
mars_type: str = "adamw",
optimize_1d: bool = False,
lr_1d_factor: float = 1.0,
betas_1d: Optional[Tuple[float, float]] = None,
caution: bool = False
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
Expand All @@ -109,9 +126,15 @@ def __init__(
optimize_1d=optimize_1d,
lr_1d_factor=lr_1d_factor,
betas_1d=betas_1d or betas,
caution=caution,
)
super(Mars, self).__init__(params, defaults)

def __setstate__(self, state):
super(Mars, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('caution', False)

@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Expand All @@ -134,7 +157,6 @@ def step(self, closure=None):
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')

state = self.state[p]
# ('----- starting a parameter state', state.keys(), 'Length of state', len(state))
# State initialization
if len(state) <= 1:
state['step'] = 0
Expand All @@ -155,7 +177,8 @@ def step(self, closure=None):
beta1, beta2 = group['betas']
is_grad_2d = grad.ndim >= 2

mars_single_tensor(
# FIXME add multi-tensor (if usage warrants), make more standard
_mars_single_tensor_step(
p,
grad,
exp_avg,
Expand All @@ -173,6 +196,7 @@ def step(self, closure=None):
optimize_1d=group['optimize_1d'],
lr_1d_factor=group['lr_1d_factor'],
betas_1d=group['betas_1d'],
caution=group['caution'],
)

state['last_grad'] = grad
Expand Down
Loading