Skip to content

Commit

Permalink
Merge pull request #135 from simonJJJ/dev_frozen_bn
Browse files Browse the repository at this point in the history
fix compatibility of checkpointing and freeze-resnet
  • Loading branch information
logicwong authored Jun 21, 2022
2 parents d217524 + 5c80a8f commit a742ad2
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 10 deletions.
80 changes: 80 additions & 0 deletions models/ofa/frozen_bn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Modified from detectron2: https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py#L13
import torch
from torch import nn
from torch.nn import functional as F


class FrozenBatchNorm2d(nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
It contains non-trainable buffers called
"weight" and "bias", "running_mean", "running_var",
initialized to perform identity transformation.
The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
which are computed from the original four parameters of BN.
The affine transform `x * weight + bias` will perform the equivalent
computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
When loading a backbone model from Caffe2, "running_mean" and "running_var"
will be left unchanged as identity transformation.
Other pre-trained backbone models may contain all 4 parameters.
The forward is implemented by `F.batch_norm(..., training=False)`.
"""

def __init__(self, num_features, eps=1e-5):
super().__init__()
self.num_features = num_features
self.eps = eps
self.register_buffer("weight", torch.ones(num_features))
self.register_buffer("bias", torch.zeros(num_features))
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features) - eps)

def forward(self, x):
if x.requires_grad:
# When gradients are needed, F.batch_norm will use extra memory
# because its backward op computes gradients for weight/bias as well.
scale = self.weight * (self.running_var + self.eps).rsqrt()
bias = self.bias - self.running_mean * scale
scale = scale.reshape(1, -1, 1, 1)
bias = bias.reshape(1, -1, 1, 1)
out_dtype = x.dtype # may be half
return x * scale.to(out_dtype) + bias.to(out_dtype)
else:
# When gradients are not needed, F.batch_norm is a single fused op
# and provide more optimization opportunities.
return F.batch_norm(
x,
self.running_mean,
self.running_var,
self.weight,
self.bias,
training=False,
eps=self.eps,
)

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
num_batches_tracked_key = prefix + 'num_batches_tracked'
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
version = local_metadata.get("version", None)

if version is None or version < 2:
# No running_mean/var in early versions
# This will silent the warnings
if prefix + "running_mean" not in state_dict:
state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
if prefix + "running_var" not in state_dict:
state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)

super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)

def __repr__(self):
return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
15 changes: 5 additions & 10 deletions models/ofa/unify_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from .unify_transformer_layer import TransformerEncoderLayer, TransformerDecoderLayer
from .resnet import ResNet
from .frozen_bn import FrozenBatchNorm2d


DEFAULT_MAX_SOURCE_POSITIONS = 1024
Expand Down Expand Up @@ -427,7 +428,10 @@ def __init__(self, args, dictionary, embed_tokens):
if getattr(args, "sync_bn", False):
norm_layer = BatchNorm2d
else:
norm_layer = None
if getattr(args, "freeze_resnet", False):
norm_layer = FrozenBatchNorm2d
else:
norm_layer = None

if args.resnet_type == 'resnet101':
self.embed_images = ResNet([3, 4, 23], norm_layer=norm_layer, drop_path_rate=args.resnet_drop_path_rate)
Expand Down Expand Up @@ -498,15 +502,6 @@ def __init__(self, args, dictionary, embed_tokens):
self.register_buffer("image_rp_bucket", image_rp_bucket)
self.entangle_position_embedding = args.entangle_position_embedding

def train(self, mode=True):
super(TransformerEncoder, self).train(mode)
if getattr(self.args, "freeze_resnet", False):
for m in self.embed_images.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
m.weight.requires_grad = False
m.bias.requires_grad = False

def build_encoder_layer(self, args, drop_path_rate=0.0):
layer = TransformerEncoderLayer(args, drop_path_rate=drop_path_rate)
checkpoint = getattr(args, "checkpoint_activations", False)
Expand Down

0 comments on commit a742ad2

Please sign in to comment.