Skip to content

Commit

Permalink
feature transform regularizer fix
Browse files Browse the repository at this point in the history
  • Loading branch information
fxia22 committed Apr 16, 2019
1 parent 9890413 commit 8e72856
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pointnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def feature_transform_reguliarzer(trans):
I = torch.eye(d)[None, :, :]
if trans.is_cuda:
I = I.cuda()
loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1) - I), dim=(1,2)))
loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1)) - I, dim=(1,2)))
return loss

if __name__ == '__main__':
Expand All @@ -195,7 +195,7 @@ def feature_transform_reguliarzer(trans):
out = trans(sim_data_64d)
print('stn64d', out.size())
print('loss', feature_transform_reguliarzer(out))

pointfeat = PointNetfeat(global_feat=True)
out, _, _ = pointfeat(sim_data)
print('global feat', out.size())
Expand Down

0 comments on commit 8e72856

Please sign in to comment.