Skip to content

Commit

Permalink
Support sparse version GAT
Browse files Browse the repository at this point in the history
  • Loading branch information
sh0416 committed Sep 14, 2018
1 parent 7a99b19 commit 5f1c561
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 25 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ A small note about initial sparse matrix operations of https://github.com/tkipf/

# Requirements

pyGAT relies on Python 3.5 and PyTorch 0.4 (due to torch.where).
pyGAT relies on Python 3.5 and PyTorch 0.4.1 (due to torch.sparse_coo_tensor).

# Issues/Pull Requests/Feedbacks

Expand Down
40 changes: 30 additions & 10 deletions layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,47 @@ def __init__(self, in_features, out_features, dropout, alpha, concat=True):
self.alpha = alpha
self.concat = concat

self.W = nn.Parameter(nn.init.xavier_uniform(torch.Tensor(in_features, out_features).type(torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor), gain=np.sqrt(2.0)), requires_grad=True)
self.a = nn.Parameter(nn.init.xavier_uniform(torch.Tensor(2*out_features, 1).type(torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor), gain=np.sqrt(2.0)), requires_grad=True)
self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
nn.init.normal_(self.W.data, std=0.01)

self.a = nn.Parameter(torch.zeros(size=(1, 2*out_features)))
nn.init.normal_(self.a.data, std=0.01)

self.leakyrelu = nn.LeakyReLU(self.alpha)

def forward(self, input, adj):
N = input.size()[0]
edge = adj.nonzero().t()

h = torch.mm(input, self.W)
N = h.size()[0]
# h: N x out
assert not torch.any(torch.isnan(h))

# Self-attention on the nodes - Shared attention mechanism
edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t()
# edge: 2*D x E
edge_e = torch.exp(-self.leakyrelu(self.a.mm(edge_h).squeeze()))
assert not torch.any(torch.isnan(edge_e))
# edge_e: 1 x E
e = torch.sparse_coo_tensor(edge, edge_e, torch.Size([N, N]))
# e: N x N

a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)
e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))
e_rowsum = torch.matmul(e, torch.ones(size=(N, 1)).cuda())
#e_rowsum = torch.where(e_rowsum<1e-6, torch.full_like(e_rowsum, 1e-6), e_rowsum)
# e_rowsum: N x 1

zero_vec = -9e15*torch.ones_like(e)
attention = torch.where(adj > 0, e, zero_vec)
attention = F.softmax(attention, dim=1)
attention = F.dropout(attention, self.dropout, training=self.training)
h_prime = torch.matmul(attention, h)
h_prime = torch.matmul(e, h)
assert not torch.any(torch.isnan(h_prime))

h_prime = h_prime.div(e_rowsum)
# h_prime: N x out
assert not torch.any(torch.isnan(h_prime))

if self.concat:
# if this layer is not last layer,
return F.elu(h_prime)
else:
# if this layer is last layer,
return h_prime

def __repr__(self):
Expand Down
16 changes: 11 additions & 5 deletions models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch.nn as nn
import torch
import torch.nn as nn
import torch.nn.functional as F
from layers import GraphAttentionLayer

Expand All @@ -9,11 +9,19 @@ def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
super(GAT, self).__init__()
self.dropout = dropout

self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]
self.attentions = [GraphAttentionLayer(nfeat,
nhid,
dropout=dropout,
alpha=alpha,
concat=True) for _ in range(nheads)]
for i, attention in enumerate(self.attentions):
self.add_module('attention_{}'.format(i), attention)

self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)
self.out_att = GraphAttentionLayer(nhid * nheads,
nclass,
dropout=dropout,
alpha=alpha,
concat=False)

def forward(self, x, adj):
x = F.dropout(x, self.dropout, training=self.training)
Expand All @@ -22,5 +30,3 @@ def forward(self, x, adj):
x = F.elu(self.out_att(x, adj))
return F.log_softmax(x, dim=1)



11 changes: 8 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
torch==0.4.0a0+1fdb392
scipy==0.19.1
numpy==1.14.0
certifi==2018.8.24
cffi==1.11.5
mkl-fft==1.0.4
mkl-random==1.0.1
numpy==1.15.1
pycparser==2.18
scipy==1.1.0
torch==0.4.1.post2
20 changes: 14 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from __future__ import division
from __future__ import print_function

import os
import glob
import time
import random
import argparse
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import glob
from torch.autograd import Variable

from utils import load_data, accuracy
Expand All @@ -19,7 +20,7 @@
parser = argparse.ArgumentParser()
parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.')
parser.add_argument('--fastmode', action='store_true', default=False, help='Validate during training pass.')
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--seed', type=int, default=72, help='Random seed.')
parser.add_argument('--epochs', type=int, default=10000, help='Number of epochs to train.')
parser.add_argument('--lr', type=float, default=0.005, help='Initial learning rate.')
parser.add_argument('--weight_decay', type=float, default=5e-4, help='Weight decay (L2 loss on parameters).')
Expand All @@ -42,8 +43,15 @@
adj, features, labels, idx_train, idx_val, idx_test = load_data()

# Model and optimizer
model = GAT(nfeat=features.shape[1], nhid=args.hidden, nclass=int(labels.max()) + 1, dropout=args.dropout, nheads=args.nb_heads, alpha=args.alpha)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
model = GAT(nfeat=features.shape[1],
nhid=args.hidden,
nclass=int(labels.max()) + 1,
dropout=args.dropout,
nheads=args.nb_heads,
alpha=args.alpha)
optimizer = optim.Adam(model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay)

if args.cuda:
model.cuda()
Expand Down

0 comments on commit 5f1c561

Please sign in to comment.