forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Model] Training GraphSAGE with PyTorch Lightning (dmlc#2878)
* pytorch lightning initial examples * revert most changes in dataloader to favor dmlc#2886. * address comments
- Loading branch information
Showing
9 changed files
with
547 additions
and
261 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import torch as th | ||
import torch.nn as nn | ||
import torch.functional as F | ||
import dgl | ||
import dgl.nn as dglnn | ||
import sklearn.linear_model as lm | ||
import sklearn.metrics as skm | ||
import tqdm | ||
|
||
class SAGE(nn.Module): | ||
def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation, dropout): | ||
super().__init__() | ||
self.init(in_feats, n_hidden, n_classes, n_layers, activation, dropout) | ||
|
||
def init(self, in_feats, n_hidden, n_classes, n_layers, activation, dropout): | ||
self.n_layers = n_layers | ||
self.n_hidden = n_hidden | ||
self.n_classes = n_classes | ||
self.layers = nn.ModuleList() | ||
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean')) | ||
for i in range(1, n_layers - 1): | ||
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) | ||
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) | ||
self.dropout = nn.Dropout(dropout) | ||
self.activation = activation | ||
|
||
def forward(self, blocks, x): | ||
h = x | ||
for l, (layer, block) in enumerate(zip(self.layers, blocks)): | ||
h = layer(block, h) | ||
if l != len(self.layers) - 1: | ||
h = self.activation(h) | ||
h = self.dropout(h) | ||
return h | ||
|
||
def inference(self, g, x, device, batch_size, num_workers): | ||
""" | ||
Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling). | ||
g : the entire graph. | ||
x : the input of entire node set. | ||
The inference code is written in a fashion that it could handle any number of nodes and | ||
layers. | ||
""" | ||
# During inference with sampling, multi-layer blocks are very inefficient because | ||
# lots of computations in the first few layers are repeated. | ||
# Therefore, we compute the representation of all nodes layer by layer. The nodes | ||
# on each layer are of course splitted in batches. | ||
# TODO: can we standardize this? | ||
for l, layer in enumerate(self.layers): | ||
y = th.zeros(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes) | ||
|
||
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) | ||
dataloader = dgl.dataloading.NodeDataLoader( | ||
g, | ||
th.arange(g.num_nodes()).to(g.device), | ||
sampler, | ||
batch_size=batch_size, | ||
shuffle=True, | ||
drop_last=False, | ||
num_workers=num_workers) | ||
|
||
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): | ||
block = blocks[0] | ||
|
||
block = block.int().to(device) | ||
h = x[input_nodes].to(device) | ||
h = layer(block, h) | ||
if l != len(self.layers) - 1: | ||
h = self.activation(h) | ||
h = self.dropout(h) | ||
|
||
y[output_nodes] = h.cpu() | ||
|
||
x = y | ||
return y | ||
|
||
def compute_acc_unsupervised(emb, labels, train_nids, val_nids, test_nids): | ||
""" | ||
Compute the accuracy of prediction given the labels. | ||
""" | ||
emb = emb.cpu().numpy() | ||
labels = labels.cpu().numpy() | ||
train_nids = train_nids.cpu().numpy() | ||
train_labels = labels[train_nids] | ||
val_nids = val_nids.cpu().numpy() | ||
val_labels = labels[val_nids] | ||
test_nids = test_nids.cpu().numpy() | ||
test_labels = labels[test_nids] | ||
|
||
emb = (emb - emb.mean(0, keepdims=True)) / emb.std(0, keepdims=True) | ||
|
||
lr = lm.LogisticRegression(multi_class='multinomial', max_iter=10000) | ||
lr.fit(emb[train_nids], train_labels) | ||
|
||
pred = lr.predict(emb) | ||
f1_micro_eval = skm.f1_score(val_labels, pred[val_nids], average='micro') | ||
f1_micro_test = skm.f1_score(test_labels, pred[test_nids], average='micro') | ||
return f1_micro_eval, f1_micro_test |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import torch as th | ||
import dgl | ||
|
||
class NegativeSampler(object): | ||
def __init__(self, g, k, neg_share=False): | ||
self.weights = g.in_degrees().float() ** 0.75 | ||
self.k = k | ||
self.neg_share = neg_share | ||
|
||
def __call__(self, g, eids): | ||
src, _ = g.find_edges(eids) | ||
n = len(src) | ||
if self.neg_share and n % self.k == 0: | ||
dst = self.weights.multinomial(n, replacement=True) | ||
dst = dst.view(-1, 1, self.k).expand(-1, self.k, -1).flatten() | ||
else: | ||
dst = self.weights.multinomial(n*self.k, replacement=True) | ||
src = src.repeat_interleave(self.k) | ||
return src, dst |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
import dgl | ||
import numpy as np | ||
import torch as th | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
import dgl.nn.pytorch as dglnn | ||
import time | ||
import argparse | ||
import tqdm | ||
import glob | ||
import os | ||
|
||
from load_graph import load_reddit, inductive_split, load_ogb | ||
|
||
from torchmetrics import Accuracy | ||
from pytorch_lightning.callbacks import ModelCheckpoint | ||
from pytorch_lightning import LightningDataModule, LightningModule, Trainer | ||
from model import SAGE | ||
|
||
class SAGELightning(LightningModule): | ||
def __init__(self, | ||
in_feats, | ||
n_hidden, | ||
n_classes, | ||
n_layers, | ||
activation, | ||
dropout, | ||
lr): | ||
super().__init__() | ||
self.save_hyperparameters() | ||
self.module = SAGE(in_feats, n_hidden, n_classes, n_layers, activation, dropout) | ||
self.lr = lr | ||
# The usage of `train_acc` and `val_acc` is the recommended practice from now on as per | ||
# https://torchmetrics.readthedocs.io/en/latest/pages/lightning.html | ||
self.train_acc = Accuracy() | ||
self.val_acc = Accuracy() | ||
|
||
def training_step(self, batch, batch_idx): | ||
input_nodes, output_nodes, mfgs = batch | ||
mfgs = [mfg.int().to(device) for mfg in mfgs] | ||
batch_inputs = mfgs[0].srcdata['features'] | ||
batch_labels = mfgs[-1].dstdata['labels'] | ||
batch_pred = self.module(mfgs, batch_inputs) | ||
loss = F.cross_entropy(batch_pred, batch_labels) | ||
self.train_acc(th.softmax(batch_pred, 1), batch_labels) | ||
self.log('train_acc', self.train_acc, prog_bar=True, on_step=True, on_epoch=False) | ||
return loss | ||
|
||
def validation_step(self, batch, batch_idx): | ||
input_nodes, output_nodes, mfgs = batch | ||
mfgs = [mfg.int().to(device) for mfg in mfgs] | ||
batch_inputs = mfgs[0].srcdata['features'] | ||
batch_labels = mfgs[-1].dstdata['labels'] | ||
batch_pred = self.module(mfgs, batch_inputs) | ||
self.val_acc(th.softmax(batch_pred, 1), batch_labels) | ||
self.log('val_acc', self.val_acc, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True) | ||
|
||
def configure_optimizers(self): | ||
optimizer = th.optim.Adam(self.parameters(), lr=self.lr) | ||
return optimizer | ||
|
||
|
||
class DataModule(LightningDataModule): | ||
def __init__(self, dataset_name, data_cpu=False, fan_out=[10, 25], | ||
device=th.device('cpu'), batch_size=1000, num_workers=4): | ||
super().__init__() | ||
if dataset_name == 'reddit': | ||
g, n_classes = load_reddit() | ||
elif dataset_name == 'ogbn-products': | ||
g, n_classes = load_ogb('ogbn-products') | ||
else: | ||
raise ValueError('unknown dataset') | ||
|
||
train_nid = th.nonzero(g.ndata['train_mask'], as_tuple=True)[0] | ||
val_nid = th.nonzero(g.ndata['val_mask'], as_tuple=True)[0] | ||
test_nid = th.nonzero(~(g.ndata['train_mask'] | g.ndata['val_mask']), as_tuple=True)[0] | ||
|
||
sampler = dgl.dataloading.MultiLayerNeighborSampler([int(_) for _ in fan_out]) | ||
|
||
dataloader_device = th.device('cpu') | ||
if not data_cpu: | ||
train_nid = train_nid.to(device) | ||
val_nid = val_nid.to(device) | ||
test_nid = test_nid.to(device) | ||
g = g.formats(['csc']) | ||
g = g.to(device) | ||
dataloader_device = device | ||
|
||
self.g = g | ||
self.train_nid, self.val_nid, self.test_nid = train_nid, val_nid, test_nid | ||
self.sampler = sampler | ||
self.device = dataloader_device | ||
self.batch_size = batch_size | ||
self.num_workers = num_workers | ||
self.in_feats = g.ndata['features'].shape[1] | ||
self.n_classes = n_classes | ||
|
||
def train_dataloader(self): | ||
return dgl.dataloading.NodeDataLoader( | ||
self.g, | ||
self.train_nid, | ||
self.sampler, | ||
device=self.device, | ||
batch_size=self.batch_size, | ||
shuffle=True, | ||
drop_last=False, | ||
num_workers=self.num_workers) | ||
|
||
def val_dataloader(self): | ||
return dgl.dataloading.NodeDataLoader( | ||
self.g, | ||
self.val_nid, | ||
self.sampler, | ||
device=self.device, | ||
batch_size=self.batch_size, | ||
shuffle=True, | ||
drop_last=False, | ||
num_workers=self.num_workers) | ||
|
||
|
||
def evaluate(model, g, val_nid, device): | ||
""" | ||
Evaluate the model on the validation set specified by ``val_nid``. | ||
g : The entire graph. | ||
val_nid : the node Ids for validation. | ||
device : The GPU device to evaluate on. | ||
""" | ||
model.eval() | ||
nfeat = g.ndata['features'] | ||
labels = g.ndata['labels'] | ||
with th.no_grad(): | ||
pred = model.module.inference(g, nfeat, device, args.batch_size, args.num_workers) | ||
model.train() | ||
test_acc = Accuracy() | ||
return test_acc(th.softmax(pred[val_nid], -1), labels[val_nid].to(pred.device)) | ||
|
||
|
||
if __name__ == '__main__': | ||
argparser = argparse.ArgumentParser() | ||
argparser.add_argument('--gpu', type=int, default=0, | ||
help="GPU device ID. Use -1 for CPU training") | ||
argparser.add_argument('--dataset', type=str, default='reddit') | ||
argparser.add_argument('--num-epochs', type=int, default=20) | ||
argparser.add_argument('--num-hidden', type=int, default=16) | ||
argparser.add_argument('--num-layers', type=int, default=2) | ||
argparser.add_argument('--fan-out', type=str, default='10,25') | ||
argparser.add_argument('--batch-size', type=int, default=1000) | ||
argparser.add_argument('--log-every', type=int, default=20) | ||
argparser.add_argument('--eval-every', type=int, default=5) | ||
argparser.add_argument('--lr', type=float, default=0.003) | ||
argparser.add_argument('--dropout', type=float, default=0.5) | ||
argparser.add_argument('--num-workers', type=int, default=0, | ||
help="Number of sampling processes. Use 0 for no extra process.") | ||
argparser.add_argument('--inductive', action='store_true', | ||
help="Inductive learning setting") | ||
argparser.add_argument('--data-cpu', action='store_true', | ||
help="By default the script puts the graph, node features and labels " | ||
"on GPU when using it to save time for data copy. This may " | ||
"be undesired if they cannot fit in GPU memory at once. " | ||
"This flag disables that.") | ||
args = argparser.parse_args() | ||
|
||
if args.gpu >= 0: | ||
device = th.device('cuda:%d' % args.gpu) | ||
else: | ||
device = th.device('cpu') | ||
|
||
datamodule = DataModule( | ||
args.dataset, args.data_cpu, [int(_) for _ in args.fan_out.split(',')], | ||
device, args.batch_size, args.num_workers) | ||
model = SAGELightning( | ||
datamodule.in_feats, args.num_hidden, datamodule.n_classes, args.num_layers, | ||
F.relu, args.dropout, args.lr) | ||
|
||
# Train | ||
checkpoint_callback = ModelCheckpoint(monitor='val_acc', save_top_k=1) | ||
trainer = Trainer(gpus=[args.gpu] if args.gpu != -1 else None, | ||
max_epochs=args.num_epochs, | ||
callbacks=[checkpoint_callback]) | ||
trainer.fit(model, datamodule=datamodule) | ||
|
||
# Test | ||
dirs = glob.glob('./lightning_logs/*') | ||
version = max([int(os.path.split(x)[-1].split('_')[-1]) for x in dirs]) | ||
logdir = './lightning_logs/version_%d' % version | ||
print('Evaluating model in', logdir) | ||
ckpt = glob.glob(os.path.join(logdir, 'checkpoints', '*'))[0] | ||
|
||
model = SAGELightning.load_from_checkpoint( | ||
checkpoint_path=ckpt, hparams_file=os.path.join(logdir, 'hparams.yaml')).to(device) | ||
test_acc = evaluate(model, datamodule.g, datamodule.test_nid, device) | ||
print('Test accuracy:', test_acc) |
Oops, something went wrong.