Skip to content

Commit b18a475

Browse files
authored
Adds Anchor tests with ground-truth outputs (#2983)
* Add AnchorGenerator with ground-truth outputs * Minor fixes
1 parent 4c11218 commit b18a475

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

test/test_models_detection_anchor_utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import OrderedDict
12
import torch
23
import unittest
34
from torchvision.models.detection.anchor_utils import AnchorGenerator
@@ -13,3 +14,42 @@ def test_incorrect_anchors(self):
1314
image_list = ImageList(image1, [(800, 800)])
1415
feature_maps = [torch.randn(1, 50)]
1516
self.assertRaises(ValueError, anc, image_list, feature_maps)
17+
18+
def _init_test_anchor_generator(self):
19+
anchor_sizes = tuple((x,) for x in [32, 64, 128])
20+
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
21+
anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
22+
23+
return anchor_generator
24+
25+
def get_features(self, images):
26+
s0, s1 = images.shape[-2:]
27+
features = [
28+
('0', torch.rand(2, 8, s0 // 4, s1 // 4)),
29+
('1', torch.rand(2, 16, s0 // 8, s1 // 8)),
30+
('2', torch.rand(2, 32, s0 // 16, s1 // 16)),
31+
]
32+
features = OrderedDict(features)
33+
return features
34+
35+
def test_anchor_generator(self):
36+
images = torch.randn(2, 3, 16, 32)
37+
features = self.get_features(images)
38+
features = list(features.values())
39+
image_shapes = [i.shape[-2:] for i in images]
40+
images = ImageList(images, image_shapes)
41+
42+
model = self._init_test_anchor_generator()
43+
model.eval()
44+
anchors = model(images, features)
45+
46+
# Compute target anchors numbers
47+
grid_sizes = [f.shape[-2:] for f in features]
48+
num_anchors_estimated = 0
49+
for sizes, num_anchors_per_loc in zip(grid_sizes, model.num_anchors_per_location()):
50+
num_anchors_estimated += sizes[0] * sizes[1] * num_anchors_per_loc
51+
52+
self.assertEqual(num_anchors_estimated, 126)
53+
self.assertEqual(len(anchors), 2)
54+
self.assertEqual(tuple(anchors[0].shape), (num_anchors_estimated, 4))
55+
self.assertEqual(tuple(anchors[1].shape), (num_anchors_estimated, 4))

0 commit comments

Comments
 (0)