diff --git a/pointnet.py b/pointnet.py index 37091cd84..f6a18e4ae 100644 --- a/pointnet.py +++ b/pointnet.py @@ -101,7 +101,7 @@ def forward(self, x): x = F.relu(self.bn1(self.fc1(x))) x = F.relu(self.bn2(self.fc2(x))) x = self.fc3(x) - return F.log_softmax(x, dim=0), trans + return F.log_softmax(x, dim=1), trans class PointNetDenseCls(nn.Module): def __init__(self, k = 2):