Skip to content

Commit

Permalink
feat: add weight decay
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Aug 26, 2022
1 parent 4a985ea commit cd54216
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 23 deletions.
127 changes: 106 additions & 21 deletions torchopt/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,97 @@
from torchopt._src.utils import pytree


def _scale_by_lr(lr: ScalarOrSchedule, maximize=False):
sign = -1 if not maximize else 1
def _flip_sign_and_weight_decay(weight_decay: float = 0.0, maximize=False):
if not 0.0 <= weight_decay: # pylint: disable=unneeded-not
raise ValueError(f"Invalid weight_decay value: {weight_decay}")

if maximize and weight_decay == 0.0:
return base.identity()

def init_fn(_):
return base.EmptyState()

if not maximize: # gradient descent
if weight_decay == 0.0:
# pylint: disable-next=unused-argument
def update_fn(updates, state, *, params=None, inplace=True):
if inplace:

def f(g):
return g.neg_() if g is not None else None

else:

def f(g):
return g.neg() if g is not None else None

updates = pytree.tree_map(f, updates)
return updates, state

else:

def update_fn(updates, state, *, params=None, inplace=True):
assert params is not None, (
"Parameters are required for weight decay. "
"Call `update(updates, state, params=params)` instead."
)

if inplace:

def f(g, p):
return g.neg_().add_(p, alpha=weight_decay) if g is not None else None

else:

def f(g, p):
return g.neg().add_(p, alpha=weight_decay) if g is not None else None

updates = pytree.tree_map(f, updates, params)
return updates, state

else: # gradient ascent

def update_fn(updates, state, *, params=None, inplace=True):
assert params is not None, (
"Parameters are required for weight decay. "
"Call `update(updates, state, params=params)` instead."
)

if inplace:

def f(g, p):
return g.add_(p, alpha=weight_decay) if g is not None else None

else:

def f(g, p):
return g.add(p, alpha=weight_decay) if g is not None else None

updates = pytree.tree_map(f, updates, params)
return updates, state

return base.GradientTransformation(init_fn, update_fn)


def _scale_by_lr(lr: ScalarOrSchedule):
if callable(lr):

def schedule_wrapper(count):
def f(scaled_lr):
return sign * scaled_lr
return scaled_lr

return pytree.tree_map(f, lr(count)) # type: ignore

return transform.scale_by_schedule(schedule_wrapper)
return transform.scale(sign * lr)
return transform.scale(lr)


# pylint: disable-next=too-many-arguments
def adam(
lr: ScalarOrSchedule = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
*,
eps_root: float = 0.0,
moment_requires_grad: bool = False,
Expand All @@ -81,6 +153,8 @@ def adam(
eps: (float, default: :const:`1e-8`)
A small constant applied to denominator outside of the square root (as in the Adam
paper) to avoid dividing by zero when rescaling.
weight_decay: (float, default: :const:`0.0`):
Weight decay, add L2 penalty to parameters.
eps_root: (float, default: :data:`0.0`)
A small constant applied to denominator inside the square root (as in RMSProp), to avoid
dividing by zero when rescaling. This is needed for example when computing
Expand All @@ -106,26 +180,32 @@ def adam(
raise ValueError(f'Invalid beta parameter at index 0: {b1}')
if not 0.0 <= b2 < 1.0:
raise ValueError(f'Invalid beta parameter at index 1: {b2}')
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
# pylint: enable=unneeded-not

adam_inst = (
transform.scale_by_accelerated_adam if use_accelerated_op else transform.scale_by_adam
)
if use_accelerated_op:
adam_scaler = transform.scale_by_accelerated_adam
else:
adam_scaler = transform.scale_by_adam

return combine.chain(
adam_inst(
_flip_sign_and_weight_decay(weight_decay=weight_decay, maximize=maximize),
adam_scaler(
b1=b1,
b2=b2,
eps=eps,
eps_root=eps_root,
moment_requires_grad=moment_requires_grad,
),
_scale_by_lr(lr, maximize=maximize),
_scale_by_lr(lr),
)


def sgd(
lr: ScalarOrSchedule,
momentum: float = 0.0,
weight_decay: float = 0.0,
nesterov: bool = False,
*,
moment_requires_grad: bool = False,
Expand All @@ -146,6 +226,8 @@ def sgd(
momentum: (float, default: :const:`0.0`)
The decay rate used by the momentum term. The momentum is not used when it is set to
:const:`0.0`.
weight_decay: (float, default: :const:`0.0`):
Weight decay, add L2 penalty to parameters.
nesterov: (bool, default: :data:`False`)
Whether the nesterov momentum is used.
moment_requires_grad: (bool, default: :data:`False`)
Expand All @@ -162,9 +244,12 @@ def sgd(
raise ValueError(f'Invalid learning rate: {lr}')
if not 0.0 <= momentum:
raise ValueError(f'Invalid momentum value: {momentum}')
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
# pylint: enable=unneeded-not

return combine.chain(
_flip_sign_and_weight_decay(weight_decay=weight_decay, maximize=maximize),
(
transform.trace(
decay=momentum,
Expand All @@ -174,7 +259,7 @@ def sgd(
if momentum is not None and momentum != 0.0
else base.identity()
),
_scale_by_lr(lr, maximize=maximize),
_scale_by_lr(lr),
)


Expand All @@ -183,6 +268,7 @@ def rmsprop(
lr: ScalarOrSchedule = 1e-2,
alpha: float = 0.9,
eps: float = 1e-8,
weight_decay: float = 0.0,
momentum: float = 0.0,
centered: bool = False,
*,
Expand All @@ -208,6 +294,8 @@ def rmsprop(
Smoothing constant, the decay used to track the magnitude of previous gradients.
eps: (float, default: :const:`1e-8`)
A small numerical constant to avoid dividing by zero when rescaling.
weight_decay: (float, default: :const:`0.0`):
Weight decay, add L2 penalty to parameters.
momentum: (float, default: :const:`0.0`)
The decay rate used by the momentum term. The momentum is not used when it is set to
:const:`0.0`.
Expand Down Expand Up @@ -235,25 +323,22 @@ def rmsprop(
raise ValueError(f'Invalid epsilon value: {eps}')
if not 0.0 <= momentum:
raise ValueError(f'Invalid momentum value: {momentum}')
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
# pylint: enable=unneeded-not

if centered:
return combine.chain(
transform.scale_by_stddev(alpha=alpha, eps=eps, initial_scale=initial_scale),
(
transform.trace(decay=momentum, nesterov=nesterov)
if momentum is not None and momentum != 0.0
else base.identity()
),
_scale_by_lr(lr, maximize=maximize),
)
rmsprop_scaler = transform.scale_by_stddev
else:
rmsprop_scaler = transform.scale_by_rms

return combine.chain(
transform.scale_by_rms(alpha=alpha, eps=eps, initial_scale=initial_scale),
_flip_sign_and_weight_decay(weight_decay=weight_decay, maximize=maximize),
rmsprop_scaler(alpha=alpha, eps=eps, initial_scale=initial_scale),
(
transform.trace(decay=momentum, nesterov=nesterov)
if momentum is not None and momentum != 0.0
else base.identity()
),
_scale_by_lr(lr, maximize=maximize),
_scale_by_lr(lr),
)
4 changes: 4 additions & 0 deletions torchopt/_src/optimizer/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
lr: ScalarOrSchedule,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
*,
eps_root: float = 0.0,
maximize: bool = False,
Expand All @@ -54,6 +55,8 @@ def __init__(
eps: (float, default: :const:`1e-8`)
A small constant applied to denominator outside of the square root (as in the Adam
paper) to avoid dividing by zero when rescaling.
weight_decay: (float, default: :const:`0.0`):
Weight decay, add L2 penalty to parameters.
eps_root: (float, default: :data:`0.0`)
A small constant applied to denominator inside the square root (as in RMSProp), to
avoid dividing by zero when rescaling. This is needed for example when computing
Expand All @@ -69,6 +72,7 @@ def __init__(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
eps_root=eps_root,
moment_requires_grad=False,
maximize=maximize,
Expand Down
4 changes: 4 additions & 0 deletions torchopt/_src/optimizer/meta/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
lr: ScalarOrSchedule = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
*,
eps_root: float = 0.0,
moment_requires_grad: bool = True,
Expand All @@ -55,6 +56,8 @@ def __init__(
eps: (float, default: :const:`1e-8`)
A small constant applied to denominator outside of the square root (as in the Adam
paper) to avoid dividing by zero when rescaling.
weight_decay: (float, default: :const:`0.0`):
Weight decay, add L2 penalty to parameters.
eps_root: (float, default: :data:`0.0`)
A small constant applied to denominator inside the square root (as in RMSProp), to
avoid dividing by zero when rescaling. This is needed for example when computing
Expand All @@ -73,6 +76,7 @@ def __init__(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
eps_root=eps_root,
moment_requires_grad=moment_requires_grad,
maximize=maximize,
Expand Down
4 changes: 4 additions & 0 deletions torchopt/_src/optimizer/meta/rmsprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
lr: ScalarOrSchedule = 1e-2,
alpha: float = 0.99,
eps: float = 1e-8,
weight_decay: float = 0.0,
momentum: float = 0.0,
centered: bool = False,
*,
Expand All @@ -53,6 +54,8 @@ def __init__(
Smoothing constant, the decay used to track the magnitude of previous gradients.
eps: (float, default: :const:`1e-8`)
A small numerical constant to avoid dividing by zero when rescaling.
weight_decay: (float, default: :const:`0.0`):
Weight decay, add L2 penalty to parameters.
momentum: (float, default: :const:`0.0`)
The decay rate used by the momentum term. The momentum is not used when it is set to
:const:`0.0`.
Expand All @@ -74,6 +77,7 @@ def __init__(
lr=lr,
alpha=alpha,
eps=eps,
weight_decay=weight_decay,
momentum=momentum,
centered=centered,
initial_scale=initial_scale,
Expand Down
4 changes: 4 additions & 0 deletions torchopt/_src/optimizer/meta/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
net: nn.Module,
lr: ScalarOrSchedule,
momentum: float = 0.0,
weight_decay: float = 0.0,
nesterov: bool = False,
moment_requires_grad: bool = True,
maximize: bool = False,
Expand All @@ -48,6 +49,8 @@ def __init__(
momentum: (float, default: :const:`0.0`)
The decay rate used by the momentum term. The momentum is not used when it is set to
:const:`0.0`.
weight_decay: (float, default: :const:`0.0`):
Weight decay, add L2 penalty to parameters.
nesterov: (bool, default: :const:`False`)
Whether the nesterov momentum is used.
moment_requires_grad: (bool, default: :data:`True`)
Expand All @@ -61,6 +64,7 @@ def __init__(
sgd(
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
nesterov=nesterov,
moment_requires_grad=moment_requires_grad,
maximize=maximize,
Expand Down
4 changes: 4 additions & 0 deletions torchopt/_src/optimizer/rmsprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
lr: ScalarOrSchedule = 1e-2,
alpha: float = 0.99,
eps: float = 1e-8,
weight_decay: float = 0.0,
momentum: float = 0.0,
centered: bool = False,
*,
Expand All @@ -55,6 +56,8 @@ def __init__(
Smoothing constant, the decay used to track the magnitude of previous gradients.
eps: (float, default: :const:`1e-8`)
A small numerical constant to avoid dividing by zero when rescaling.
weight_decay: (float, default: :const:`0.0`):
Weight decay, add L2 penalty to parameters.
momentum: (float, default: :const:`0.0`)
The decay rate used by the momentum term. The momentum is not used when it is set to
:const:`0.0`.
Expand All @@ -76,6 +79,7 @@ def __init__(
lr=lr,
alpha=alpha,
eps=eps,
weight_decay=weight_decay,
momentum=momentum,
centered=centered,
initial_scale=initial_scale,
Expand Down
4 changes: 4 additions & 0 deletions torchopt/_src/optimizer/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
params: Iterable[torch.Tensor],
lr: ScalarOrSchedule,
momentum: float = 0.0,
weight_decay: float = 0.0,
nesterov: bool = False,
maximize: bool = False,
):
Expand All @@ -49,6 +50,8 @@ def __init__(
momentum: (float, default: :const:`0.0`)
The decay rate used by the momentum term. The momentum is not used when it is set to
:const:`0.0`.
weight_decay: (float, default: :const:`0.0`):
Weight decay, add L2 penalty to parameters.
nesterov: (bool, default: :data:`False`)
Whether the nesterov momentum is used.
maximize: (bool, default: :data:`False`)
Expand All @@ -59,6 +62,7 @@ def __init__(
sgd(
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
nesterov=nesterov,
moment_requires_grad=False,
maximize=maximize,
Expand Down
3 changes: 1 addition & 2 deletions torchopt/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ def scale(step_size: float) -> base.GradientTransformation:
An ``(init_fn, update_fn)`` tuple.
"""

def init_fn(params):
del params
def init_fn(_):
return ScaleState()

def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument
Expand Down

0 comments on commit cd54216

Please sign in to comment.