Skip to content

Commit 9b237f5

Browse files
authored
Fix num points (#64)
* Fix number of points * Better fix
1 parent 3492c15 commit 9b237f5

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

pointnet2/data/ModelNet40Loader.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def __init__(self, num_points, transforms=None, train=True, download=True):
5050

5151
subprocess.check_call(shlex.split("rm {}".format(zipfile)))
5252

53-
self.train, self.num_points = train, num_points
53+
self.train = train
54+
self.set_num_points(num_points)
5455
if self.train:
5556
self.files = _get_data_files(os.path.join(self.data_dir, "train_files.txt"))
5657
else:
@@ -68,7 +69,7 @@ def __init__(self, num_points, transforms=None, train=True, download=True):
6869
self.randomize()
6970

7071
def __getitem__(self, idx):
71-
pt_idxs = np.arange(0, self.actual_number_of_points)
72+
pt_idxs = np.arange(0, self.num_points)
7273
np.random.shuffle(pt_idxs)
7374

7475
current_points = self.points[idx, pt_idxs].copy()
@@ -84,7 +85,6 @@ def __len__(self):
8485

8586
def set_num_points(self, pts):
8687
self.num_points = pts
87-
self.actual_number_of_points = pts
8888

8989
def randomize(self):
9090
pass
@@ -103,7 +103,7 @@ def randomize(self):
103103
d_utils.PointcloudJitter(),
104104
]
105105
)
106-
dset = ModelNet40Cls(16, "./", train=True, transforms=transforms)
106+
dset = ModelNet40Cls(16, train=True, transforms=transforms)
107107
print(dset[0][0])
108108
print(dset[0][1])
109109
print(len(dset))

0 commit comments

Comments
 (0)