Skip to content

support read_graph_pyg in parallel for ogb/io/read_graph_pyg.py #496

Open
@brysonwx

Description

@brysonwx
import pandas as pd
import torch
from torch_geometric.data import Data
import os.path as osp
import numpy as np
from ogb.io.read_graph_raw import read_csv_graph_raw, read_csv_heterograph_raw, read_binary_graph_raw, read_binary_heterograph_raw
from tqdm.auto import tqdm


def process_one_graph(graph, additional_node_files, additional_edge_files):
    g = Data()
    g.num_nodes = graph['num_nodes']
    g.edge_index = torch.from_numpy(graph['edge_index'])

    del graph['num_nodes']
    del graph['edge_index']

    if graph['edge_feat'] is not None:
        g.edge_attr = torch.from_numpy(graph['edge_feat'])
        del graph['edge_feat']

    if graph['node_feat'] is not None:
        g.x = torch.from_numpy(graph['node_feat'])
        del graph['node_feat']

    for key in additional_node_files:
        g[key] = torch.from_numpy(graph[key])
        del graph[key]

    for key in additional_edge_files:
        g[key] = torch.from_numpy(graph[key])
        del graph[key]
    return g


def process_one_heterograph(graph, additional_node_files, additional_edge_files):
    g = Data()

    g.__num_nodes__ = graph['num_nodes_dict']
    g.num_nodes_dict = graph['num_nodes_dict']

    # add edge connectivity
    g.edge_index_dict = {}
    for triplet, edge_index in graph['edge_index_dict'].items():
        g.edge_index_dict[triplet] = torch.from_numpy(edge_index)

    del graph['edge_index_dict']

    if graph['edge_feat_dict'] is not None:
        g.edge_attr_dict = {}
        for triplet in graph['edge_feat_dict'].keys():
            g.edge_attr_dict[triplet] = torch.from_numpy(graph['edge_feat_dict'][triplet])

        del graph['edge_feat_dict']

    if graph['node_feat_dict'] is not None:
        g.x_dict = {}
        for nodetype in graph['node_feat_dict'].keys():
            g.x_dict[nodetype] = torch.from_numpy(graph['node_feat_dict'][nodetype])

        del graph['node_feat_dict']

    for key in additional_node_files:
        g[key] = {}
        for nodetype in graph[key].keys():
            g[key][nodetype] = torch.from_numpy(graph[key][nodetype])

        del graph[key]

    for key in additional_edge_files:
        g[key] = {}
        for triplet in graph[key].keys():
            g[key][triplet] = torch.from_numpy(graph[key][triplet])

        del graph[key]
    return g


def process_graphs_in_parallel(graph_list, additional_node_files, additional_edge_files, num_workers=10, hetero_flag=False):
    pyg_graph_list = []

    if hetero_flag:
        process_func = process_one_heterograph
    else:
        process_func = process_one_graph
    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        futures = [executor.submit(process_func, graph, additional_node_files, additional_edge_files) for graph in graph_list]

        for future in tqdm(futures, desc="Processing graphs in parallel", total=len(futures)):
            pyg_graph_list.append(future.result())

    return pyg_graph_list


def read_graph_pyg(raw_dir, add_inverse_edge = False, additional_node_files = [], additional_edge_files = [], binary = False):

    if binary:
        # npz
        graph_list = read_binary_graph_raw(raw_dir, add_inverse_edge)
    else:
        # csv
        graph_list = read_csv_graph_raw(
            raw_dir, add_inverse_edge, additional_node_files = additional_node_files, additional_edge_files = additional_edge_files)
    
    pyg_graph_list = []

    print('Converting graphs into PyG objects...')
    print(f'The total length of graph_list is: {len(graph_list)}')

    pyg_graph_list = process_graphs_in_parallel(
        graph_list, additional_node_files, additional_edge_files, hetero_flag=False)

    return pyg_graph_list


def read_heterograph_pyg(raw_dir, add_inverse_edge = False, additional_node_files = [], additional_edge_files = [], binary = False):

    if binary:
        # npz
        graph_list = read_binary_heterograph_raw(raw_dir, add_inverse_edge)
    else:
        # csv
        graph_list = read_csv_heterograph_raw(
            raw_dir, add_inverse_edge, additional_node_files = additional_node_files, additional_edge_files = additional_edge_files)

    print('Converting graphs into PyG objects...')
    print(f'The total length of graph_list is: {len(graph_list)}')

    pyg_graph_list = process_graphs_in_parallel(
        graph_list, additional_node_files, additional_edge_files, hetero_flag=True)

    return pyg_graph_list

if __name__ == '__main__':
    pass

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions