Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug fix] No backpropagation in sparse GAT #13

Merged
merged 2 commits into from
Oct 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ For the branch **master**, the training of the transductive learning on Cora tas

A small note about initial sparse matrix operations of https://github.com/tkipf/pygcn: they have been removed. Therefore, the current model take ~7GB on GRAM.

# Sparse version GAT

We develop a sparse version GAT using pytorch. There are numerically instability because of softmax function. Therefore, you need to initialize carefully. To use sparse version GAT, add flag `--sparse`. The performance of sparse version is similar with tensorflow. On a Titan Xp takes 0.08~0.14 sec.

# Requirements

pyGAT relies on Python 3.5 and PyTorch 0.4.1 (due to torch.sparse_coo_tensor).
Expand Down
45 changes: 37 additions & 8 deletions layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,34 @@ def __repr__(self):
return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


class SpecialSpmmFunction(torch.autograd.Function):
"""Special function for only sparse region backpropataion layer."""
@staticmethod
def forward(ctx, indices, values, shape, b):
assert indices.requires_grad == False
a = torch.sparse_coo_tensor(indices, values, shape)
ctx.save_for_backward(a, b)
ctx.N = shape[0]
return torch.matmul(a, b)

@staticmethod
def backward(ctx, grad_output):
a, b = ctx.saved_tensors
grad_values = grad_b = None
if ctx.needs_input_grad[1]:
grad_a_dense = grad_output.matmul(b.t())
edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :]
grad_values = grad_a_dense.view(-1)[edge_idx]
if ctx.needs_input_grad[3]:
grad_b = a.t().matmul(grad_output)
return None, grad_values, None, grad_b


class SpecialSpmm(nn.Module):
def forward(self, indices, values, shape, b):
return SpecialSpmmFunction.apply(indices, values, shape, b)


class SpGraphAttentionLayer(nn.Module):
"""
Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903
Expand All @@ -66,6 +94,7 @@ def __init__(self, in_features, out_features, dropout, alpha, concat=True):

self.dropout = nn.Dropout(dropout)
self.leakyrelu = nn.LeakyReLU(self.alpha)
self.special_spmm = SpecialSpmm()

def forward(self, input, adj):
N = input.size()[0]
Expand All @@ -78,20 +107,20 @@ def forward(self, input, adj):
# Self-attention on the nodes - Shared attention mechanism
edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t()
# edge: 2*D x E

edge_e = torch.exp(-self.leakyrelu(self.a.mm(edge_h).squeeze()))
assert not torch.isnan(edge_e).any()
# edge_e: 1 x E
e = torch.sparse_coo_tensor(edge, edge_e, torch.Size([N, N]))
# e: N x N
e_rowsum = torch.matmul(e, torch.ones(size=(N, 1)).cuda())
# edge_e: E

e_rowsum = self.special_spmm(edge, edge_e, torch.Size([N, N]), torch.ones(size=(N,1)).cuda())
# e_rowsum: N x 1

edge_e = self.dropout(edge_e)
# edge_e: 1 x E
e = torch.sparse_coo_tensor(edge, edge_e, torch.Size([N, N]))
# e: N x N
h_prime = torch.matmul(e, h)
# edge_e: E

h_prime = self.special_spmm(edge, edge_e, torch.Size([N, N]), h)
assert not torch.isnan(h_prime).any()
# h_prime: N x out

h_prime = h_prime.div(e_rowsum)
# h_prime: N x out
Expand Down
Binary file added output/graph_visualize.pdf
Binary file not shown.
1 change: 0 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def compute_test():
"loss= {:.4f}".format(loss_test.data[0]),
"accuracy= {:.4f}".format(acc_test.data[0]))


# Train model
t_total = time.time()
loss_values = []
Expand Down
62 changes: 62 additions & 0 deletions visualize_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from graphviz import Digraph

import torch
import models

def make_dot(var, params):
""" Produces Graphviz representation of PyTorch autograd graph

Blue nodes are the Variables that require grad, orange are Tensors
saved for backward in torch.autograd.Function

Args:
var: output Variable
params: dict of (name, Variable) to add names to node that
require grad (TODO: make optional)
"""
param_map = {id(v): k for k, v in params.items()}
print(param_map)

node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='12',
ranksep='0.1',
height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()

def size_to_str(size):
return '('+(', ').join(['%d'% v for v in size])+')'

def add_nodes(var):
if var not in seen:
if torch.is_tensor(var):
dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
elif hasattr(var, 'variable'):
u = var.variable
node_name = '%s\n %s' % (param_map.get(id(u)), size_to_str(u.size()))
dot.node(str(id(var)), node_name, fillcolor='lightblue')
else:
dot.node(str(id(var)), str(type(var).__name__))
seen.add(var)
if hasattr(var, 'next_functions'):
for u in var.next_functions:
if u[0] is not None:
dot.edge(str(id(u[0])), str(id(var)))
add_nodes(u[0])
if hasattr(var, 'saved_tensors'):
for t in var.saved_tensors:
dot.edge(str(id(t)), str(id(var)))
add_nodes(t)
add_nodes(var.grad_fn)
return dot

inputs = torch.randn(100, 50).cuda()
adj = torch.randn(100, 100).cuda()
model = models.SpGAT(50, 8, 7, 0.5, 0.01, 3)
model = model.cuda()
y = model(inputs, adj)

g = make_dot(y, model.state_dict())
g.view()