-
Notifications
You must be signed in to change notification settings - Fork 16
/
proposal_layer.py
executable file
·46 lines (34 loc) · 1.58 KB
/
proposal_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#!/usr/bin/env python
# Copyrigh 2018 houjingyong@gmail.com
# MIT Licence
from __future__ import absolute_import
import torch
import torch.nn as nn
from config import cfg
from generate_anchors import AnchorGenerator
from bbox_transform import bbox_transform_inv, clip_boxes
import pdb
import numpy as np
import math
import yaml
class ProposalLayer(nn.Module):
"""
Outputs object detection proposals by applying estimated bounding-box
transfromations to a set of regular boxes (called "anchors")
"""
def __init__(self, num_anchors_per_frame, min_box_size, max_box_size):
super(ProposalLayer, self).__init__()
self.anchor_generator = AnchorGenerator(num_anchors_per_frame, min_box_size, max_box_size)
self.num_anchors_per_frame = num_anchors_per_frame
def forward(self, bbox_deltas):
batch_size = bbox_deltas.size(0)
feature_len = bbox_deltas.size(1)/self.num_anchors_per_frame
# First dimension is batchsize, the second dimension is length of
# the number of frames
anchors_per_utt = self.anchor_generator.get_anchors_per_utt(feature_len)
# anchors for a batch of utterance
anchors = anchors_per_utt.view(1, self.num_anchors_per_frame * feature_len, 2).expand(batch_size, self.num_anchors_per_frame * feature_len, 2)
bbox_deltas.reshape(batch_size, self.num_anchors_per_frame * feature_len, 2)
proposals = bbox_transform_inv(anchors, bbox_deltas)
anchors_per_utt = anchors_per_utt.view(self.num_anchors_per_frame *feature_len, 2)
return anchors_per_utt, proposals