Skip to content

Commit

Permalink
Merge pull request Diego999#13 from sh0416/master
Browse files Browse the repository at this point in the history
[Bug fix] No backpropagation in sparse GAT
  • Loading branch information
Diego999 authored Oct 23, 2018
2 parents 808e0ff + 7b428ef commit 3301fd6
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 9 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ For the branch **master**, the training of the transductive learning on Cora tas

A small note about initial sparse matrix operations of https://github.com/tkipf/pygcn: they have been removed. Therefore, the current model take ~7GB on GRAM.

# Sparse version GAT

We develop a sparse version GAT using pytorch. There are numerically instability because of softmax function. Therefore, you need to initialize carefully. To use sparse version GAT, add flag `--sparse`. The performance of sparse version is similar with tensorflow. On a Titan Xp takes 0.08~0.14 sec.

# Requirements

pyGAT relies on Python 3.5 and PyTorch 0.4.1 (due to torch.sparse_coo_tensor).
Expand Down
45 changes: 37 additions & 8 deletions layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,34 @@ def __repr__(self):
return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


class SpecialSpmmFunction(torch.autograd.Function):
"""Special function for only sparse region backpropataion layer."""
@staticmethod
def forward(ctx, indices, values, shape, b):
assert indices.requires_grad == False
a = torch.sparse_coo_tensor(indices, values, shape)
ctx.save_for_backward(a, b)
ctx.N = shape[0]
return torch.matmul(a, b)

@staticmethod
def backward(ctx, grad_output):
a, b = ctx.saved_tensors
grad_values = grad_b = None
if ctx.needs_input_grad[1]:
grad_a_dense = grad_output.matmul(b.t())
edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :]
grad_values = grad_a_dense.view(-1)[edge_idx]
if ctx.needs_input_grad[3]:
grad_b = a.t().matmul(grad_output)
return None, grad_values, None, grad_b


class SpecialSpmm(nn.Module):
def forward(self, indices, values, shape, b):
return SpecialSpmmFunction.apply(indices, values, shape, b)


class SpGraphAttentionLayer(nn.Module):
"""
Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903
Expand All @@ -66,6 +94,7 @@ def __init__(self, in_features, out_features, dropout, alpha, concat=True):

self.dropout = nn.Dropout(dropout)
self.leakyrelu = nn.LeakyReLU(self.alpha)
self.special_spmm = SpecialSpmm()

def forward(self, input, adj):
N = input.size()[0]
Expand All @@ -78,20 +107,20 @@ def forward(self, input, adj):
# Self-attention on the nodes - Shared attention mechanism
edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t()
# edge: 2*D x E

edge_e = torch.exp(-self.leakyrelu(self.a.mm(edge_h).squeeze()))
assert not torch.isnan(edge_e).any()
# edge_e: 1 x E
e = torch.sparse_coo_tensor(edge, edge_e, torch.Size([N, N]))
# e: N x N
e_rowsum = torch.matmul(e, torch.ones(size=(N, 1)).cuda())
# edge_e: E

e_rowsum = self.special_spmm(edge, edge_e, torch.Size([N, N]), torch.ones(size=(N,1)).cuda())
# e_rowsum: N x 1

edge_e = self.dropout(edge_e)
# edge_e: 1 x E
e = torch.sparse_coo_tensor(edge, edge_e, torch.Size([N, N]))
# e: N x N
h_prime = torch.matmul(e, h)
# edge_e: E

h_prime = self.special_spmm(edge, edge_e, torch.Size([N, N]), h)
assert not torch.isnan(h_prime).any()
# h_prime: N x out

h_prime = h_prime.div(e_rowsum)
# h_prime: N x out
Expand Down
Binary file added output/graph_visualize.pdf
Binary file not shown.
1 change: 0 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def compute_test():
"loss= {:.4f}".format(loss_test.data[0]),
"accuracy= {:.4f}".format(acc_test.data[0]))


# Train model
t_total = time.time()
loss_values = []
Expand Down
62 changes: 62 additions & 0 deletions visualize_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from graphviz import Digraph

import torch
import models

def make_dot(var, params):
""" Produces Graphviz representation of PyTorch autograd graph
Blue nodes are the Variables that require grad, orange are Tensors
saved for backward in torch.autograd.Function
Args:
var: output Variable
params: dict of (name, Variable) to add names to node that
require grad (TODO: make optional)
"""
param_map = {id(v): k for k, v in params.items()}
print(param_map)

node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='12',
ranksep='0.1',
height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()

def size_to_str(size):
return '('+(', ').join(['%d'% v for v in size])+')'

def add_nodes(var):
if var not in seen:
if torch.is_tensor(var):
dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
elif hasattr(var, 'variable'):
u = var.variable
node_name = '%s\n %s' % (param_map.get(id(u)), size_to_str(u.size()))
dot.node(str(id(var)), node_name, fillcolor='lightblue')
else:
dot.node(str(id(var)), str(type(var).__name__))
seen.add(var)
if hasattr(var, 'next_functions'):
for u in var.next_functions:
if u[0] is not None:
dot.edge(str(id(u[0])), str(id(var)))
add_nodes(u[0])
if hasattr(var, 'saved_tensors'):
for t in var.saved_tensors:
dot.edge(str(id(t)), str(id(var)))
add_nodes(t)
add_nodes(var.grad_fn)
return dot

inputs = torch.randn(100, 50).cuda()
adj = torch.randn(100, 100).cuda()
model = models.SpGAT(50, 8, 7, 0.5, 0.01, 3)
model = model.cuda()
y = model(inputs, adj)

g = make_dot(y, model.state_dict())
g.view()

0 comments on commit 3301fd6

Please sign in to comment.