Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancing GraphGym Documentation #7885

Merged
merged 12 commits into from
Sep 1, 2023
18 changes: 12 additions & 6 deletions examples/contrib/pgm_explainer_graph_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import os.path as osp

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear, ReLU, Sequential

import torch_geometric.transforms as T
from torch_geometric.contrib.explain import PGMExplainer
Expand Down Expand Up @@ -37,15 +37,21 @@ def normalized_cut_2d(edge_index, pos):
return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0))


class Net(nn.Module):
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
nn1 = nn.Sequential(nn.Linear(2, 25), nn.ReLU(),
nn.Linear(25, d.num_features * 32))
nn1 = Sequential(
Linear(2, 25),
ReLU(),
Linear(25, d.num_features * 32),
)
self.conv1 = NNConv(d.num_features, 32, nn1, aggr='mean')

nn2 = nn.Sequential(nn.Linear(2, 25), nn.ReLU(),
nn.Linear(25, 32 * 64))
nn2 = Sequential(
Linear(2, 25),
ReLU(),
Linear(25, 32 * 64),
)
self.conv2 = NNConv(32, 64, nn2, aggr='mean')

self.fc1 = torch.nn.Linear(64, 128)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def run_training_proc(
backend='nccl', # or choose 'gloo' if 'nccl' is not supported.
rank=current_ctx.rank,
world_size=current_ctx.world_size,
init_method='tcp://{}:{}'.format(master_addr, training_pg_master_port))
init_method=f'tcp://{master_addr}:{training_pg_master_port}',
)

# Create distributed neighbor loader for training.
# We replace PyG's NeighborLoader with GLT's DistNeighborLoader.
Expand Down
9 changes: 4 additions & 5 deletions examples/infomax_inductive.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os.path as osp

import torch
import torch.nn as nn
from tqdm import tqdm

from torch_geometric.datasets import Reddit
Expand All @@ -19,7 +18,7 @@
num_workers=12)


class Encoder(nn.Module):
class Encoder(torch.nn.Module):
def __init__(self, in_channels, hidden_channels):
super().__init__()
self.convs = torch.nn.ModuleList([
Expand All @@ -30,9 +29,9 @@ def __init__(self, in_channels, hidden_channels):

self.activations = torch.nn.ModuleList()
self.activations.extend([
nn.PReLU(hidden_channels),
nn.PReLU(hidden_channels),
nn.PReLU(hidden_channels)
torch.nn.PReLU(hidden_channels),
torch.nn.PReLU(hidden_channels),
torch.nn.PReLU(hidden_channels)
])

def forward(self, x, edge_index, batch_size):
Expand Down
7 changes: 3 additions & 4 deletions examples/kuzu/papers_100M/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import kuzu
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

Expand Down Expand Up @@ -60,13 +59,13 @@
)


class GraphSAGE(nn.Module):
class GraphSAGE(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
dropout=0.2):
super().__init__()

self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
self.convs = torch.nn.ModuleList()
self.norms = torch.nn.ModuleList()

self.convs.append(SAGEConv(in_channels, hidden_channels))
self.bns.append(BatchNorm(hidden_channels))
Expand Down
18 changes: 12 additions & 6 deletions examples/mnist_nn_conv.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os.path as osp

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear, ReLU, Sequential

import torch_geometric.transforms as T
from torch_geometric.datasets import MNISTSuperpixels
Expand Down Expand Up @@ -35,15 +35,21 @@ def normalized_cut_2d(edge_index, pos):
return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0))


class Net(nn.Module):
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
nn1 = nn.Sequential(nn.Linear(2, 25), nn.ReLU(),
nn.Linear(25, d.num_features * 32))
nn1 = Sequential(
Linear(2, 25),
ReLU(),
Linear(25, d.num_features * 32),
)
self.conv1 = NNConv(d.num_features, 32, nn1, aggr='mean')

nn2 = nn.Sequential(nn.Linear(2, 25), nn.ReLU(),
nn.Linear(25, 32 * 64))
nn2 = Sequential(
Linear(2, 25),
ReLU(),
Linear(25, 32 * 64),
)
self.conv2 = NNConv(32, 64, nn2, aggr='mean')

self.fc1 = torch.nn.Linear(64, 128)
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/graphgym/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,8 +452,8 @@ def set_cfg(cfg):
def assert_cfg(cfg):
r"""Checks config values, do necessary post processing to the configs"""
if cfg.dataset.task not in ['node', 'edge', 'graph', 'link_pred']:
raise ValueError('Task {} not supported, must be one of node, '
'edge, graph, link_pred'.format(cfg.dataset.task))
raise ValueError(f"Task '{cfg.dataset.task}' not supported. Must be "
f"one of node, edge, graph, link_pred")
if 'classification' in cfg.dataset.task_type and cfg.model.loss_fun == \
'mse':
cfg.model.loss_fun = 'cross_entropy'
Expand Down
24 changes: 16 additions & 8 deletions torch_geometric/graphgym/contrib/layer/generalconv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import torch.nn as nn
from torch.nn import Parameter

from torch_geometric.graphgym.config import cfg
Expand All @@ -9,8 +8,7 @@


class GeneralConvLayer(MessagePassing):
r"""General GNN layer
"""
r"""A general GNN layer."""
def __init__(self, in_channels, out_channels, improved=False, cached=False,
bias=True, **kwargs):
super().__init__(aggr=cfg.gnn.agg, **kwargs)
Expand Down Expand Up @@ -128,13 +126,23 @@ def __init__(self, in_channels, out_channels, edge_dim, improved=False,
self.msg_direction = cfg.gnn.msg_direction

if self.msg_direction == 'single':
self.linear_msg = nn.Linear(in_channels + edge_dim, out_channels,
bias=False)
self.linear_msg = torch.nn.Linear(
in_channels + edge_dim,
out_channels,
bias=False,
)
else:
self.linear_msg = nn.Linear(in_channels * 2 + edge_dim,
out_channels, bias=False)
self.linear_msg = torch.nn.Linear(
in_channels * 2 + edge_dim,
out_channels,
bias=False,
)
if cfg.gnn.self_msg == 'concat':
self.linear_self = nn.Linear(in_channels, out_channels, bias=False)
self.linear_self = torch.nn.Linear(
in_channels,
out_channels,
bias=False,
)

if bias:
self.bias = Parameter(torch.empty(out_channels))
Expand Down
11 changes: 6 additions & 5 deletions torch_geometric/graphgym/init.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import torch.nn as nn
import torch


def init_weights(m):
Expand All @@ -9,11 +9,12 @@ def init_weights(m):
m (nn.Module): PyTorch module

"""
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
if (isinstance(m, torch.nn.BatchNorm2d)
or isinstance(m, torch.nn.BatchNorm1d)):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data = nn.init.xavier_uniform_(
m.weight.data, gain=nn.init.calculate_gain('relu'))
elif isinstance(m, torch.nn.Linear):
m.weight.data = torch.nn.init.xavier_uniform_(
m.weight.data, gain=torch.nn.init.calculate_gain('relu'))
if m.bias is not None:
m.bias.data.zero_()
33 changes: 19 additions & 14 deletions torch_geometric/graphgym/loader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os.path as osp
from typing import Callable

import torch
Expand Down Expand Up @@ -58,7 +59,7 @@ def load_pyg(name, dataset_dir):
Returns: PyG dataset object

"""
dataset_dir = '{}/{}'.format(dataset_dir, name)
dataset_dir = osp.join(dataset_dir, name)
if name in ['Cora', 'CiteSeer', 'PubMed']:
dataset = Planetoid(dataset_dir, name)
elif name[:3] == 'TU_':
Expand Down Expand Up @@ -87,7 +88,7 @@ def load_pyg(name, dataset_dir):
elif name == 'QM7b':
dataset = QM7b(dataset_dir)
else:
raise ValueError('{} not support'.format(name))
raise ValueError(f"'{name}' not support")

return dataset

Expand Down Expand Up @@ -194,7 +195,7 @@ def load_dataset():
elif format == 'OGB':
dataset = load_ogb(name.replace('_', '-'), dataset_dir)
else:
raise ValueError('Unknown data format: {}'.format(format))
raise ValueError(f"Unknown data format '{format}'")
return dataset


Expand Down Expand Up @@ -290,19 +291,23 @@ def get_loader(dataset, sampler, batch_size, shuffle=True):
pin_memory=True,
persistent_workers=pw)
elif sampler == "cluster":
loader_train = \
ClusterLoader(dataset[0],
num_parts=cfg.train.train_parts,
save_dir="{}/{}".format(cfg.dataset.dir,
cfg.dataset.name.replace(
"-", "_")),
batch_size=batch_size, shuffle=shuffle,
num_workers=cfg.num_workers,
pin_memory=True,
persistent_workers=pw)
loader_train = ClusterLoader(
dataset[0],
num_parts=cfg.train.train_parts,
save_dir=osp.join(
cfg.dataset.dir,
cfg.dataset.name.replace("-", "_"),
),
batch_size=batch_size,
shuffle=shuffle,
num_workers=cfg.num_workers,
pin_memory=True,
persistent_workers=pw,
)

else:
raise NotImplementedError("%s sampler is not implemented!" % sampler)
raise NotImplementedError(f"'{sampler}' is not implemented")

return loader_train


Expand Down
8 changes: 3 additions & 5 deletions torch_geometric/graphgym/loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.graphgym.register as register
Expand All @@ -17,8 +16,8 @@ def compute_loss(pred, true):
Returns: Loss, normalized prediction score

"""
bce_loss = nn.BCEWithLogitsLoss(reduction=cfg.model.size_average)
mse_loss = nn.MSELoss(reduction=cfg.model.size_average)
bce_loss = torch.nn.BCEWithLogitsLoss(reduction=cfg.model.size_average)
mse_loss = torch.nn.MSELoss(reduction=cfg.model.size_average)

# default manipulation for pred and true
# can be skipped if special loss computation is needed
Expand All @@ -44,5 +43,4 @@ def compute_loss(pred, true):
true = true.float()
return mse_loss(pred, true), pred
else:
raise ValueError('Loss func {} not supported'.format(
cfg.model.loss_fun))
raise ValueError(f"Loss function '{cfg.model.loss_fun}' not supported")
16 changes: 8 additions & 8 deletions torch_geometric/graphgym/models/act.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,35 @@
import torch.nn as nn
import torch

from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_act


def relu():
return nn.ReLU(inplace=cfg.mem.inplace)
return torch.nn.ReLU(inplace=cfg.mem.inplace)


def selu():
return nn.SELU(inplace=cfg.mem.inplace)
return torch.nn.SELU(inplace=cfg.mem.inplace)


def prelu():
return nn.PReLU()
return torch.nn.PReLU()


def elu():
return nn.ELU(inplace=cfg.mem.inplace)
return torch.nn.ELU(inplace=cfg.mem.inplace)


def lrelu_01():
return nn.LeakyReLU(0.1, inplace=cfg.mem.inplace)
return torch.nn.LeakyReLU(0.1, inplace=cfg.mem.inplace)


def lrelu_025():
return nn.LeakyReLU(0.25, inplace=cfg.mem.inplace)
return torch.nn.LeakyReLU(0.25, inplace=cfg.mem.inplace)


def lrelu_05():
return nn.LeakyReLU(0.5, inplace=cfg.mem.inplace)
return torch.nn.LeakyReLU(0.5, inplace=cfg.mem.inplace)


if cfg is not None:
Expand Down
Loading