19
19
20
20
21
21
class STN3d (nn .Module ):
22
- def __init__ (self , num_points = 2500 ):
22
+ def __init__ (self ):
23
23
super (STN3d , self ).__init__ ()
24
- self .num_points = num_points
25
24
self .conv1 = torch .nn .Conv1d (3 , 64 , 1 )
26
25
self .conv2 = torch .nn .Conv1d (64 , 128 , 1 )
27
26
self .conv3 = torch .nn .Conv1d (128 , 1024 , 1 )
28
- self .mp1 = torch .nn .MaxPool1d (num_points )
29
27
self .fc1 = nn .Linear (1024 , 512 )
30
28
self .fc2 = nn .Linear (512 , 256 )
31
29
self .fc3 = nn .Linear (256 , 9 )
@@ -43,7 +41,7 @@ def forward(self, x):
43
41
x = F .relu (self .bn1 (self .conv1 (x )))
44
42
x = F .relu (self .bn2 (self .conv2 (x )))
45
43
x = F .relu (self .bn3 (self .conv3 (x )))
46
- x = self . mp1 ( x )
44
+ x = torch . max ( x , 2 , keepdim = True )[ 0 ]
47
45
x = x .view (- 1 , 1024 )
48
46
49
47
x = F .relu (self .bn4 (self .fc1 (x )))
@@ -59,20 +57,19 @@ def forward(self, x):
59
57
60
58
61
59
class PointNetfeat (nn .Module ):
62
- def __init__ (self , num_points = 2500 , global_feat = True ):
60
+ def __init__ (self , global_feat = True ):
63
61
super (PointNetfeat , self ).__init__ ()
64
- self .stn = STN3d (num_points = num_points )
62
+ self .stn = STN3d ()
65
63
self .conv1 = torch .nn .Conv1d (3 , 64 , 1 )
66
64
self .conv2 = torch .nn .Conv1d (64 , 128 , 1 )
67
65
self .conv3 = torch .nn .Conv1d (128 , 1024 , 1 )
68
66
self .bn1 = nn .BatchNorm1d (64 )
69
67
self .bn2 = nn .BatchNorm1d (128 )
70
68
self .bn3 = nn .BatchNorm1d (1024 )
71
- self .mp1 = torch .nn .MaxPool1d (num_points )
72
- self .num_points = num_points
73
69
self .global_feat = global_feat
74
70
def forward (self , x ):
75
71
batchsize = x .size ()[0 ]
72
+ n_pts = x .size ()[2 ]
76
73
trans = self .stn (x )
77
74
x = x .transpose (2 ,1 )
78
75
x = torch .bmm (x , trans )
@@ -81,19 +78,18 @@ def forward(self, x):
81
78
pointfeat = x
82
79
x = F .relu (self .bn2 (self .conv2 (x )))
83
80
x = self .bn3 (self .conv3 (x ))
84
- x = self . mp1 ( x )
81
+ x = torch . max ( x , 2 , keepdim = True )[ 0 ]
85
82
x = x .view (- 1 , 1024 )
86
83
if self .global_feat :
87
84
return x , trans
88
85
else :
89
- x = x .view (- 1 , 1024 , 1 ).repeat (1 , 1 , self . num_points )
86
+ x = x .view (- 1 , 1024 , 1 ).repeat (1 , 1 , n_pts )
90
87
return torch .cat ([x , pointfeat ], 1 ), trans
91
88
92
89
class PointNetCls (nn .Module ):
93
- def __init__ (self , num_points = 2500 , k = 2 ):
90
+ def __init__ (self , k = 2 ):
94
91
super (PointNetCls , self ).__init__ ()
95
- self .num_points = num_points
96
- self .feat = PointNetfeat (num_points , global_feat = True )
92
+ self .feat = PointNetfeat (global_feat = True )
97
93
self .fc1 = nn .Linear (1024 , 512 )
98
94
self .fc2 = nn .Linear (512 , 256 )
99
95
self .fc3 = nn .Linear (256 , k )
@@ -105,14 +101,13 @@ def forward(self, x):
105
101
x = F .relu (self .bn1 (self .fc1 (x )))
106
102
x = F .relu (self .bn2 (self .fc2 (x )))
107
103
x = self .fc3 (x )
108
- return F .log_softmax (x , dim = - 1 ), trans
104
+ return F .log_softmax (x , dim = 0 ), trans
109
105
110
106
class PointNetDenseCls (nn .Module ):
111
- def __init__ (self , num_points = 2500 , k = 2 ):
107
+ def __init__ (self , k = 2 ):
112
108
super (PointNetDenseCls , self ).__init__ ()
113
- self .num_points = num_points
114
109
self .k = k
115
- self .feat = PointNetfeat (num_points , global_feat = False )
110
+ self .feat = PointNetfeat (global_feat = False )
116
111
self .conv1 = torch .nn .Conv1d (1088 , 512 , 1 )
117
112
self .conv2 = torch .nn .Conv1d (512 , 256 , 1 )
118
113
self .conv3 = torch .nn .Conv1d (256 , 128 , 1 )
@@ -123,14 +118,15 @@ def __init__(self, num_points = 2500, k = 2):
123
118
124
119
def forward (self , x ):
125
120
batchsize = x .size ()[0 ]
121
+ n_pts = x .size ()[2 ]
126
122
x , trans = self .feat (x )
127
123
x = F .relu (self .bn1 (self .conv1 (x )))
128
124
x = F .relu (self .bn2 (self .conv2 (x )))
129
125
x = F .relu (self .bn3 (self .conv3 (x )))
130
126
x = self .conv4 (x )
131
127
x = x .transpose (2 ,1 ).contiguous ()
132
128
x = F .log_softmax (x .view (- 1 ,self .k ), dim = - 1 )
133
- x = x .view (batchsize , self . num_points , self .k )
129
+ x = x .view (batchsize , n_pts , self .k )
134
130
return x , trans
135
131
136
132
0 commit comments