-
Notifications
You must be signed in to change notification settings - Fork 3.7k
/
datapipe.py
109 lines (80 loc) · 3.51 KB
/
datapipe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# In this example, you will find data loading implementations using PyTorch
# DataPipes (https://pytorch.org/data/) across various tasks:
# (1) molecular graph data loading pipe
# (2) mesh/point cloud data loading pipe
# In particular, we make use of PyG's built-in DataPipes, e.g., for batching
# multiple PyG data objects together or for converting SMILES strings into
# molecular graph representations. We also showcase how to write your own
# DataPipe (i.e. for loading and parsing mesh data into PyG data objects).
import argparse
import os.path as osp
import time
import torch
from torchdata.datapipes.iter import FileLister, FileOpener, IterDataPipe
from torch_geometric.data import Data, download_url, extract_zip
def molecule_datapipe() -> IterDataPipe:
# Download HIV dataset from MoleculeNet:
url = 'https://deepchemdata.s3-us-west-1.amazonaws.com/datasets'
root_dir = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data')
path = download_url(f'{url}/HIV.csv', root_dir)
datapipe = FileOpener([path])
datapipe = datapipe.parse_csv_as_dict()
datapipe = datapipe.parse_smiles(target_key='HIV_active')
datapipe = datapipe.in_memory_cache() # Cache graph instances in-memory.
return datapipe
@torch.utils.data.functional_datapipe('read_mesh')
class MeshOpener(IterDataPipe):
# A custom DataPipe to load and parse mesh data into PyG data objects.
def __init__(self, dp: IterDataPipe):
super().__init__()
self.dp = dp
def __iter__(self):
import meshio
for path in self.dp:
category = osp.basename(path).split('_')[0]
mesh = meshio.read(path)
pos = torch.from_numpy(mesh.points).to(torch.float)
face = torch.from_numpy(mesh.cells[0].data).t().contiguous()
yield Data(pos=pos, face=face, category=category)
def mesh_datapipe() -> IterDataPipe:
# Download ModelNet10 dataset from Princeton:
url = 'http://vision.princeton.edu/projects/2014/3DShapeNets'
root_dir = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data')
path = download_url(f'{url}/ModelNet10.zip', root_dir)
root_dir = osp.join(root_dir, 'ModelNet10')
if not osp.exists(root_dir):
extract_zip(path, root_dir)
def is_train(path: str) -> bool:
return 'train' in path
datapipe = FileLister([root_dir], masks='*.off', recursive=True)
datapipe = datapipe.filter(is_train)
datapipe = datapipe.read_mesh()
datapipe = datapipe.in_memory_cache() # Cache graph instances in-memory.
datapipe = datapipe.sample_points(1024) # Use PyG transforms from here.
datapipe = datapipe.knn_graph(k=8)
return datapipe
DATAPIPES = {
'molecule': molecule_datapipe,
'mesh': mesh_datapipe,
}
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--task', default='molecule', choices=DATAPIPES.keys())
args = parser.parse_args()
datapipe = DATAPIPES[args.task]()
print('Example output:')
print(next(iter(datapipe)))
# Shuffling + Batching support:
datapipe = datapipe.shuffle()
datapipe = datapipe.batch_graphs(batch_size=32)
# The first epoch will take longer than the remaining ones...
print('Iterating over all data...')
t = time.perf_counter()
for batch in datapipe:
pass
print(f'Done! [{time.perf_counter() - t:.2f}s]')
print('Iterating over all data a second time...')
t = time.perf_counter()
for batch in datapipe:
pass
print(f'Done! [{time.perf_counter() - t:.2f}s]')