-
Notifications
You must be signed in to change notification settings - Fork 21
Expand file tree
/
Copy pathcreate_modelnet40_small.py
More file actions
68 lines (49 loc) · 1.93 KB
/
create_modelnet40_small.py
File metadata and controls
68 lines (49 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#!/usr/bin/env python
import os
import h5py
import numpy as np
np.random.seed(123)
def main(split_size):
modelnet40_dir = "./data/modelnet40_ply_hdf5_2048/"
modelnet40_train_file = os.path.join(
modelnet40_dir, "train_minus_valid_files.txt")
modelnet40_train_split_file = os.path.join(
modelnet40_dir, f"train_minus_valid_split_{split_size}_files.txt")
modelnet40_train_split_path = f"ply_data_trainminusval_split_{split_size}.h5"
with open(modelnet40_train_file, "r") as f:
modelnet40_train_paths = [l.strip() for l in f.readlines()]
data = []
labels = []
for modelnet40_train_path in modelnet40_train_paths:
train_h5 = h5py.File(modelnet40_train_path, "r")
data.append(train_h5["data"][:])
labels.append(train_h5["label"][:])
data = np.concatenate(data)
labels = np.concatenate(labels)
train_data = []
train_label = []
for i in range(40):
cls_inds = np.where(labels == i)[0]
num_objs = len(cls_inds)
num_train = int(num_objs * split_size)
cls_data = data[cls_inds]
np.random.shuffle(cls_data)
train_data.append(cls_data[:num_train])
train_label += [i] * num_train
train_data = np.concatenate(train_data)
train_label = np.array(train_label).reshape(-1, 1)
with open(modelnet40_train_split_file, "w") as f:
f.write(os.path.join(modelnet40_dir,
modelnet40_train_split_path) + "\n")
with h5py.File(
os.path.join(modelnet40_dir, modelnet40_train_split_path),
"w") as f:
f.create_dataset("data", data=train_data)
f.create_dataset("label", data=train_label)
print('data: {}'.format(data.shape))
print('train data: {}'.format(train_data.shape))
print('min_label: {}'.format(labels.min()))
print('max_label: {}'.format(labels.max()))
if __name__ == "__main__":
main(0.5 / 0.8)
main(0.25 / 0.8)