Skip to content

Commit

Permalink
add full-model gradient clipping to optimizer (facebookresearch#287)
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx authored Nov 15, 2020
1 parent 4e1a928 commit a54b778
Showing 1 changed file with 25 additions and 34 deletions.
59 changes: 25 additions & 34 deletions d2/train_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
import os
import sys
import itertools

# fmt: off
sys.path.insert(1, os.path.join(sys.path[0], '..'))
Expand All @@ -21,7 +22,7 @@
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog, build_detection_train_loader
from detectron2.engine import AutogradProfiler, DefaultTrainer, default_argument_parser, default_setup, launch
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch
from detectron2.evaluation import COCOEvaluator, verify_results

from detectron2.solver.build import maybe_add_gradient_clipping
Expand All @@ -32,37 +33,6 @@ class Trainer(DefaultTrainer):
Extension of the Trainer class adapted to DETR.
"""

def __init__(self, cfg):
"""
Args:
cfg (CfgNode):
"""
self.clip_norm_val = 0.0
if cfg.SOLVER.CLIP_GRADIENTS.ENABLED:
if cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
self.clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
super().__init__(cfg)

def run_step(self):
assert self.model.training, "[Trainer] model was changed to eval mode!"
start = time.perf_counter()
data = next(self._data_loader_iter)
data_time = time.perf_counter() - start

loss_dict = self.model(data)
losses = sum(loss_dict.values())
self._detect_anomaly(losses, loss_dict)

metrics_dict = loss_dict
metrics_dict["data_time"] = data_time
self._write_metrics(metrics_dict)

self.optimizer.zero_grad()
losses.backward()
if self.clip_norm_val > 0.0:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_norm_val)
self.optimizer.step()

@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
"""
Expand Down Expand Up @@ -100,11 +70,32 @@ def build_optimizer(cls, cfg, model):
lr = lr * cfg.SOLVER.BACKBONE_MULTIPLIER
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]

def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class
# detectron2 doesn't have full model gradient clipping now
clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
enable = (
cfg.SOLVER.CLIP_GRADIENTS.ENABLED
and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
and clip_norm_val > 0.0
)

class FullModelGradientClippingOptimizer(optim):
def step(self, closure=None):
all_params = itertools.chain(*[x["params"] for x in self.param_groups])
torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
super().step(closure=closure)

return FullModelGradientClippingOptimizer if enable else optim

optimizer_type = cfg.SOLVER.OPTIMIZER
if optimizer_type == "SGD":
optimizer = torch.optim.SGD(params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM)
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM
)
elif optimizer_type == "ADAMW":
optimizer = torch.optim.AdamW(params, cfg.SOLVER.BASE_LR)
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
params, cfg.SOLVER.BASE_LR
)
else:
raise NotImplementedError(f"no optimizer type {optimizer_type}")
if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
Expand Down

0 comments on commit a54b778

Please sign in to comment.