Skip to content

Commit ab66c93

Browse files
committed
Revert the expected value and use original eps=0 value for flaky tests.
1 parent 2f92072 commit ab66c93

File tree

4 files changed

+19
-2
lines changed

4 files changed

+19
-2
lines changed
-1.09 KB
Binary file not shown.
0 Bytes
Binary file not shown.

test/test_models.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,19 @@ def _test_detection_model(self, name, dev):
148148
kwargs = {}
149149
if "retinanet" in name:
150150
kwargs["score_thresh"] = 0.013
151-
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs)
151+
152+
# Workaround for flaky tests
153+
from torchvision.ops.misc import FrozenBatchNorm2d
154+
class OverwriteEPS:
155+
def __enter__(self):
156+
self.default = FrozenBatchNorm2d._DEFAULT_EPS
157+
FrozenBatchNorm2d._DEFAULT_EPS = 0.0
158+
159+
def __exit__(self, type, value, traceback):
160+
FrozenBatchNorm2d._DEFAULT_EPS = self.default
161+
162+
with OverwriteEPS():
163+
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs)
152164
model.eval().to(device=dev)
153165
input_shape = (3, 300, 300)
154166
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests

torchvision/ops/misc.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,18 @@ class FrozenBatchNorm2d(torch.nn.Module):
4747
BatchNorm2d where the batch statistics and the affine parameters
4848
are fixed
4949
"""
50+
_DEFAULT_EPS = 1e-5
5051

5152
def __init__(
5253
self,
5354
num_features: int,
54-
eps: float = 1e-5,
55+
eps: float = None,
5556
n: Optional[int] = None,
5657
):
58+
# eps=None for unit-test stability
59+
if eps is None:
60+
eps = FrozenBatchNorm2d._DEFAULT_EPS
61+
5762
# n=None for backward-compatibility
5863
if n is not None:
5964
warnings.warn("`n` argument is deprecated and has been renamed `num_features`",

0 commit comments

Comments
 (0)