Skip to content

Commit 45e027c

Browse files
authored
[BC-breaking] Change default eps value of FrozenBN (#2933)
* Change default eps value of FrozenBN. * Update the unit-tests.` * Update the expected values. * Revert the expected value and use original eps=0 value for flaky tests. * Post init change of eps. * Styles.
1 parent 455cd57 commit 45e027c

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

test/test_models.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import numpy as np
77
from torchvision import models
88
import unittest
9-
import traceback
109
import random
1110

11+
from torchvision.ops.misc import FrozenBatchNorm2d
12+
1213

1314
def set_rng_seed(seed):
1415
torch.manual_seed(seed)
@@ -149,6 +150,10 @@ def _test_detection_model(self, name, dev):
149150
if "retinanet" in name:
150151
kwargs["score_thresh"] = 0.013
151152
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs)
153+
if "keypointrcnn" in name or "retinanet" in name:
154+
for module in model.modules():
155+
if isinstance(module, FrozenBatchNorm2d):
156+
module.eps = 0
152157
model.eval().to(device=dev)
153158
input_shape = (3, 300, 300)
154159
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests

test/test_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -623,10 +623,10 @@ def test_frozenbatchnorm2d_eps(self):
623623
running_var=torch.rand(sample_size[1]),
624624
num_batches_tracked=torch.tensor(100))
625625

626-
# Check that default eps is zero for backward-compatibility
626+
# Check that default eps is equal to the one of BN
627627
fbn = ops.misc.FrozenBatchNorm2d(sample_size[1])
628628
fbn.load_state_dict(state_dict, strict=False)
629-
bn = torch.nn.BatchNorm2d(sample_size[1], eps=0).eval()
629+
bn = torch.nn.BatchNorm2d(sample_size[1]).eval()
630630
bn.load_state_dict(state_dict)
631631
# Difference is expected to fall in an acceptable range
632632
self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6))

torchvision/ops/misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class FrozenBatchNorm2d(torch.nn.Module):
5151
def __init__(
5252
self,
5353
num_features: int,
54-
eps: float = 0.,
54+
eps: float = 1e-5,
5555
n: Optional[int] = None,
5656
):
5757
# n=None for backward-compatibility

0 commit comments

Comments
 (0)