1
+ from collections import OrderedDict
1
2
import torch
2
3
import unittest
3
4
from torchvision .models .detection .anchor_utils import AnchorGenerator
@@ -13,3 +14,42 @@ def test_incorrect_anchors(self):
13
14
image_list = ImageList (image1 , [(800 , 800 )])
14
15
feature_maps = [torch .randn (1 , 50 )]
15
16
self .assertRaises (ValueError , anc , image_list , feature_maps )
17
+
18
+ def _init_test_anchor_generator (self ):
19
+ anchor_sizes = tuple ((x ,) for x in [32 , 64 , 128 ])
20
+ aspect_ratios = ((0.5 , 1.0 , 2.0 ),) * len (anchor_sizes )
21
+ anchor_generator = AnchorGenerator (anchor_sizes , aspect_ratios )
22
+
23
+ return anchor_generator
24
+
25
+ def get_features (self , images ):
26
+ s0 , s1 = images .shape [- 2 :]
27
+ features = [
28
+ ('0' , torch .rand (2 , 8 , s0 // 4 , s1 // 4 )),
29
+ ('1' , torch .rand (2 , 16 , s0 // 8 , s1 // 8 )),
30
+ ('2' , torch .rand (2 , 32 , s0 // 16 , s1 // 16 )),
31
+ ]
32
+ features = OrderedDict (features )
33
+ return features
34
+
35
+ def test_anchor_generator (self ):
36
+ images = torch .randn (2 , 3 , 16 , 32 )
37
+ features = self .get_features (images )
38
+ features = list (features .values ())
39
+ image_shapes = [i .shape [- 2 :] for i in images ]
40
+ images = ImageList (images , image_shapes )
41
+
42
+ model = self ._init_test_anchor_generator ()
43
+ model .eval ()
44
+ anchors = model (images , features )
45
+
46
+ # Compute target anchors numbers
47
+ grid_sizes = [f .shape [- 2 :] for f in features ]
48
+ num_anchors_estimated = 0
49
+ for sizes , num_anchors_per_loc in zip (grid_sizes , model .num_anchors_per_location ()):
50
+ num_anchors_estimated += sizes [0 ] * sizes [1 ] * num_anchors_per_loc
51
+
52
+ self .assertEqual (num_anchors_estimated , 126 )
53
+ self .assertEqual (len (anchors ), 2 )
54
+ self .assertEqual (tuple (anchors [0 ].shape ), (num_anchors_estimated , 4 ))
55
+ self .assertEqual (tuple (anchors [1 ].shape ), (num_anchors_estimated , 4 ))
0 commit comments