Skip to content

Commit 0a6b08b

Browse files
committed
allow inputs with 3 channels (use_xyz)
1 parent 20af906 commit 0a6b08b

File tree

6 files changed

+94
-17
lines changed

6 files changed

+94
-17
lines changed

models/pointnet2_msg_cls.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def model_fn(model, data, epoch=0, eval=False):
4040

4141
class Pointnet2MSG(nn.Module):
4242

43-
def __init__(self, num_classes, input_channels=3):
43+
def __init__(self, num_classes, input_channels=3, use_xyz=True):
4444
super().__init__()
4545

4646
self.SA_modules = nn.ModuleList()
@@ -50,7 +50,8 @@ def __init__(self, num_classes, input_channels=3):
5050
radii=[0.1, 0.2, 0.4],
5151
nsamples=[32, 64, 128],
5252
mlps=[[input_channels, 64], [input_channels, 128],
53-
[input_channels, 128]]
53+
[input_channels, 128]],
54+
use_xyz=use_xyz
5455
)
5556
)
5657

@@ -61,11 +62,13 @@ def __init__(self, num_classes, input_channels=3):
6162
radii=[0.2, 0.4, 0.8],
6263
nsamples=[16, 32, 64],
6364
mlps=[[input_channels, 128], [input_channels, 256],
64-
[input_channels, 256]]
65+
[input_channels, 256]],
6566
)
6667
)
6768
self.SA_modules.append(
68-
PointnetSAModule(mlp=[128 + 256 + 256, 256, 512, 1024])
69+
PointnetSAModule(
70+
mlp=[128 + 256 + 256, 256, 512, 1024],
71+
)
6972
)
7073

7174
self.FC_layer = nn.Sequential(
@@ -108,3 +111,19 @@ def forward(self, xyz, points=None):
108111
loss.backward()
109112
print(loss.data[0])
110113
optimizer.step()
114+
115+
# With with use_xyz=False
116+
inputs = torch.randn(B, N, 3).cuda()
117+
labels = torch.from_numpy(np.random.randint(0, 3, size=B)).cuda()
118+
model = Pointnet2MSG(3, input_channels=3, use_xyz=False)
119+
model.cuda()
120+
121+
optimizer = optim.Adam(model.parameters(), lr=1e-2)
122+
123+
model_fn = model_fn_decorator(nn.CrossEntropyLoss())
124+
for _ in range(20):
125+
optimizer.zero_grad()
126+
_, loss, _ = model_fn(model, (inputs, labels))
127+
loss.backward()
128+
print(loss.data[0])
129+
optimizer.step()

models/pointnet2_msg_sem.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def model_fn(model, data, epoch=0, eval=False):
3939

4040
class Pointnet2MSG(nn.Module):
4141

42-
def __init__(self, num_classes, input_channels=9):
42+
def __init__(self, num_classes, input_channels=9, use_xyz=True):
4343
super().__init__()
4444

4545
self.SA_modules = nn.ModuleList()
@@ -49,7 +49,8 @@ def __init__(self, num_classes, input_channels=9):
4949
npoint=1024,
5050
radii=[0.05, 0.1],
5151
nsamples=[16, 32],
52-
mlps=[[c_in, 16, 16, 32], [c_in, 32, 32, 64]]
52+
mlps=[[c_in, 16, 16, 32], [c_in, 32, 32, 64]],
53+
use_xyz=use_xyz
5354
)
5455
)
5556
c_out_0 = 32 + 64
@@ -60,7 +61,8 @@ def __init__(self, num_classes, input_channels=9):
6061
npoint=256,
6162
radii=[0.1, 0.2],
6263
nsamples=[16, 32],
63-
mlps=[[c_in, 64, 64, 128], [c_in, 64, 96, 128]]
64+
mlps=[[c_in, 64, 64, 128], [c_in, 64, 96, 128]],
65+
# use_xyz=use_xyz
6466
)
6567
)
6668
c_out_1 = 128 + 128
@@ -71,7 +73,8 @@ def __init__(self, num_classes, input_channels=9):
7173
npoint=64,
7274
radii=[0.2, 0.4],
7375
nsamples=[16, 32],
74-
mlps=[[c_in, 128, 196, 256], [c_in, 128, 196, 256]]
76+
mlps=[[c_in, 128, 196, 256], [c_in, 128, 196, 256]],
77+
# use_xyz=use_xyz
7578
)
7679
)
7780
c_out_2 = 256 + 256
@@ -82,14 +85,15 @@ def __init__(self, num_classes, input_channels=9):
8285
npoint=16,
8386
radii=[0.4, 0.8],
8487
nsamples=[16, 32],
85-
mlps=[[c_in, 256, 256, 512], [c_in, 256, 384, 512]]
88+
mlps=[[c_in, 256, 256, 512], [c_in, 256, 384, 512]],
89+
# use_xyz=use_xyz
8690
)
8791
)
8892
c_out_3 = 512 + 512
8993

9094
self.FP_modules = nn.ModuleList()
9195
self.FP_modules.append(
92-
PointnetFPModule(mlp=[256 + input_channels, 128, 128])
96+
PointnetFPModule(mlp=[256 + (input_channels if use_xyz else 0), 128, 128])
9397
)
9498
self.FP_modules.append(PointnetFPModule(mlp=[512 + c_out_0, 256, 256]))
9599
self.FP_modules.append(PointnetFPModule(mlp=[512 + c_out_1, 512, 512]))
@@ -143,3 +147,20 @@ def forward(self, xyz, points=None):
143147
loss.backward()
144148
print(loss.data[0])
145149
optimizer.step()
150+
151+
# with use_xyz=False
152+
inputs = torch.randn(B, N, 3).cuda()
153+
labels = torch.from_numpy(np.random.randint(0, 3,
154+
size=B * N)).view(B, N).cuda()
155+
model = Pointnet2MSG(3, input_channels=3, use_xyz=False)
156+
model.cuda()
157+
158+
optimizer = optim.Adam(model.parameters(), lr=1e-2)
159+
160+
model_fn = model_fn_decorator(nn.CrossEntropyLoss())
161+
for _ in range(20):
162+
optimizer.zero_grad()
163+
_, loss, _ = model_fn(model, (inputs, labels))
164+
loss.backward()
165+
print(loss.data[0])
166+
optimizer.step()

models/pointnet2_ssg_cls.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def model_fn(model, data, epoch=0, eval=False):
4040

4141
class Pointnet2SSG(nn.Module):
4242

43-
def __init__(self, num_classes, input_channels=3):
43+
def __init__(self, num_classes, input_channels=3, use_xyz=True):
4444
super().__init__()
4545

4646
self.SA_modules = nn.ModuleList()
@@ -49,7 +49,8 @@ def __init__(self, num_classes, input_channels=3):
4949
npoint=512,
5050
radius=0.2,
5151
nsample=64,
52-
mlp=[input_channels, 64, 64, 128]
52+
mlp=[input_channels, 64, 64, 128],
53+
use_xyz=use_xyz
5354
)
5455
)
5556
self.SA_modules.append(
@@ -99,3 +100,19 @@ def forward(self, xyz, points=None):
99100
loss.backward()
100101
print(loss.data[0])
101102
optimizer.step()
103+
104+
# use_xyz=False
105+
inputs = torch.randn(B, N, 3).cuda()
106+
labels = torch.from_numpy(np.random.randint(0, 3, size=B)).cuda()
107+
model = Pointnet2SSG(3, input_channels=3, use_xyz=False)
108+
model.cuda()
109+
110+
optimizer = optim.Adam(model.parameters(), lr=1e-2)
111+
112+
model_fn = model_fn_decorator(nn.CrossEntropyLoss())
113+
for _ in range(20):
114+
optimizer.zero_grad()
115+
_, loss, _ = model_fn(model, (inputs, labels))
116+
loss.backward()
117+
print(loss.data[0])
118+
optimizer.step()

models/pointnet2_ssg_sem.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def model_fn(model, data, epoch=0, eval=False):
3939

4040
class Pointnet2SSG(nn.Module):
4141

42-
def __init__(self, num_classes, input_channels=3):
42+
def __init__(self, num_classes, input_channels=3, use_xyz=True):
4343
super().__init__()
4444

4545
self.SA_modules = nn.ModuleList()
@@ -48,7 +48,8 @@ def __init__(self, num_classes, input_channels=3):
4848
npoint=1024,
4949
radius=0.1,
5050
nsample=32,
51-
mlp=[input_channels, 32, 32, 64]
51+
mlp=[input_channels, 32, 32, 64],
52+
use_xyz=use_xyz
5253
)
5354
)
5455
self.SA_modules.append(
@@ -69,7 +70,7 @@ def __init__(self, num_classes, input_channels=3):
6970

7071
self.FP_modules = nn.ModuleList()
7172
self.FP_modules.append(
72-
PointnetFPModule(mlp=[128 + input_channels, 128, 128, 128])
73+
PointnetFPModule(mlp=[128 + (input_channels if use_xyz else 0), 128, 128, 128])
7374
)
7475
self.FP_modules.append(PointnetFPModule(mlp=[256 + 64, 256, 128]))
7576
self.FP_modules.append(PointnetFPModule(mlp=[256 + 128, 256, 256]))
@@ -121,3 +122,22 @@ def forward(self, xyz, points=None):
121122
loss.backward()
122123
print(loss.data[0])
123124
optimizer.step()
125+
126+
127+
# try with use_xyz=False too
128+
inputs = torch.randn(B, N, 3).cuda()
129+
labels = torch.from_numpy(np.random.randint(0, 3,
130+
size=B * N)).view(B, N).cuda()
131+
model = Pointnet2SSG(3, input_channels=3, use_xyz=False)
132+
model.cuda()
133+
134+
optimizer = optim.Adam(model.parameters(), lr=1e-2)
135+
136+
model_fn = model_fn_decorator(nn.CrossEntropyLoss())
137+
138+
for _ in range(20):
139+
optimizer.zero_grad()
140+
_, loss, _ = model_fn(model, (inputs, labels))
141+
loss.backward()
142+
print(loss.data[0])
143+
optimizer.step()

train_cls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def parse_args():
115115

116116
tb_log.configure('runs/{}'.format(args.run_name))
117117

118-
model = Pointnet(input_channels=3, num_classes=40)
118+
model = Pointnet(input_channels=3, num_classes=40, use_xyz=False)
119119
model.cuda()
120120
optimizer = optim.Adam(
121121
model.parameters(), lr=args.lr, weight_decay=args.weight_decay

train_sem_seg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@
104104
shuffle=True
105105
)
106106

107-
model = Pointnet(num_classes=13)
107+
model = Pointnet(num_classes=13, use_xyz=False)
108108
model.cuda()
109109
optimizer = optim.Adam(
110110
model.parameters(), lr=args.lr, weight_decay=args.weight_decay

0 commit comments

Comments
 (0)