Skip to content
This repository has been archived by the owner on Feb 20, 2021. It is now read-only.

Commit

Permalink
✨ Minor fixes to AdamWCD, new AdamCD
Browse files Browse the repository at this point in the history
  • Loading branch information
emaballarin committed May 21, 2020
1 parent e7ec832 commit 35c0249
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 5 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@

END OF TERMS AND CONDITIONS

Copyright 2020- Emanuele Ballarin <emanuele@ballarin.cc> and affiliates.
*Copyright 2020- Emanuele Ballarin <emanuele@ballarin.cc> and affiliates.*

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
2 changes: 2 additions & 0 deletions pyromaniac/optim/pyro/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from pyromaniac.optim.pyro.adamwcd import AdamWCD
from pyromaniac.optim.pyro.adamcd import AdamCD


__all__ = [
"AdamWCD",
"AdamCD",
]
14 changes: 14 additions & 0 deletions pyromaniac/optim/pyro/adamcd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2017- Uber Technologies, Inc.
# Copyright (c) 2020- AI-CPS@UniTS
# Copyright (c) 2020- Emanuele Ballarin <emanuele@ballarin.cc>
# SPDX-License-Identifier: Apache-2.0

from pyro.optim.optim import PyroOptim
from pyromaniac.optim.torch import AdamCD as pt_AdamCD


def AdamCD(optim_args):
"""
Wraps :class:`pyromaniac.optim.adamwcd.AdamCD` with :class:`~pyro.optim.optim.PyroOptim`.
"""
return PyroOptim(pt_AdamCD, optim_args)
2 changes: 2 additions & 0 deletions pyromaniac/optim/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .adamwcd import AdamWCD
from .adamcd import AdamCD

del adamwcd
del adamcd
173 changes: 173 additions & 0 deletions pyromaniac/optim/torch/adamcd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright (c) 2016- Facebook, Inc
# Copyright (c) 2017- Uber Technologies, Inc.
# Copyright (c) 2020- AI-CPS@UniTS
# Copyright (c) 2020- Emanuele Ballarin <emanuele@ballarin.cc>
# SPDX-License-Identifier: Apache-2.0

import math
import torch
from torch.optim.optimizer import Optimizer, required


class AdamCD(Optimizer):
r"""Implements AdamCD algorithm.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
AdamCD further introduces gradient clipping and epoch-wise learning rate decay, similarly to the `Uber Pyro ClippedAdam variant of Adam`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
clip_norm (float, optional): magnitude of norm to which gradients are clipped (default: 10.0)
lrd (float, optional): rate at which learning rate decays (default: 1.0)
lrd_epoch (int, optional): number of optimizer steps every which to perform one learning rate decay step (default: 1)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Uber Pyro ClippedAdam variant of Adam:
http://docs.pyro.ai/en/latest/_modules/pyro/optim/clipped_adam.html
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""

def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
clip_norm=10.0,
lrd=1.0,
lrd_epoch: int = 1,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if not 0.0 <= clip_norm:
raise ValueError("Invalid clip_norm value: {}".format(clip_norm))
if not 0.0 <= lrd:
raise ValueError("Invalid lrd value: {}".format(lrd))
if not 1 <= lrd_epoch:
raise ValueError("Invalid lrd_epoch value: {}".format(lrd_epoch))
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
clip_norm=clip_norm,
lrd=lrd,
lrd_epoch=lrd_epoch,
)
self.lrd_epoch_persist = lrd_epoch
super(AdamCD, self).__init__(params, defaults)

def __setstate__(self, state):
super(AdamCD, self).__setstate__(state)
for group in self.param_groups:
group.setdefault("amsgrad", False)

@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:

# Apply a decay factor to learning rate, but only every lrd_epoch steps
group["lrd_epoch"] -= 1
if group["lrd_epoch"] == 0:
group["lr"] *= group["lrd"]
group["lrd_epoch"] = self.lrd_epoch_persist

for p in group["params"]:
if p.grad is None:
continue
grad = p.grad

# Clamp the gradient
grad.clamp_(-group["clip_norm"], group["clip_norm"])

if grad.is_sparse:
raise RuntimeError(
"AdamCD does not support sparse gradients, please consider SparseAdam instead"
)
amsgrad = group["amsgrad"]

state = self.state[p]

# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)

exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
if amsgrad:
max_exp_avg_sq = state["max_exp_avg_sq"]
beta1, beta2 = group["betas"]

state["step"] += 1
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]

if group["weight_decay"] != 0:
grad = grad.add(p, alpha=group["weight_decay"])

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
group["eps"]
)
else:
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
group["eps"]
)

step_size = group["lr"] / bias_correction1

p.addcdiv_(exp_avg, denom, value=-step_size)

return loss
10 changes: 6 additions & 4 deletions pyromaniac/optim/torch/adamwcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,18 @@ class AdamWCD(Optimizer):
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
clip_norm (float, optional): magnitude of norm to which gradients are clipped (default: 10.0)
lrd (float, optional): rate at which learning rate decays (default: 1.0)
lrd_epoch (int, optional): number of optimizer steps every which to perform one learning rate decay step (default: 1)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
.. _Uber Pyro ClippedAdam variant of Adam:
http://docs.pyro.ai/en/latest/_modules/pyro/optim/clipped_adam.html
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""

def __init__(
Expand All @@ -46,7 +49,6 @@ def __init__(
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
#weight_decay=1e-2,
weight_decay=0,
amsgrad=False,
clip_norm=10.0,
Expand Down Expand Up @@ -102,7 +104,7 @@ def step(self, closure=None):

for group in self.param_groups:

# Apply a decay factor to learning rate, but only every ldr_epoch steps
# Apply a decay factor to learning rate, but only every lrd_epoch steps
group["lrd_epoch"] -= 1
if group["lrd_epoch"] == 0:
group["lr"] *= group["lrd"]
Expand Down

0 comments on commit 35c0249

Please sign in to comment.