Open
Description
import pandas as pd
import torch
from torch_geometric.data import Data
import os.path as osp
import numpy as np
from ogb.io.read_graph_raw import read_csv_graph_raw, read_csv_heterograph_raw, read_binary_graph_raw, read_binary_heterograph_raw
from tqdm.auto import tqdm
def process_one_graph(graph, additional_node_files, additional_edge_files):
g = Data()
g.num_nodes = graph['num_nodes']
g.edge_index = torch.from_numpy(graph['edge_index'])
del graph['num_nodes']
del graph['edge_index']
if graph['edge_feat'] is not None:
g.edge_attr = torch.from_numpy(graph['edge_feat'])
del graph['edge_feat']
if graph['node_feat'] is not None:
g.x = torch.from_numpy(graph['node_feat'])
del graph['node_feat']
for key in additional_node_files:
g[key] = torch.from_numpy(graph[key])
del graph[key]
for key in additional_edge_files:
g[key] = torch.from_numpy(graph[key])
del graph[key]
return g
def process_one_heterograph(graph, additional_node_files, additional_edge_files):
g = Data()
g.__num_nodes__ = graph['num_nodes_dict']
g.num_nodes_dict = graph['num_nodes_dict']
# add edge connectivity
g.edge_index_dict = {}
for triplet, edge_index in graph['edge_index_dict'].items():
g.edge_index_dict[triplet] = torch.from_numpy(edge_index)
del graph['edge_index_dict']
if graph['edge_feat_dict'] is not None:
g.edge_attr_dict = {}
for triplet in graph['edge_feat_dict'].keys():
g.edge_attr_dict[triplet] = torch.from_numpy(graph['edge_feat_dict'][triplet])
del graph['edge_feat_dict']
if graph['node_feat_dict'] is not None:
g.x_dict = {}
for nodetype in graph['node_feat_dict'].keys():
g.x_dict[nodetype] = torch.from_numpy(graph['node_feat_dict'][nodetype])
del graph['node_feat_dict']
for key in additional_node_files:
g[key] = {}
for nodetype in graph[key].keys():
g[key][nodetype] = torch.from_numpy(graph[key][nodetype])
del graph[key]
for key in additional_edge_files:
g[key] = {}
for triplet in graph[key].keys():
g[key][triplet] = torch.from_numpy(graph[key][triplet])
del graph[key]
return g
def process_graphs_in_parallel(graph_list, additional_node_files, additional_edge_files, num_workers=10, hetero_flag=False):
pyg_graph_list = []
if hetero_flag:
process_func = process_one_heterograph
else:
process_func = process_one_graph
with ProcessPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(process_func, graph, additional_node_files, additional_edge_files) for graph in graph_list]
for future in tqdm(futures, desc="Processing graphs in parallel", total=len(futures)):
pyg_graph_list.append(future.result())
return pyg_graph_list
def read_graph_pyg(raw_dir, add_inverse_edge = False, additional_node_files = [], additional_edge_files = [], binary = False):
if binary:
# npz
graph_list = read_binary_graph_raw(raw_dir, add_inverse_edge)
else:
# csv
graph_list = read_csv_graph_raw(
raw_dir, add_inverse_edge, additional_node_files = additional_node_files, additional_edge_files = additional_edge_files)
pyg_graph_list = []
print('Converting graphs into PyG objects...')
print(f'The total length of graph_list is: {len(graph_list)}')
pyg_graph_list = process_graphs_in_parallel(
graph_list, additional_node_files, additional_edge_files, hetero_flag=False)
return pyg_graph_list
def read_heterograph_pyg(raw_dir, add_inverse_edge = False, additional_node_files = [], additional_edge_files = [], binary = False):
if binary:
# npz
graph_list = read_binary_heterograph_raw(raw_dir, add_inverse_edge)
else:
# csv
graph_list = read_csv_heterograph_raw(
raw_dir, add_inverse_edge, additional_node_files = additional_node_files, additional_edge_files = additional_edge_files)
print('Converting graphs into PyG objects...')
print(f'The total length of graph_list is: {len(graph_list)}')
pyg_graph_list = process_graphs_in_parallel(
graph_list, additional_node_files, additional_edge_files, hetero_flag=True)
return pyg_graph_list
if __name__ == '__main__':
pass
Metadata
Metadata
Assignees
Labels
No labels
Activity