Skip to content

Commit

Permalink
Merge pull request graphdeeplearning#8 from JakeStevens/citation_graphs
Browse files Browse the repository at this point in the history
Citation graphs
  • Loading branch information
vijaydwivedi75 authored Mar 26, 2020
2 parents a17fdf1 + ea17ced commit c74e4e0
Show file tree
Hide file tree
Showing 22 changed files with 1,356 additions and 60 deletions.
37 changes: 37 additions & 0 deletions configs/CitationGraphs_node_classification_GAT.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"gpu": {
"use": true,
"id": 0
},

"model": "GAT",

"out_dir": "out/CitationGraphs_node_classification/",

"params": {
"seed": 41,
"epochs": 300,
"batch_size": 128,
"init_lr": 0.005,
"lr_reduce_factor": 0.5,
"lr_schedule_patience": 5,
"min_lr": 1e-5,
"weight_decay": 5e-4,
"print_epoch_interval": 5,
"max_time": 48
},

"net_params": {
"builtin": true,
"L": 1,
"n_heads": 8,
"hidden_dim": 8,
"out_dim": 8,
"residual": false,
"in_feat_dropout": 0.6,
"dropout": 0.6,
"graph_norm": false,
"batch_norm": false,
"self_loop": true
}
}
36 changes: 36 additions & 0 deletions configs/CitationGraphs_node_classification_GCN.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
{
"gpu": {
"use": true,
"id": 0
},

"model": "GCN",

"out_dir": "out/CitationGraphs_node_classification/",

"params": {
"seed": 41,
"epochs": 300,
"batch_size": 128,
"init_lr": 1e-2,
"lr_reduce_factor": 0.5,
"lr_schedule_patience": 5,
"min_lr": 1e-5,
"weight_decay": 5e-4,
"print_epoch_interval": 5,
"max_time": 48
},

"net_params": {
"builtin": true,
"L": 1,
"hidden_dim": 16,
"out_dim": 16,
"residual": false,
"in_feat_dropout": 0.5,
"dropout": 0.5,
"graph_norm": false,
"batch_norm": false,
"self_loop": true
}
}
36 changes: 36 additions & 0 deletions configs/CitationGraphs_node_classification_GraphSage.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
{
"gpu": {
"use": true,
"id": 0
},

"model": "GraphSage",

"out_dir": "out/CitationGraphs_node_classification/",

"params": {
"seed": 41,
"epochs": 300,
"batch_size": 20,
"init_lr": 1e-2,
"lr_reduce_factor": 0.5,
"lr_schedule_patience": 25,
"min_lr": 1e-6,
"weight_decay": 5e-4,
"print_epoch_interval": 5,
"max_time": 48
},

"net_params": {
"builtin": true,
"L": 1,
"hidden_dim": 16,
"out_dim": 16,
"residual": false,
"in_feat_dropout": 0.5,
"dropout": 0.5,
"graph_norm": false,
"batch_norm": false,
"sage_aggregator": "mean"
}
}
33 changes: 33 additions & 0 deletions configs/CitationGraphs_node_classification_MLP.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"gpu": {
"use": true,
"id": 0
},

"model": "MLP",

"out_dir": "out/CitationGraphs_node_classification/",

"params": {
"seed": 41,
"epochs": 300,
"batch_size": 20,
"init_lr": 0.005,
"lr_reduce_factor": 0.5,
"lr_schedule_patience": 25,
"min_lr": 1e-5,
"weight_decay": 5e-4,
"print_epoch_interval": 5,
"max_time": 48
},

"net_params": {
"L": 4,
"hidden_dim": 16,
"out_dim": 16,
"readout": "mean",
"gated": false,
"in_feat_dropout": 0.6,
"dropout": 0.6
}
}
33 changes: 33 additions & 0 deletions configs/CitationGraphs_node_classification_MLP_GATED.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"gpu": {
"use": true,
"id": 0
},

"model": "MLP",

"out_dir": "out/CitationGraphs_node_classification/",

"params": {
"seed": 41,
"epochs": 300,
"batch_size": 20,
"init_lr": 0.005,
"lr_reduce_factor": 0.5,
"lr_schedule_patience": 25,
"min_lr": 1e-5,
"weight_decay": 5e-4,
"print_epoch_interval": 5,
"max_time": 48
},

"net_params": {
"L": 4,
"hidden_dim": 16,
"out_dim": 16,
"readout": "mean",
"gated": true,
"in_feat_dropout": 0.6,
"dropout": 0.6
}
}
91 changes: 91 additions & 0 deletions data/CitationGraphs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import torch
import pickle
import torch.utils.data
import time
import os
import numpy as np

import csv

import dgl
from dgl.data import CoraDataset
from dgl.data import CitationGraphDataset
import networkx as nx

import random
random.seed(42)


def self_loop(g):
"""
Utility function only, to be used only when necessary as per user self_loop flag
: Overwriting the function dgl.transform.add_self_loop() to not miss ndata['feat'] and edata['feat']
This function is called inside a function in CitationGraphsDataset class.
"""
new_g = dgl.DGLGraph()
new_g.add_nodes(g.number_of_nodes())
new_g.ndata['feat'] = g.ndata['feat']

src, dst = g.all_edges(order="eid")
src = dgl.backend.zerocopy_to_numpy(src)
dst = dgl.backend.zerocopy_to_numpy(dst)
non_self_edges_idx = src != dst
nodes = np.arange(g.number_of_nodes())
new_g.add_edges(src[non_self_edges_idx], dst[non_self_edges_idx])
new_g.add_edges(nodes, nodes)

# This new edata is not used since this function gets called only for GCN, GAT
# However, we need this for the generic requirement of ndata and edata
new_g.edata['feat'] = torch.zeros(new_g.number_of_edges())
return new_g




class CitationGraphsDataset(torch.utils.data.Dataset):
def __init__(self, name):
t0 = time.time()
self.name = name.lower()

if self.name == 'cora':
dataset = CoraDataset()
else:
dataset = CitationGraphDataset(self.name)
dataset.graph.remove_edges_from(nx.selfloop_edges(dataset.graph))
graph = dgl.DGLGraph(dataset.graph)
E = graph.number_of_edges()
N = graph.number_of_nodes()
D = dataset.features.shape[1]
graph.ndata['feat'] = torch.Tensor(dataset.features)
graph.edata['feat'] = torch.zeros((E, D))
graph.batch_num_nodes = [N]


self.norm_n = torch.FloatTensor(N,1).fill_(1./float(N)).sqrt()
self.norm_e = torch.FloatTensor(E,1).fill_(1./float(E)).sqrt()
self.graph = graph
self.train_mask = torch.BoolTensor(dataset.train_mask)
self.val_mask = torch.BoolTensor(dataset.val_mask)
self.test_mask = torch.BoolTensor(dataset.test_mask)
self.labels = torch.LongTensor(dataset.labels)
self.num_classes = dataset.num_labels
self.num_dims = D



print("[!] Dataset: ", self.name)


print("Time taken: {:.4f}s".format(time.time()-t0))


def _add_self_loops(self):
# function for adding self loops
# this function will be called only if self_loop flag is True
self.graph = self_loop(self.graph)
norm = torch.pow(self.graph.in_degrees().float().clamp(min=1), -0.5)
shp = norm.shape + (1,) * (self.graph.ndata['feat'].dim() - 1)
self.norm_n = torch.reshape(norm, shp)

6 changes: 6 additions & 0 deletions data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from data.TUs import TUsDataset
from data.SBMs import SBMsDataset
from data.TSP import TSPDataset
from data.CitationGraphs import CitationGraphsDataset

def LoadData(DATASET_NAME):
"""
Expand Down Expand Up @@ -34,3 +35,8 @@ def LoadData(DATASET_NAME):
# handling for TSP dataset
if DATASET_NAME == 'TSP':
return TSPDataset(DATASET_NAME)

# handling for the CITATIONGRAPHS Datasets
CITATIONGRAPHS_DATASETS = ['CORA', 'CITESEER', 'PUBMED']
if DATASET_NAME in CITATIONGRAPHS_DATASETS:
return CitationGraphsDataset(DATASET_NAME)
91 changes: 67 additions & 24 deletions layers/gat_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch.nn as nn
import torch.nn.functional as F

from dgl.nn.pytorch import GATConv

"""
GAT: Graph Attention Network
Graph Attention Networks (Veličković et al., ICLR 2018)
Expand Down Expand Up @@ -51,35 +53,76 @@ class GATLayer(nn.Module):
"""
Param: [in_dim, out_dim, n_heads]
"""
def __init__(self, in_dim, out_dim, num_heads, dropout, graph_norm, batch_norm, residual=False):
def __init__(self, in_dim, out_dim, num_heads, dropout, graph_norm, batch_norm, residual=False, activation=None, dgl_builtin=False):

super().__init__()
self.in_channels = in_dim
self.out_channels = out_dim
self.num_heads = num_heads
self.residual = residual

if in_dim != (out_dim*num_heads):
self.residual = False

self.heads = nn.ModuleList()
for i in range(num_heads):
self.heads.append(GATHeadLayer(in_dim, out_dim, dropout, graph_norm, batch_norm))
self.merge = 'cat'
self.dgl_builtin = dgl_builtin

if dgl_builtin == False:
self.in_channels = in_dim
self.out_channels = out_dim
self.num_heads = num_heads
self.residual = residual

if in_dim != (out_dim*num_heads):
self.residual = False

self.heads = nn.ModuleList()
for i in range(num_heads):
self.heads.append(GATHeadLayer(in_dim, out_dim, dropout, graph_norm, batch_norm))
self.merge = 'cat'

else:
self.in_channels = in_dim
self.out_channels = out_dim
self.num_heads = num_heads
self.residual = residual
self.activation = activation
self.graph_norm = graph_norm
self.batch_norm = batch_norm

if in_dim != (out_dim*num_heads):
self.residual = False

# Both feat and weighting dropout tied together here
self.conv = GATConv(in_dim, out_dim, num_heads, dropout, dropout)
self.batchnorm_h = nn.BatchNorm1d(out_dim)



def forward(self, g, h, snorm_n):
h_in = h # for residual connection
head_outs = [attn_head(g, h, snorm_n) for attn_head in self.heads]

if self.merge == 'cat':
h = torch.cat(head_outs, dim=1)
if self.dgl_builtin == False:
h_in = h # for residual connection
head_outs = [attn_head(g, h, snorm_n) for attn_head in self.heads]

if self.merge == 'cat':
h = torch.cat(head_outs, dim=1)
else:
h = torch.mean(torch.stack(head_outs))

if self.residual:
h = h_in + h # residual connection
return h
else:
h = torch.mean(torch.stack(head_outs))

if self.residual:
h = h_in + h # residual connection
return h
h_in = h # for residual connection

h = self.conv(g, h).flatten(1)

if self.graph_norm:
h = h * snorm_n
if self.batch_norm:
h = self.batchnorm_h(h)

if self.residual:
h = h_in + h # residual connection

if self.activation:
h = self.activation(h)
return h

def __repr__(self):
return '{}(in_channels={}, out_channels={}, heads={}, residual={})'.format(self.__class__.__name__,
self.in_channels,
self.out_channels, self.num_heads, self.residual)
self.out_channels, self.num_heads, self.residual)


Loading

0 comments on commit c74e4e0

Please sign in to comment.