Skip to content

Commit

Permalink
fixed missing file bugs, error uploaded folder
Browse files Browse the repository at this point in the history
  • Loading branch information
ma-xu committed Mar 11, 2022
1 parent ac97c8d commit b5ebdad
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 229 deletions.
185 changes: 0 additions & 185 deletions part_segmentation/data.py

This file was deleted.

File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def __init__(self, num_classes=50,points=2048, embed_dim=64, groups=1, res_expan
self.stages = len(pre_blocks)
self.class_num = num_classes
self.points = points
self.embedding = ConvBNReLU1D(3, embed_dim, bias=bias, activation=activation)
self.embedding = ConvBNReLU1D(6, embed_dim, bias=bias, activation=activation)
assert len(pre_blocks) == len(k_neighbors) == len(reducers) == len(pos_blocks) == len(dim_expansion), \
"Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers."
self.local_grouper_list = nn.ModuleList()
Expand Down Expand Up @@ -401,14 +401,14 @@ def __init__(self, num_classes=50,points=2048, embed_dim=64, groups=1, res_expan
self.classifier = nn.Sequential(
nn.Conv1d(gmp_dim+cls_dim+de_dims[-1], 128, 1, bias=bias),
nn.BatchNorm1d(128),
self.act,
nn.Dropout(),
nn.Conv1d(128, num_classes, 1, bias=bias)
)
self.en_dims = en_dims

def forward(self, x, cls_label):
def forward(self, x, norm_plt, cls_label):
xyz = x.permute(0, 2, 1)
x = torch.cat([x,norm_plt],dim=1)
x = self.embedding(x) # B,D,N

xyz_list = [xyz] # [B, N, 3]
Expand Down Expand Up @@ -440,8 +440,8 @@ def forward(self, x, cls_label):
cls_token = self.cls_map(cls_label.unsqueeze(dim=-1)) # [b, cls_dim, 1]
x = torch.cat([x, global_context.repeat([1, 1, x.shape[-1]]), cls_token.repeat([1, 1, x.shape[-1]])], dim=1)
x = self.classifier(x)
# x = F.log_softmax(x, dim=1)
# x = x.permute(0, 2, 1)
x = F.log_softmax(x, dim=1)
x = x.permute(0, 2, 1)
return x


Expand All @@ -459,6 +459,6 @@ def pointMLP(num_classes=50, **kwargs) -> PointMLP:
norm = torch.rand(2, 3, 2048)
cls_label = torch.rand([2, 16])
print("===> testing modelD ...")
model = model31G(50)
model = pointMLP(50)
out = model(data, cls_label) # [2,2048,50]
print(out.shape)
38 changes: 0 additions & 38 deletions part_segmentation/util.py

This file was deleted.

Empty file.
Loading

0 comments on commit b5ebdad

Please sign in to comment.