Skip to content

Commit

Permalink
[Model] Training GraphSAGE with PyTorch Lightning (dmlc#2878)
Browse files Browse the repository at this point in the history
* pytorch lightning initial examples

* revert most changes in dataloader to favor dmlc#2886.

* address comments
  • Loading branch information
BarclayII authored May 11, 2021
1 parent c18f957 commit 70695ff
Show file tree
Hide file tree
Showing 9 changed files with 547 additions and 261 deletions.
11 changes: 11 additions & 0 deletions examples/pytorch/graphsage/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,14 @@ Notably,
in the paper.

Micro F1 score reaches 0.9212 on test set.

### Training with PyTorch Lightning

We also provide minibatch training scripts with PyTorch Lightning in `train_lightning.py` and `train_lightning_unsupervised.py`.

Requires `pytorch_lightning` and `torchmetrics`.

```bash
python3 train_lightning.py
python3 train_lightning_unsupervised.py
```
99 changes: 99 additions & 0 deletions examples/pytorch/graphsage/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import torch as th
import torch.nn as nn
import torch.functional as F
import dgl
import dgl.nn as dglnn
import sklearn.linear_model as lm
import sklearn.metrics as skm
import tqdm

class SAGE(nn.Module):
def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation, dropout):
super().__init__()
self.init(in_feats, n_hidden, n_classes, n_layers, activation, dropout)

def init(self, in_feats, n_hidden, n_classes, n_layers, activation, dropout):
self.n_layers = n_layers
self.n_hidden = n_hidden
self.n_classes = n_classes
self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
for i in range(1, n_layers - 1):
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
self.dropout = nn.Dropout(dropout)
self.activation = activation

def forward(self, blocks, x):
h = x
for l, (layer, block) in enumerate(zip(self.layers, blocks)):
h = layer(block, h)
if l != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)
return h

def inference(self, g, x, device, batch_size, num_workers):
"""
Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).
g : the entire graph.
x : the input of entire node set.
The inference code is written in a fashion that it could handle any number of nodes and
layers.
"""
# During inference with sampling, multi-layer blocks are very inefficient because
# lots of computations in the first few layers are repeated.
# Therefore, we compute the representation of all nodes layer by layer. The nodes
# on each layer are of course splitted in batches.
# TODO: can we standardize this?
for l, layer in enumerate(self.layers):
y = th.zeros(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)

sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader(
g,
th.arange(g.num_nodes()).to(g.device),
sampler,
batch_size=batch_size,
shuffle=True,
drop_last=False,
num_workers=num_workers)

for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0]

block = block.int().to(device)
h = x[input_nodes].to(device)
h = layer(block, h)
if l != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)

y[output_nodes] = h.cpu()

x = y
return y

def compute_acc_unsupervised(emb, labels, train_nids, val_nids, test_nids):
"""
Compute the accuracy of prediction given the labels.
"""
emb = emb.cpu().numpy()
labels = labels.cpu().numpy()
train_nids = train_nids.cpu().numpy()
train_labels = labels[train_nids]
val_nids = val_nids.cpu().numpy()
val_labels = labels[val_nids]
test_nids = test_nids.cpu().numpy()
test_labels = labels[test_nids]

emb = (emb - emb.mean(0, keepdims=True)) / emb.std(0, keepdims=True)

lr = lm.LogisticRegression(multi_class='multinomial', max_iter=10000)
lr.fit(emb[train_nids], train_labels)

pred = lr.predict(emb)
f1_micro_eval = skm.f1_score(val_labels, pred[val_nids], average='micro')
f1_micro_test = skm.f1_score(test_labels, pred[test_nids], average='micro')
return f1_micro_eval, f1_micro_test
19 changes: 19 additions & 0 deletions examples/pytorch/graphsage/negative_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch as th
import dgl

class NegativeSampler(object):
def __init__(self, g, k, neg_share=False):
self.weights = g.in_degrees().float() ** 0.75
self.k = k
self.neg_share = neg_share

def __call__(self, g, eids):
src, _ = g.find_edges(eids)
n = len(src)
if self.neg_share and n % self.k == 0:
dst = self.weights.multinomial(n, replacement=True)
dst = dst.view(-1, 1, self.k).expand(-1, self.k, -1).flatten()
else:
dst = self.weights.multinomial(n*self.k, replacement=True)
src = src.repeat_interleave(self.k)
return src, dst
193 changes: 193 additions & 0 deletions examples/pytorch/graphsage/train_lightning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import dgl
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import dgl.nn.pytorch as dglnn
import time
import argparse
import tqdm
import glob
import os

from load_graph import load_reddit, inductive_split, load_ogb

from torchmetrics import Accuracy
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from model import SAGE

class SAGELightning(LightningModule):
def __init__(self,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout,
lr):
super().__init__()
self.save_hyperparameters()
self.module = SAGE(in_feats, n_hidden, n_classes, n_layers, activation, dropout)
self.lr = lr
# The usage of `train_acc` and `val_acc` is the recommended practice from now on as per
# https://torchmetrics.readthedocs.io/en/latest/pages/lightning.html
self.train_acc = Accuracy()
self.val_acc = Accuracy()

def training_step(self, batch, batch_idx):
input_nodes, output_nodes, mfgs = batch
mfgs = [mfg.int().to(device) for mfg in mfgs]
batch_inputs = mfgs[0].srcdata['features']
batch_labels = mfgs[-1].dstdata['labels']
batch_pred = self.module(mfgs, batch_inputs)
loss = F.cross_entropy(batch_pred, batch_labels)
self.train_acc(th.softmax(batch_pred, 1), batch_labels)
self.log('train_acc', self.train_acc, prog_bar=True, on_step=True, on_epoch=False)
return loss

def validation_step(self, batch, batch_idx):
input_nodes, output_nodes, mfgs = batch
mfgs = [mfg.int().to(device) for mfg in mfgs]
batch_inputs = mfgs[0].srcdata['features']
batch_labels = mfgs[-1].dstdata['labels']
batch_pred = self.module(mfgs, batch_inputs)
self.val_acc(th.softmax(batch_pred, 1), batch_labels)
self.log('val_acc', self.val_acc, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)

def configure_optimizers(self):
optimizer = th.optim.Adam(self.parameters(), lr=self.lr)
return optimizer


class DataModule(LightningDataModule):
def __init__(self, dataset_name, data_cpu=False, fan_out=[10, 25],
device=th.device('cpu'), batch_size=1000, num_workers=4):
super().__init__()
if dataset_name == 'reddit':
g, n_classes = load_reddit()
elif dataset_name == 'ogbn-products':
g, n_classes = load_ogb('ogbn-products')
else:
raise ValueError('unknown dataset')

train_nid = th.nonzero(g.ndata['train_mask'], as_tuple=True)[0]
val_nid = th.nonzero(g.ndata['val_mask'], as_tuple=True)[0]
test_nid = th.nonzero(~(g.ndata['train_mask'] | g.ndata['val_mask']), as_tuple=True)[0]

sampler = dgl.dataloading.MultiLayerNeighborSampler([int(_) for _ in fan_out])

dataloader_device = th.device('cpu')
if not data_cpu:
train_nid = train_nid.to(device)
val_nid = val_nid.to(device)
test_nid = test_nid.to(device)
g = g.formats(['csc'])
g = g.to(device)
dataloader_device = device

self.g = g
self.train_nid, self.val_nid, self.test_nid = train_nid, val_nid, test_nid
self.sampler = sampler
self.device = dataloader_device
self.batch_size = batch_size
self.num_workers = num_workers
self.in_feats = g.ndata['features'].shape[1]
self.n_classes = n_classes

def train_dataloader(self):
return dgl.dataloading.NodeDataLoader(
self.g,
self.train_nid,
self.sampler,
device=self.device,
batch_size=self.batch_size,
shuffle=True,
drop_last=False,
num_workers=self.num_workers)

def val_dataloader(self):
return dgl.dataloading.NodeDataLoader(
self.g,
self.val_nid,
self.sampler,
device=self.device,
batch_size=self.batch_size,
shuffle=True,
drop_last=False,
num_workers=self.num_workers)


def evaluate(model, g, val_nid, device):
"""
Evaluate the model on the validation set specified by ``val_nid``.
g : The entire graph.
val_nid : the node Ids for validation.
device : The GPU device to evaluate on.
"""
model.eval()
nfeat = g.ndata['features']
labels = g.ndata['labels']
with th.no_grad():
pred = model.module.inference(g, nfeat, device, args.batch_size, args.num_workers)
model.train()
test_acc = Accuracy()
return test_acc(th.softmax(pred[val_nid], -1), labels[val_nid].to(pred.device))


if __name__ == '__main__':
argparser = argparse.ArgumentParser()
argparser.add_argument('--gpu', type=int, default=0,
help="GPU device ID. Use -1 for CPU training")
argparser.add_argument('--dataset', type=str, default='reddit')
argparser.add_argument('--num-epochs', type=int, default=20)
argparser.add_argument('--num-hidden', type=int, default=16)
argparser.add_argument('--num-layers', type=int, default=2)
argparser.add_argument('--fan-out', type=str, default='10,25')
argparser.add_argument('--batch-size', type=int, default=1000)
argparser.add_argument('--log-every', type=int, default=20)
argparser.add_argument('--eval-every', type=int, default=5)
argparser.add_argument('--lr', type=float, default=0.003)
argparser.add_argument('--dropout', type=float, default=0.5)
argparser.add_argument('--num-workers', type=int, default=0,
help="Number of sampling processes. Use 0 for no extra process.")
argparser.add_argument('--inductive', action='store_true',
help="Inductive learning setting")
argparser.add_argument('--data-cpu', action='store_true',
help="By default the script puts the graph, node features and labels "
"on GPU when using it to save time for data copy. This may "
"be undesired if they cannot fit in GPU memory at once. "
"This flag disables that.")
args = argparser.parse_args()

if args.gpu >= 0:
device = th.device('cuda:%d' % args.gpu)
else:
device = th.device('cpu')

datamodule = DataModule(
args.dataset, args.data_cpu, [int(_) for _ in args.fan_out.split(',')],
device, args.batch_size, args.num_workers)
model = SAGELightning(
datamodule.in_feats, args.num_hidden, datamodule.n_classes, args.num_layers,
F.relu, args.dropout, args.lr)

# Train
checkpoint_callback = ModelCheckpoint(monitor='val_acc', save_top_k=1)
trainer = Trainer(gpus=[args.gpu] if args.gpu != -1 else None,
max_epochs=args.num_epochs,
callbacks=[checkpoint_callback])
trainer.fit(model, datamodule=datamodule)

# Test
dirs = glob.glob('./lightning_logs/*')
version = max([int(os.path.split(x)[-1].split('_')[-1]) for x in dirs])
logdir = './lightning_logs/version_%d' % version
print('Evaluating model in', logdir)
ckpt = glob.glob(os.path.join(logdir, 'checkpoints', '*'))[0]

model = SAGELightning.load_from_checkpoint(
checkpoint_path=ckpt, hparams_file=os.path.join(logdir, 'hparams.yaml')).to(device)
test_acc = evaluate(model, datamodule.g, datamodule.test_nid, device)
print('Test accuracy:', test_acc)
Loading

0 comments on commit 70695ff

Please sign in to comment.