-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathrefnet.py
118 lines (96 loc) · 4.44 KB
/
refnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import torch.nn as nn
import numpy as np
import sys
import os
sys.path.append(os.path.join(os.getcwd(), "lib")) # HACK add the lib folder
from models.backbone_module import Pointnet2Backbone
from models.voting_module import VotingModule
from models.proposal_module import ProposalModule
from models.lang_module import LangModule
from models.match_module import MatchModule
class RefNet(nn.Module):
def __init__(self, num_class, num_heading_bin, num_size_cluster, mean_size_arr,
input_feature_dim=0, num_proposal=128, vote_factor=1, sampling="vote_fps",
use_lang_classifier=True, use_bidir=False, no_reference=False,
emb_size=300, hidden_size=256):
super().__init__()
self.num_class = num_class
self.num_heading_bin = num_heading_bin
self.num_size_cluster = num_size_cluster
self.mean_size_arr = mean_size_arr
assert(mean_size_arr.shape[0] == self.num_size_cluster)
self.input_feature_dim = input_feature_dim
self.num_proposal = num_proposal
self.vote_factor = vote_factor
self.sampling = sampling
self.use_lang_classifier = use_lang_classifier
self.use_bidir = use_bidir
self.no_reference = no_reference
# --------- PROPOSAL GENERATION ---------
# Backbone point feature learning
self.backbone_net = Pointnet2Backbone(input_feature_dim=self.input_feature_dim)
# Hough voting
self.vgen = VotingModule(self.vote_factor, 256)
# Vote aggregation and object proposal
self.proposal = ProposalModule(num_class, num_heading_bin, num_size_cluster, mean_size_arr, num_proposal, sampling)
if not no_reference:
# --------- LANGUAGE ENCODING ---------
# Encode the input descriptions into vectors
# (including attention and language classification)
self.lang = LangModule(num_class, use_lang_classifier, use_bidir, emb_size, hidden_size)
# --------- PROPOSAL MATCHING ---------
# Match the generated proposals and select the most confident ones
self.match = MatchModule(num_proposals=num_proposal, lang_size=(1 + int(self.use_bidir)) * hidden_size)
def forward(self, data_dict):
""" Forward pass of the network
Args:
data_dict: dict
{
point_clouds,
lang_feat
}
point_clouds: Variable(torch.cuda.FloatTensor)
(B, N, 3 + input_channels) tensor
Point cloud to run predicts on
Each point in the point-cloud MUST
be formated as (x, y, z, features...)
Returns:
end_points: dict
"""
#######################################
# #
# DETECTION BRANCH #
# #
#######################################
# --------- HOUGH VOTING ---------
data_dict = self.backbone_net(data_dict)
# --------- HOUGH VOTING ---------
xyz = data_dict["fp2_xyz"]
features = data_dict["fp2_features"]
data_dict["seed_inds"] = data_dict["fp2_inds"]
data_dict["seed_xyz"] = xyz
data_dict["seed_features"] = features
xyz, features = self.vgen(xyz, features)
features_norm = torch.norm(features, p=2, dim=1)
features = features.div(features_norm.unsqueeze(1))
data_dict["vote_xyz"] = xyz
data_dict["vote_features"] = features
# --------- PROPOSAL GENERATION ---------
data_dict = self.proposal(xyz, features, data_dict)
if not self.no_reference:
#######################################
# #
# LANGUAGE BRANCH #
# #
#######################################
# --------- LANGUAGE ENCODING ---------
data_dict = self.lang(data_dict)
#######################################
# #
# PROPOSAL MATCHING #
# #
#######################################
# --------- PROPOSAL MATCHING ---------
data_dict = self.match(data_dict)
return data_dict