Skip to content

Commit

Permalink
Add adam optimizers and finegrained control of wd (facebookresearch#151)
Browse files Browse the repository at this point in the history
Summary:
The update supports Adam and AdamW optimizers in
addition to default SGD optimizer. It now also supports
customizing weight decay for layernorm and bias params.

Pull Request resolved: facebookresearch#151

Reviewed By: pdollar

Differential Revision: D29521923

Pulled By: Tete-Xiao

fbshipit-source-id: 0cd5326eadba51b3a175d62a48fcd332d757b3b6
  • Loading branch information
Tete Xiao authored and facebook-github-bot committed Jul 1, 2021
1 parent 39f2dc2 commit 74df325
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 17 deletions.
22 changes: 22 additions & 0 deletions pycls/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,23 @@
_C.BN.CUSTOM_WEIGHT_DECAY = 0.0


# -------------------------------- Layer norm options -------------------------------- #
_C.LN = CfgNode()

# LN epsilon
_C.LN.EPS = 1e-5

# Use a different weight decay for LN layers
_C.LN.USE_CUSTOM_WEIGHT_DECAY = False
_C.LN.CUSTOM_WEIGHT_DECAY = 0.0


# -------------------------------- Optimizer options --------------------------------- #
_C.OPTIM = CfgNode()

# Type of optimizer select from {'sgd', 'adam', 'adamw'}
_C.OPTIM.OPTIMIZER = "sgd"

# Learning rate ranges from BASE_LR to MIN_LR*BASE_LR according to the LR_POLICY
_C.OPTIM.BASE_LR = 0.1
_C.OPTIM.MIN_LR = 0.0
Expand All @@ -221,9 +235,17 @@
# Nesterov momentum
_C.OPTIM.NESTEROV = True

# Betas (for Adam/AdamW optimizer)
_C.OPTIM.BETA1 = 0.9
_C.OPTIM.BETA2 = 0.999

# L2 regularization
_C.OPTIM.WEIGHT_DECAY = 5e-4

# Use a different weight decay for all biases (excluding those in BN/LN layers)
_C.OPTIM.BIAS_USE_CUSTOM_WEIGHT_DECAY = False
_C.OPTIM.BIAS_CUSTOM_WEIGHT_DECAY = 0.0

# Start the warm up from OPTIM.BASE_LR * OPTIM.WARMUP_FACTOR
_C.OPTIM.WARMUP_FACTOR = 0.1

Expand Down
55 changes: 38 additions & 17 deletions pycls/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,45 @@ def construct_optimizer(model):
when the learning rate is changed there is no need to perform the
momentum correction by scaling V (unlike in the Caffe2 case).
"""
if cfg.BN.USE_CUSTOM_WEIGHT_DECAY:
# Apply different weight decay to Batchnorm and non-batchnorm parameters.
p_bn = [p for n, p in model.named_parameters() if "bn" in n]
p_non_bn = [p for n, p in model.named_parameters() if "bn" not in n]
optim_params = [
{"params": p_bn, "weight_decay": cfg.BN.CUSTOM_WEIGHT_DECAY},
{"params": p_non_bn, "weight_decay": cfg.OPTIM.WEIGHT_DECAY},
]
# Split parameters into types and get weight decay for each type
optim, wd, params = cfg.OPTIM, cfg.OPTIM.WEIGHT_DECAY, [[], [], [], []]
for n, p in model.named_parameters():
ks = [k for (k, x) in enumerate(["bn", "ln", "bias", ""]) if x in n]
params[ks[0]].append(p)
wds = [
cfg.BN.CUSTOM_WEIGHT_DECAY if cfg.BN.USE_CUSTOM_WEIGHT_DECAY else wd,
cfg.LN.CUSTOM_WEIGHT_DECAY if cfg.LN.USE_CUSTOM_WEIGHT_DECAY else wd,
optim.BIAS_CUSTOM_WEIGHT_DECAY if optim.BIAS_USE_CUSTOM_WEIGHT_DECAY else wd,
wd,
]
param_wds = [{"params": p, "weight_decay": w} for (p, w) in zip(params, wds) if p]
# Set up optimizer
if optim.OPTIMIZER == "sgd":
optimizer = torch.optim.SGD(
param_wds,
lr=optim.BASE_LR,
momentum=optim.MOMENTUM,
weight_decay=wd,
dampening=optim.DAMPENING,
nesterov=optim.NESTEROV,
)
elif optim.OPTIMIZER == "adam":
optimizer = torch.optim.Adam(
param_wds,
lr=optim.BASE_LR,
betas=(optim.BETA1, optim.BETA2),
weight_decay=wd,
)
elif optim.OPTIMIZER == "adamw":
optimizer = torch.optim.AdamW(
param_wds,
lr=optim.BASE_LR,
betas=(optim.BETA1, optim.BETA2),
weight_decay=wd,
)
else:
optim_params = model.parameters()
return torch.optim.SGD(
optim_params,
lr=cfg.OPTIM.BASE_LR,
momentum=cfg.OPTIM.MOMENTUM,
weight_decay=cfg.OPTIM.WEIGHT_DECAY,
dampening=cfg.OPTIM.DAMPENING,
nesterov=cfg.OPTIM.NESTEROV,
)
raise NotImplementedError
return optimizer


def lr_fun_steps(cur_epoch):
Expand Down

0 comments on commit 74df325

Please sign in to comment.