@@ -50,7 +50,8 @@ def __init__(self, num_points, transforms=None, train=True, download=True):
50
50
51
51
subprocess .check_call (shlex .split ("rm {}" .format (zipfile )))
52
52
53
- self .train , self .num_points = train , num_points
53
+ self .train = train
54
+ self .set_num_points (num_points )
54
55
if self .train :
55
56
self .files = _get_data_files (os .path .join (self .data_dir , "train_files.txt" ))
56
57
else :
@@ -68,7 +69,7 @@ def __init__(self, num_points, transforms=None, train=True, download=True):
68
69
self .randomize ()
69
70
70
71
def __getitem__ (self , idx ):
71
- pt_idxs = np .arange (0 , self .actual_number_of_points )
72
+ pt_idxs = np .arange (0 , self .num_points )
72
73
np .random .shuffle (pt_idxs )
73
74
74
75
current_points = self .points [idx , pt_idxs ].copy ()
@@ -84,7 +85,6 @@ def __len__(self):
84
85
85
86
def set_num_points (self , pts ):
86
87
self .num_points = pts
87
- self .actual_number_of_points = pts
88
88
89
89
def randomize (self ):
90
90
pass
@@ -103,7 +103,7 @@ def randomize(self):
103
103
d_utils .PointcloudJitter (),
104
104
]
105
105
)
106
- dset = ModelNet40Cls (16 , "./" , train = True , transforms = transforms )
106
+ dset = ModelNet40Cls (16 , train = True , transforms = transforms )
107
107
print (dset [0 ][0 ])
108
108
print (dset [0 ][1 ])
109
109
print (len (dset ))
0 commit comments