Skip to content

Commit

Permalink
Enhancing GraphGym Documentation (#7885)
Browse files Browse the repository at this point in the history
Part of #5132.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored Sep 1, 2023
1 parent a52af69 commit 7a6052c
Show file tree
Hide file tree
Showing 19 changed files with 419 additions and 294 deletions.
18 changes: 12 additions & 6 deletions examples/contrib/pgm_explainer_graph_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import os.path as osp

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear, ReLU, Sequential

import torch_geometric.transforms as T
from torch_geometric.contrib.explain import PGMExplainer
Expand Down Expand Up @@ -37,15 +37,21 @@ def normalized_cut_2d(edge_index, pos):
return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0))


class Net(nn.Module):
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
nn1 = nn.Sequential(nn.Linear(2, 25), nn.ReLU(),
nn.Linear(25, d.num_features * 32))
nn1 = Sequential(
Linear(2, 25),
ReLU(),
Linear(25, d.num_features * 32),
)
self.conv1 = NNConv(d.num_features, 32, nn1, aggr='mean')

nn2 = nn.Sequential(nn.Linear(2, 25), nn.ReLU(),
nn.Linear(25, 32 * 64))
nn2 = Sequential(
Linear(2, 25),
ReLU(),
Linear(25, 32 * 64),
)
self.conv2 = NNConv(32, 64, nn2, aggr='mean')

self.fc1 = torch.nn.Linear(64, 128)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def run_training_proc(
backend='nccl', # or choose 'gloo' if 'nccl' is not supported.
rank=current_ctx.rank,
world_size=current_ctx.world_size,
init_method='tcp://{}:{}'.format(master_addr, training_pg_master_port))
init_method=f'tcp://{master_addr}:{training_pg_master_port}',
)

# Create distributed neighbor loader for training.
# We replace PyG's NeighborLoader with GLT's DistNeighborLoader.
Expand Down
9 changes: 4 additions & 5 deletions examples/infomax_inductive.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os.path as osp

import torch
import torch.nn as nn
from tqdm import tqdm

from torch_geometric.datasets import Reddit
Expand All @@ -19,7 +18,7 @@
num_workers=12)


class Encoder(nn.Module):
class Encoder(torch.nn.Module):
def __init__(self, in_channels, hidden_channels):
super().__init__()
self.convs = torch.nn.ModuleList([
Expand All @@ -30,9 +29,9 @@ def __init__(self, in_channels, hidden_channels):

self.activations = torch.nn.ModuleList()
self.activations.extend([
nn.PReLU(hidden_channels),
nn.PReLU(hidden_channels),
nn.PReLU(hidden_channels)
torch.nn.PReLU(hidden_channels),
torch.nn.PReLU(hidden_channels),
torch.nn.PReLU(hidden_channels)
])

def forward(self, x, edge_index, batch_size):
Expand Down
7 changes: 3 additions & 4 deletions examples/kuzu/papers_100M/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import kuzu
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

Expand Down Expand Up @@ -60,13 +59,13 @@
)


class GraphSAGE(nn.Module):
class GraphSAGE(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
dropout=0.2):
super().__init__()

self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
self.convs = torch.nn.ModuleList()
self.norms = torch.nn.ModuleList()

self.convs.append(SAGEConv(in_channels, hidden_channels))
self.bns.append(BatchNorm(hidden_channels))
Expand Down
18 changes: 12 additions & 6 deletions examples/mnist_nn_conv.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os.path as osp

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear, ReLU, Sequential

import torch_geometric.transforms as T
from torch_geometric.datasets import MNISTSuperpixels
Expand Down Expand Up @@ -35,15 +35,21 @@ def normalized_cut_2d(edge_index, pos):
return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0))


class Net(nn.Module):
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
nn1 = nn.Sequential(nn.Linear(2, 25), nn.ReLU(),
nn.Linear(25, d.num_features * 32))
nn1 = Sequential(
Linear(2, 25),
ReLU(),
Linear(25, d.num_features * 32),
)
self.conv1 = NNConv(d.num_features, 32, nn1, aggr='mean')

nn2 = nn.Sequential(nn.Linear(2, 25), nn.ReLU(),
nn.Linear(25, 32 * 64))
nn2 = Sequential(
Linear(2, 25),
ReLU(),
Linear(25, 32 * 64),
)
self.conv2 = NNConv(32, 64, nn2, aggr='mean')

self.fc1 = torch.nn.Linear(64, 128)
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/graphgym/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,8 +452,8 @@ def set_cfg(cfg):
def assert_cfg(cfg):
r"""Checks config values, do necessary post processing to the configs"""
if cfg.dataset.task not in ['node', 'edge', 'graph', 'link_pred']:
raise ValueError('Task {} not supported, must be one of node, '
'edge, graph, link_pred'.format(cfg.dataset.task))
raise ValueError(f"Task '{cfg.dataset.task}' not supported. Must be "
f"one of node, edge, graph, link_pred")
if 'classification' in cfg.dataset.task_type and cfg.model.loss_fun == \
'mse':
cfg.model.loss_fun = 'cross_entropy'
Expand Down
24 changes: 16 additions & 8 deletions torch_geometric/graphgym/contrib/layer/generalconv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import torch.nn as nn
from torch.nn import Parameter

from torch_geometric.graphgym.config import cfg
Expand All @@ -9,8 +8,7 @@


class GeneralConvLayer(MessagePassing):
r"""General GNN layer
"""
r"""A general GNN layer."""
def __init__(self, in_channels, out_channels, improved=False, cached=False,
bias=True, **kwargs):
super().__init__(aggr=cfg.gnn.agg, **kwargs)
Expand Down Expand Up @@ -128,13 +126,23 @@ def __init__(self, in_channels, out_channels, edge_dim, improved=False,
self.msg_direction = cfg.gnn.msg_direction

if self.msg_direction == 'single':
self.linear_msg = nn.Linear(in_channels + edge_dim, out_channels,
bias=False)
self.linear_msg = torch.nn.Linear(
in_channels + edge_dim,
out_channels,
bias=False,
)
else:
self.linear_msg = nn.Linear(in_channels * 2 + edge_dim,
out_channels, bias=False)
self.linear_msg = torch.nn.Linear(
in_channels * 2 + edge_dim,
out_channels,
bias=False,
)
if cfg.gnn.self_msg == 'concat':
self.linear_self = nn.Linear(in_channels, out_channels, bias=False)
self.linear_self = torch.nn.Linear(
in_channels,
out_channels,
bias=False,
)

if bias:
self.bias = Parameter(torch.empty(out_channels))
Expand Down
11 changes: 6 additions & 5 deletions torch_geometric/graphgym/init.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import torch.nn as nn
import torch


def init_weights(m):
Expand All @@ -9,11 +9,12 @@ def init_weights(m):
m (nn.Module): PyTorch module
"""
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
if (isinstance(m, torch.nn.BatchNorm2d)
or isinstance(m, torch.nn.BatchNorm1d)):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data = nn.init.xavier_uniform_(
m.weight.data, gain=nn.init.calculate_gain('relu'))
elif isinstance(m, torch.nn.Linear):
m.weight.data = torch.nn.init.xavier_uniform_(
m.weight.data, gain=torch.nn.init.calculate_gain('relu'))
if m.bias is not None:
m.bias.data.zero_()
33 changes: 19 additions & 14 deletions torch_geometric/graphgym/loader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os.path as osp
from typing import Callable

import torch
Expand Down Expand Up @@ -58,7 +59,7 @@ def load_pyg(name, dataset_dir):
Returns: PyG dataset object
"""
dataset_dir = '{}/{}'.format(dataset_dir, name)
dataset_dir = osp.join(dataset_dir, name)
if name in ['Cora', 'CiteSeer', 'PubMed']:
dataset = Planetoid(dataset_dir, name)
elif name[:3] == 'TU_':
Expand Down Expand Up @@ -87,7 +88,7 @@ def load_pyg(name, dataset_dir):
elif name == 'QM7b':
dataset = QM7b(dataset_dir)
else:
raise ValueError('{} not support'.format(name))
raise ValueError(f"'{name}' not support")

return dataset

Expand Down Expand Up @@ -194,7 +195,7 @@ def load_dataset():
elif format == 'OGB':
dataset = load_ogb(name.replace('_', '-'), dataset_dir)
else:
raise ValueError('Unknown data format: {}'.format(format))
raise ValueError(f"Unknown data format '{format}'")
return dataset


Expand Down Expand Up @@ -290,19 +291,23 @@ def get_loader(dataset, sampler, batch_size, shuffle=True):
pin_memory=True,
persistent_workers=pw)
elif sampler == "cluster":
loader_train = \
ClusterLoader(dataset[0],
num_parts=cfg.train.train_parts,
save_dir="{}/{}".format(cfg.dataset.dir,
cfg.dataset.name.replace(
"-", "_")),
batch_size=batch_size, shuffle=shuffle,
num_workers=cfg.num_workers,
pin_memory=True,
persistent_workers=pw)
loader_train = ClusterLoader(
dataset[0],
num_parts=cfg.train.train_parts,
save_dir=osp.join(
cfg.dataset.dir,
cfg.dataset.name.replace("-", "_"),
),
batch_size=batch_size,
shuffle=shuffle,
num_workers=cfg.num_workers,
pin_memory=True,
persistent_workers=pw,
)

else:
raise NotImplementedError("%s sampler is not implemented!" % sampler)
raise NotImplementedError(f"'{sampler}' is not implemented")

return loader_train


Expand Down
8 changes: 3 additions & 5 deletions torch_geometric/graphgym/loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.graphgym.register as register
Expand All @@ -17,8 +16,8 @@ def compute_loss(pred, true):
Returns: Loss, normalized prediction score
"""
bce_loss = nn.BCEWithLogitsLoss(reduction=cfg.model.size_average)
mse_loss = nn.MSELoss(reduction=cfg.model.size_average)
bce_loss = torch.nn.BCEWithLogitsLoss(reduction=cfg.model.size_average)
mse_loss = torch.nn.MSELoss(reduction=cfg.model.size_average)

# default manipulation for pred and true
# can be skipped if special loss computation is needed
Expand All @@ -44,5 +43,4 @@ def compute_loss(pred, true):
true = true.float()
return mse_loss(pred, true), pred
else:
raise ValueError('Loss func {} not supported'.format(
cfg.model.loss_fun))
raise ValueError(f"Loss function '{cfg.model.loss_fun}' not supported")
16 changes: 8 additions & 8 deletions torch_geometric/graphgym/models/act.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,35 @@
import torch.nn as nn
import torch

from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_act


def relu():
return nn.ReLU(inplace=cfg.mem.inplace)
return torch.nn.ReLU(inplace=cfg.mem.inplace)


def selu():
return nn.SELU(inplace=cfg.mem.inplace)
return torch.nn.SELU(inplace=cfg.mem.inplace)


def prelu():
return nn.PReLU()
return torch.nn.PReLU()


def elu():
return nn.ELU(inplace=cfg.mem.inplace)
return torch.nn.ELU(inplace=cfg.mem.inplace)


def lrelu_01():
return nn.LeakyReLU(0.1, inplace=cfg.mem.inplace)
return torch.nn.LeakyReLU(0.1, inplace=cfg.mem.inplace)


def lrelu_025():
return nn.LeakyReLU(0.25, inplace=cfg.mem.inplace)
return torch.nn.LeakyReLU(0.25, inplace=cfg.mem.inplace)


def lrelu_05():
return nn.LeakyReLU(0.5, inplace=cfg.mem.inplace)
return torch.nn.LeakyReLU(0.5, inplace=cfg.mem.inplace)


if cfg is not None:
Expand Down
Loading

0 comments on commit 7a6052c

Please sign in to comment.