diff --git a/torchopt/_src/alias.py b/torchopt/_src/alias.py index 495673656..d25855b3c 100644 --- a/torchopt/_src/alias.py +++ b/torchopt/_src/alias.py @@ -39,18 +39,89 @@ from torchopt._src.utils import pytree -def _scale_by_lr(lr: ScalarOrSchedule, maximize=False): - sign = -1 if not maximize else 1 +def _flip_sign_and_weight_decay(weight_decay: float = 0.0, maximize=False): + if not 0.0 <= weight_decay: # pylint: disable=unneeded-not + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + if maximize and weight_decay == 0.0: + return base.identity() + + def init_fn(_): + return base.EmptyState() + + if not maximize: # gradient descent + if weight_decay == 0.0: + # pylint: disable-next=unused-argument + def update_fn(updates, state, *, params=None, inplace=True): + if inplace: + + def f(g): + return g.neg_() if g is not None else None + + else: + + def f(g): + return g.neg() if g is not None else None + + updates = pytree.tree_map(f, updates) + return updates, state + + else: + + def update_fn(updates, state, *, params=None, inplace=True): + assert params is not None, ( + "Parameters are required for weight decay. " + "Call `update(updates, state, params=params)` instead." + ) + + if inplace: + + def f(g, p): + return g.neg_().add_(p, alpha=weight_decay) if g is not None else None + + else: + + def f(g, p): + return g.neg().add_(p, alpha=weight_decay) if g is not None else None + + updates = pytree.tree_map(f, updates, params) + return updates, state + + else: # gradient ascent + + def update_fn(updates, state, *, params=None, inplace=True): + assert params is not None, ( + "Parameters are required for weight decay. " + "Call `update(updates, state, params=params)` instead." + ) + + if inplace: + + def f(g, p): + return g.add_(p, alpha=weight_decay) if g is not None else None + + else: + + def f(g, p): + return g.add(p, alpha=weight_decay) if g is not None else None + + updates = pytree.tree_map(f, updates, params) + return updates, state + + return base.GradientTransformation(init_fn, update_fn) + + +def _scale_by_lr(lr: ScalarOrSchedule): if callable(lr): def schedule_wrapper(count): def f(scaled_lr): - return sign * scaled_lr + return scaled_lr return pytree.tree_map(f, lr(count)) # type: ignore return transform.scale_by_schedule(schedule_wrapper) - return transform.scale(sign * lr) + return transform.scale(lr) # pylint: disable-next=too-many-arguments @@ -58,6 +129,7 @@ def adam( lr: ScalarOrSchedule = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, + weight_decay: float = 0.0, *, eps_root: float = 0.0, moment_requires_grad: bool = False, @@ -81,6 +153,8 @@ def adam( eps: (float, default: :const:`1e-8`) A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. + weight_decay: (float, default: :const:`0.0`): + Weight decay, add L2 penalty to parameters. eps_root: (float, default: :data:`0.0`) A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing @@ -106,26 +180,32 @@ def adam( raise ValueError(f'Invalid beta parameter at index 0: {b1}') if not 0.0 <= b2 < 1.0: raise ValueError(f'Invalid beta parameter at index 1: {b2}') + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") # pylint: enable=unneeded-not - adam_inst = ( - transform.scale_by_accelerated_adam if use_accelerated_op else transform.scale_by_adam - ) + if use_accelerated_op: + adam_scaler = transform.scale_by_accelerated_adam + else: + adam_scaler = transform.scale_by_adam + return combine.chain( - adam_inst( + _flip_sign_and_weight_decay(weight_decay=weight_decay, maximize=maximize), + adam_scaler( b1=b1, b2=b2, eps=eps, eps_root=eps_root, moment_requires_grad=moment_requires_grad, ), - _scale_by_lr(lr, maximize=maximize), + _scale_by_lr(lr), ) def sgd( lr: ScalarOrSchedule, momentum: float = 0.0, + weight_decay: float = 0.0, nesterov: bool = False, *, moment_requires_grad: bool = False, @@ -146,6 +226,8 @@ def sgd( momentum: (float, default: :const:`0.0`) The decay rate used by the momentum term. The momentum is not used when it is set to :const:`0.0`. + weight_decay: (float, default: :const:`0.0`): + Weight decay, add L2 penalty to parameters. nesterov: (bool, default: :data:`False`) Whether the nesterov momentum is used. moment_requires_grad: (bool, default: :data:`False`) @@ -162,9 +244,12 @@ def sgd( raise ValueError(f'Invalid learning rate: {lr}') if not 0.0 <= momentum: raise ValueError(f'Invalid momentum value: {momentum}') + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") # pylint: enable=unneeded-not return combine.chain( + _flip_sign_and_weight_decay(weight_decay=weight_decay, maximize=maximize), ( transform.trace( decay=momentum, @@ -174,7 +259,7 @@ def sgd( if momentum is not None and momentum != 0.0 else base.identity() ), - _scale_by_lr(lr, maximize=maximize), + _scale_by_lr(lr), ) @@ -183,6 +268,7 @@ def rmsprop( lr: ScalarOrSchedule = 1e-2, alpha: float = 0.9, eps: float = 1e-8, + weight_decay: float = 0.0, momentum: float = 0.0, centered: bool = False, *, @@ -208,6 +294,8 @@ def rmsprop( Smoothing constant, the decay used to track the magnitude of previous gradients. eps: (float, default: :const:`1e-8`) A small numerical constant to avoid dividing by zero when rescaling. + weight_decay: (float, default: :const:`0.0`): + Weight decay, add L2 penalty to parameters. momentum: (float, default: :const:`0.0`) The decay rate used by the momentum term. The momentum is not used when it is set to :const:`0.0`. @@ -235,25 +323,22 @@ def rmsprop( raise ValueError(f'Invalid epsilon value: {eps}') if not 0.0 <= momentum: raise ValueError(f'Invalid momentum value: {momentum}') + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") # pylint: enable=unneeded-not if centered: - return combine.chain( - transform.scale_by_stddev(alpha=alpha, eps=eps, initial_scale=initial_scale), - ( - transform.trace(decay=momentum, nesterov=nesterov) - if momentum is not None and momentum != 0.0 - else base.identity() - ), - _scale_by_lr(lr, maximize=maximize), - ) + rmsprop_scaler = transform.scale_by_stddev + else: + rmsprop_scaler = transform.scale_by_rms return combine.chain( - transform.scale_by_rms(alpha=alpha, eps=eps, initial_scale=initial_scale), + _flip_sign_and_weight_decay(weight_decay=weight_decay, maximize=maximize), + rmsprop_scaler(alpha=alpha, eps=eps, initial_scale=initial_scale), ( transform.trace(decay=momentum, nesterov=nesterov) if momentum is not None and momentum != 0.0 else base.identity() ), - _scale_by_lr(lr, maximize=maximize), + _scale_by_lr(lr), ) diff --git a/torchopt/_src/optimizer/adam.py b/torchopt/_src/optimizer/adam.py index 9af1769eb..56da62235 100644 --- a/torchopt/_src/optimizer/adam.py +++ b/torchopt/_src/optimizer/adam.py @@ -37,6 +37,7 @@ def __init__( lr: ScalarOrSchedule, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, + weight_decay: float = 0.0, *, eps_root: float = 0.0, maximize: bool = False, @@ -54,6 +55,8 @@ def __init__( eps: (float, default: :const:`1e-8`) A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. + weight_decay: (float, default: :const:`0.0`): + Weight decay, add L2 penalty to parameters. eps_root: (float, default: :data:`0.0`) A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing @@ -69,6 +72,7 @@ def __init__( lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, eps_root=eps_root, moment_requires_grad=False, maximize=maximize, diff --git a/torchopt/_src/optimizer/meta/adam.py b/torchopt/_src/optimizer/meta/adam.py index 8ae58a638..7104d0e1e 100644 --- a/torchopt/_src/optimizer/meta/adam.py +++ b/torchopt/_src/optimizer/meta/adam.py @@ -37,6 +37,7 @@ def __init__( lr: ScalarOrSchedule = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, + weight_decay: float = 0.0, *, eps_root: float = 0.0, moment_requires_grad: bool = True, @@ -55,6 +56,8 @@ def __init__( eps: (float, default: :const:`1e-8`) A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. + weight_decay: (float, default: :const:`0.0`): + Weight decay, add L2 penalty to parameters. eps_root: (float, default: :data:`0.0`) A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing @@ -73,6 +76,7 @@ def __init__( lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, eps_root=eps_root, moment_requires_grad=moment_requires_grad, maximize=maximize, diff --git a/torchopt/_src/optimizer/meta/rmsprop.py b/torchopt/_src/optimizer/meta/rmsprop.py index 3b27fbdd5..d2fb5490f 100644 --- a/torchopt/_src/optimizer/meta/rmsprop.py +++ b/torchopt/_src/optimizer/meta/rmsprop.py @@ -35,6 +35,7 @@ def __init__( lr: ScalarOrSchedule = 1e-2, alpha: float = 0.99, eps: float = 1e-8, + weight_decay: float = 0.0, momentum: float = 0.0, centered: bool = False, *, @@ -53,6 +54,8 @@ def __init__( Smoothing constant, the decay used to track the magnitude of previous gradients. eps: (float, default: :const:`1e-8`) A small numerical constant to avoid dividing by zero when rescaling. + weight_decay: (float, default: :const:`0.0`): + Weight decay, add L2 penalty to parameters. momentum: (float, default: :const:`0.0`) The decay rate used by the momentum term. The momentum is not used when it is set to :const:`0.0`. @@ -74,6 +77,7 @@ def __init__( lr=lr, alpha=alpha, eps=eps, + weight_decay=weight_decay, momentum=momentum, centered=centered, initial_scale=initial_scale, diff --git a/torchopt/_src/optimizer/meta/sgd.py b/torchopt/_src/optimizer/meta/sgd.py index b4458372b..1aab903d9 100644 --- a/torchopt/_src/optimizer/meta/sgd.py +++ b/torchopt/_src/optimizer/meta/sgd.py @@ -34,6 +34,7 @@ def __init__( net: nn.Module, lr: ScalarOrSchedule, momentum: float = 0.0, + weight_decay: float = 0.0, nesterov: bool = False, moment_requires_grad: bool = True, maximize: bool = False, @@ -48,6 +49,8 @@ def __init__( momentum: (float, default: :const:`0.0`) The decay rate used by the momentum term. The momentum is not used when it is set to :const:`0.0`. + weight_decay: (float, default: :const:`0.0`): + Weight decay, add L2 penalty to parameters. nesterov: (bool, default: :const:`False`) Whether the nesterov momentum is used. moment_requires_grad: (bool, default: :data:`True`) @@ -61,6 +64,7 @@ def __init__( sgd( lr=lr, momentum=momentum, + weight_decay=weight_decay, nesterov=nesterov, moment_requires_grad=moment_requires_grad, maximize=maximize, diff --git a/torchopt/_src/optimizer/rmsprop.py b/torchopt/_src/optimizer/rmsprop.py index f8e0612fd..8d2a66a3f 100644 --- a/torchopt/_src/optimizer/rmsprop.py +++ b/torchopt/_src/optimizer/rmsprop.py @@ -37,6 +37,7 @@ def __init__( lr: ScalarOrSchedule = 1e-2, alpha: float = 0.99, eps: float = 1e-8, + weight_decay: float = 0.0, momentum: float = 0.0, centered: bool = False, *, @@ -55,6 +56,8 @@ def __init__( Smoothing constant, the decay used to track the magnitude of previous gradients. eps: (float, default: :const:`1e-8`) A small numerical constant to avoid dividing by zero when rescaling. + weight_decay: (float, default: :const:`0.0`): + Weight decay, add L2 penalty to parameters. momentum: (float, default: :const:`0.0`) The decay rate used by the momentum term. The momentum is not used when it is set to :const:`0.0`. @@ -76,6 +79,7 @@ def __init__( lr=lr, alpha=alpha, eps=eps, + weight_decay=weight_decay, momentum=momentum, centered=centered, initial_scale=initial_scale, diff --git a/torchopt/_src/optimizer/sgd.py b/torchopt/_src/optimizer/sgd.py index da62495bb..9cfd608a1 100644 --- a/torchopt/_src/optimizer/sgd.py +++ b/torchopt/_src/optimizer/sgd.py @@ -36,6 +36,7 @@ def __init__( params: Iterable[torch.Tensor], lr: ScalarOrSchedule, momentum: float = 0.0, + weight_decay: float = 0.0, nesterov: bool = False, maximize: bool = False, ): @@ -49,6 +50,8 @@ def __init__( momentum: (float, default: :const:`0.0`) The decay rate used by the momentum term. The momentum is not used when it is set to :const:`0.0`. + weight_decay: (float, default: :const:`0.0`): + Weight decay, add L2 penalty to parameters. nesterov: (bool, default: :data:`False`) Whether the nesterov momentum is used. maximize: (bool, default: :data:`False`) @@ -59,6 +62,7 @@ def __init__( sgd( lr=lr, momentum=momentum, + weight_decay=weight_decay, nesterov=nesterov, moment_requires_grad=False, maximize=maximize, diff --git a/torchopt/_src/transform.py b/torchopt/_src/transform.py index 98a1cd583..af969d614 100644 --- a/torchopt/_src/transform.py +++ b/torchopt/_src/transform.py @@ -69,8 +69,7 @@ def scale(step_size: float) -> base.GradientTransformation: An ``(init_fn, update_fn)`` tuple. """ - def init_fn(params): - del params + def init_fn(_): return ScaleState() def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument