-
Notifications
You must be signed in to change notification settings - Fork 248
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #135 from simonJJJ/dev_frozen_bn
fix compatibility of checkpointing and freeze-resnet
- Loading branch information
Showing
2 changed files
with
85 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Modified from detectron2: https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py#L13 | ||
import torch | ||
from torch import nn | ||
from torch.nn import functional as F | ||
|
||
|
||
class FrozenBatchNorm2d(nn.Module): | ||
""" | ||
BatchNorm2d where the batch statistics and the affine parameters are fixed. | ||
It contains non-trainable buffers called | ||
"weight" and "bias", "running_mean", "running_var", | ||
initialized to perform identity transformation. | ||
The pre-trained backbone models from Caffe2 only contain "weight" and "bias", | ||
which are computed from the original four parameters of BN. | ||
The affine transform `x * weight + bias` will perform the equivalent | ||
computation of `(x - running_mean) / sqrt(running_var) * weight + bias`. | ||
When loading a backbone model from Caffe2, "running_mean" and "running_var" | ||
will be left unchanged as identity transformation. | ||
Other pre-trained backbone models may contain all 4 parameters. | ||
The forward is implemented by `F.batch_norm(..., training=False)`. | ||
""" | ||
|
||
def __init__(self, num_features, eps=1e-5): | ||
super().__init__() | ||
self.num_features = num_features | ||
self.eps = eps | ||
self.register_buffer("weight", torch.ones(num_features)) | ||
self.register_buffer("bias", torch.zeros(num_features)) | ||
self.register_buffer("running_mean", torch.zeros(num_features)) | ||
self.register_buffer("running_var", torch.ones(num_features) - eps) | ||
|
||
def forward(self, x): | ||
if x.requires_grad: | ||
# When gradients are needed, F.batch_norm will use extra memory | ||
# because its backward op computes gradients for weight/bias as well. | ||
scale = self.weight * (self.running_var + self.eps).rsqrt() | ||
bias = self.bias - self.running_mean * scale | ||
scale = scale.reshape(1, -1, 1, 1) | ||
bias = bias.reshape(1, -1, 1, 1) | ||
out_dtype = x.dtype # may be half | ||
return x * scale.to(out_dtype) + bias.to(out_dtype) | ||
else: | ||
# When gradients are not needed, F.batch_norm is a single fused op | ||
# and provide more optimization opportunities. | ||
return F.batch_norm( | ||
x, | ||
self.running_mean, | ||
self.running_var, | ||
self.weight, | ||
self.bias, | ||
training=False, | ||
eps=self.eps, | ||
) | ||
|
||
def _load_from_state_dict( | ||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs | ||
): | ||
num_batches_tracked_key = prefix + 'num_batches_tracked' | ||
if num_batches_tracked_key in state_dict: | ||
del state_dict[num_batches_tracked_key] | ||
version = local_metadata.get("version", None) | ||
|
||
if version is None or version < 2: | ||
# No running_mean/var in early versions | ||
# This will silent the warnings | ||
if prefix + "running_mean" not in state_dict: | ||
state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean) | ||
if prefix + "running_var" not in state_dict: | ||
state_dict[prefix + "running_var"] = torch.ones_like(self.running_var) | ||
|
||
super()._load_from_state_dict( | ||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs | ||
) | ||
|
||
def __repr__(self): | ||
return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters