@@ -46,8 +46,46 @@ def forward(self, x):
46
46
return x
47
47
48
48
49
+ class STNkd (nn .Module ):
50
+ def __init__ (self , k = 64 ):
51
+ super (STNkd , self ).__init__ ()
52
+ self .conv1 = torch .nn .Conv1d (k , 64 , 1 )
53
+ self .conv2 = torch .nn .Conv1d (64 , 128 , 1 )
54
+ self .conv3 = torch .nn .Conv1d (128 , 1024 , 1 )
55
+ self .fc1 = nn .Linear (1024 , 512 )
56
+ self .fc2 = nn .Linear (512 , 256 )
57
+ self .fc3 = nn .Linear (256 , k * k )
58
+ self .relu = nn .ReLU ()
59
+
60
+ self .bn1 = nn .BatchNorm1d (64 )
61
+ self .bn2 = nn .BatchNorm1d (128 )
62
+ self .bn3 = nn .BatchNorm1d (1024 )
63
+ self .bn4 = nn .BatchNorm1d (512 )
64
+ self .bn5 = nn .BatchNorm1d (256 )
65
+
66
+ self .k = k
67
+
68
+ def forward (self , x ):
69
+ batchsize = x .size ()[0 ]
70
+ x = F .relu (self .bn1 (self .conv1 (x )))
71
+ x = F .relu (self .bn2 (self .conv2 (x )))
72
+ x = F .relu (self .bn3 (self .conv3 (x )))
73
+ x = torch .max (x , 2 , keepdim = True )[0 ]
74
+ x = x .view (- 1 , 1024 )
75
+
76
+ x = F .relu (self .bn4 (self .fc1 (x )))
77
+ x = F .relu (self .bn5 (self .fc2 (x )))
78
+ x = self .fc3 (x )
79
+
80
+ iden = Variable (torch .from_numpy (np .eye (self .k ).flatten ().astype (np .float32 ))).view (1 ,self .k * self .k ).repeat (batchsize ,1 )
81
+ if x .is_cuda :
82
+ iden = iden .cuda ()
83
+ x = x + iden
84
+ x = x .view (- 1 , self .k , self .k )
85
+ return x
86
+
49
87
class PointNetfeat (nn .Module ):
50
- def __init__ (self , global_feat = True ):
88
+ def __init__ (self , global_feat = True , feature_transform = False ):
51
89
super (PointNetfeat , self ).__init__ ()
52
90
self .stn = STN3d ()
53
91
self .conv1 = torch .nn .Conv1d (3 , 64 , 1 )
@@ -57,7 +95,9 @@ def __init__(self, global_feat = True):
57
95
self .bn2 = nn .BatchNorm1d (128 )
58
96
self .bn3 = nn .BatchNorm1d (1024 )
59
97
self .global_feat = global_feat
60
-
98
+ self .feature_transform = feature_transform
99
+ if self .feature_transform :
100
+ self .fstn = STNkd (k = 64 )
61
101
62
102
def forward (self , x ):
63
103
n_pts = x .size ()[2 ]
@@ -66,21 +106,31 @@ def forward(self, x):
66
106
x = torch .bmm (x , trans )
67
107
x = x .transpose (2 , 1 )
68
108
x = F .relu (self .bn1 (self .conv1 (x )))
109
+
110
+ if self .feature_transform :
111
+ trans_feat = self .fstn (x )
112
+ x = x .transpose (2 ,1 )
113
+ x = torch .bmm (x , trans_feat )
114
+ x = x .transpose (2 ,1 )
115
+ else :
116
+ trans_feat = None
117
+
69
118
pointfeat = x
70
119
x = F .relu (self .bn2 (self .conv2 (x )))
71
120
x = self .bn3 (self .conv3 (x ))
72
121
x = torch .max (x , 2 , keepdim = True )[0 ]
73
122
x = x .view (- 1 , 1024 )
74
123
if self .global_feat :
75
- return x , trans
124
+ return x , trans , trans_feat
76
125
else :
77
126
x = x .view (- 1 , 1024 , 1 ).repeat (1 , 1 , n_pts )
78
- return torch .cat ([x , pointfeat ], 1 ), trans
127
+ return torch .cat ([x , pointfeat ], 1 ), trans , trans_feat
79
128
80
129
class PointNetCls (nn .Module ):
81
- def __init__ (self , k = 2 ):
130
+ def __init__ (self , k = 2 , feature_transform = False ):
82
131
super (PointNetCls , self ).__init__ ()
83
- self .feat = PointNetfeat (global_feat = True )
132
+ self .feature_transform = feature_transform
133
+ self .feat = PointNetfeat (global_feat = True , feature_transform = feature_transform )
84
134
self .fc1 = nn .Linear (1024 , 512 )
85
135
self .fc2 = nn .Linear (512 , 256 )
86
136
self .fc3 = nn .Linear (256 , k )
@@ -90,17 +140,18 @@ def __init__(self, k = 2):
90
140
self .relu = nn .ReLU ()
91
141
92
142
def forward (self , x ):
93
- x , trans = self .feat (x )
143
+ x , trans , trans_feat = self .feat (x )
94
144
x = F .relu (self .bn1 (self .fc1 (x )))
95
145
x = F .relu (self .bn2 (self .dropout (self .fc2 (x ))))
96
146
x = self .fc3 (x )
97
- return F .log_softmax (x , dim = 1 ), trans
147
+ return F .log_softmax (x , dim = 1 ), trans , trans_feat
98
148
99
149
class PointNetDenseCls (nn .Module ):
100
- def __init__ (self , k = 2 ):
150
+ def __init__ (self , k = 2 , feature_transform = False ):
101
151
super (PointNetDenseCls , self ).__init__ ()
102
152
self .k = k
103
- self .feat = PointNetfeat (global_feat = False )
153
+ self .feature_transform = feature_transform
154
+ self .feat = PointNetfeat (global_feat = False , feature_transform = feature_transform )
104
155
self .conv1 = torch .nn .Conv1d (1088 , 512 , 1 )
105
156
self .conv2 = torch .nn .Conv1d (512 , 256 , 1 )
106
157
self .conv3 = torch .nn .Conv1d (256 , 128 , 1 )
@@ -112,35 +163,50 @@ def __init__(self, k = 2):
112
163
def forward (self , x ):
113
164
batchsize = x .size ()[0 ]
114
165
n_pts = x .size ()[2 ]
115
- x , trans = self .feat (x )
166
+ x , trans , trans_feat = self .feat (x )
116
167
x = F .relu (self .bn1 (self .conv1 (x )))
117
168
x = F .relu (self .bn2 (self .conv2 (x )))
118
169
x = F .relu (self .bn3 (self .conv3 (x )))
119
170
x = self .conv4 (x )
120
171
x = x .transpose (2 ,1 ).contiguous ()
121
172
x = F .log_softmax (x .view (- 1 ,self .k ), dim = - 1 )
122
173
x = x .view (batchsize , n_pts , self .k )
123
- return x , trans
174
+ return x , trans , trans_feat
124
175
176
+ def feature_transform_reguliarzer (trans ):
177
+ d = trans .size ()[1 ]
178
+ batchsize = trans .size ()[0 ]
179
+ I = torch .eye (d )[None , :, :]
180
+ if trans .is_cuda :
181
+ I = I .cuda ()
182
+ loss = torch .mean (torch .norm (torch .bmm (trans , trans .transpose (2 ,1 ) - I ), dim = (1 ,2 )))
183
+ return loss
125
184
126
185
if __name__ == '__main__' :
127
186
sim_data = Variable (torch .rand (32 ,3 ,2500 ))
128
187
trans = STN3d ()
129
188
out = trans (sim_data )
130
189
print ('stn' , out .size ())
131
-
190
+ print ('loss' , feature_transform_reguliarzer (out ))
191
+
192
+ sim_data_64d = Variable (torch .rand (32 , 64 , 2500 ))
193
+ trans = STNkd (k = 64 )
194
+ out = trans (sim_data_64d )
195
+ print ('stn64d' , out .size ())
196
+ print ('loss' , feature_transform_reguliarzer (out ))
197
+
132
198
pointfeat = PointNetfeat (global_feat = True )
133
- out , _ = pointfeat (sim_data )
199
+ out , _ , _ = pointfeat (sim_data )
134
200
print ('global feat' , out .size ())
135
201
136
202
pointfeat = PointNetfeat (global_feat = False )
137
- out , _ = pointfeat (sim_data )
203
+ out , _ , _ = pointfeat (sim_data )
138
204
print ('point feat' , out .size ())
139
205
140
206
cls = PointNetCls (k = 5 )
141
- out , _ = cls (sim_data )
207
+ out , _ , _ = cls (sim_data )
142
208
print ('class' , out .size ())
143
209
144
210
seg = PointNetDenseCls (k = 3 )
145
- out , _ = seg (sim_data )
211
+ out , _ , _ = seg (sim_data )
146
212
print ('seg' , out .size ())
0 commit comments