forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial example file for ginipath
- Loading branch information
Ivan Brugere
committed
Jul 10, 2018
1 parent
9219349
commit ef18dab
Showing
2 changed files
with
254 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Mon Jul 9 13:34:38 2018 | ||
@author: ivabruge | ||
""" | ||
|
||
""" | ||
Graph Attention Networks | ||
Paper: https://arxiv.org/abs/1710.10903 | ||
Code: https://github.com/PetarV-/GAT | ||
""" | ||
|
||
import networkx as nx | ||
from dgl.graph import DGLGraph | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import argparse | ||
from dataset import load_data, preprocess_features | ||
import numpy as np | ||
|
||
class NodeReduceModule(nn.Module): | ||
def __init__(self, input_dim, num_hidden, num_heads=3, input_dropout=None, | ||
attention_dropout=None, act=lambda x: F.softmax(F.leaky_relu(x), dim=0)): | ||
super(NodeReduceModule, self).__init__() | ||
self.num_heads = num_heads | ||
self.input_dropout = input_dropout | ||
self.attention_dropout = attention_dropout | ||
self.act = act | ||
self.fc = nn.ModuleList( | ||
[nn.Linear(input_dim, num_hidden, bias=False) | ||
for _ in range(num_heads)]) | ||
self.attention = nn.ModuleList( | ||
[nn.Linear(num_hidden * 2, 1, bias=False) for _ in range(num_heads)]) | ||
|
||
def forward(self, msgs): | ||
src, dst = zip(*msgs) | ||
hu = torch.cat(src, dim=0) # neighbor repr | ||
hv = torch.cat(dst, dim=0) | ||
|
||
msgs_repr = [] | ||
|
||
# iterate for each head | ||
for i in range(self.num_heads): | ||
# calc W*hself and W*hneigh | ||
hvv = self.fc[i](hv) | ||
huu = self.fc[i](hu) | ||
# calculate W*hself||W*hneigh | ||
h = torch.cat((hvv, huu), dim=1) | ||
a = self.act(self.attention[i](h)) | ||
if self.attention_dropout is not None: | ||
a = F.dropout(a, self.attention_dropout) | ||
if self.input_dropout is not None: | ||
hvv = F.dropout(hvv, self.input_dropout) | ||
h = torch.sum(a * hvv, 0, keepdim=True) | ||
msgs_repr.append(h) | ||
|
||
return msgs_repr | ||
|
||
|
||
class NodeUpdateModule(nn.Module): | ||
def __init__(self, residual, fc, act, aggregator): | ||
super(NodeUpdateModule, self).__init__() | ||
|
||
self.residual = residual | ||
self.fc = fc | ||
self.act = act | ||
self.aggregator = aggregator | ||
|
||
def forward(self, node, msgs_repr): | ||
# apply residual connection and activation for each head | ||
for i in range(len(msgs_repr)): | ||
if self.residual: | ||
h = self.fc[i](node['h']) | ||
msgs_repr[i] = msgs_repr[i] + h | ||
if self.act is not None: | ||
msgs_repr[i] = self.act(msgs_repr[i]) | ||
|
||
# aggregate multi-head results | ||
h = self.aggregator(msgs_repr) | ||
c0 = torch.zeros(h.shape) | ||
if node['c'] is None: | ||
c0 = torch.zeros(h.shape) | ||
else: | ||
c0 = node['c'] | ||
if node['h_i'] is None: | ||
h0 = torch.zeros(h.shape) | ||
else: | ||
h0 = node['h_i'] | ||
lstm = nn.LSTM(input_size=h.shape[1], hidden_size=h.shape[1], num_layers=1) | ||
|
||
#add dimension to handle sequential (create sequence of length 1) | ||
h, (h_i, c) = lstm(h.unsqueeze(0), (h0.unsqueeze(0), c0.unsqueeze(0))) | ||
|
||
#remove sequential dim | ||
h = torch.squeeze(h, 0) | ||
h_i = torch.squeeze(h, 0) | ||
c = torch.squeeze(c, 0) | ||
|
||
return {'h': h, 'c':c, 'h_i':h_i} | ||
|
||
class GiniPath(nn.Module): | ||
def __init__(self, num_layers, in_dim, num_hidden, num_classes, num_heads, | ||
activation, input_dropout, attention_dropout, use_residual=False ): | ||
super(GiniPath, self).__init__() | ||
|
||
self.input_dropout = input_dropout | ||
self.reduce_layers = nn.ModuleList() | ||
self.update_layers = nn.ModuleList() | ||
# hidden layers | ||
for i in range(num_layers): | ||
if i == 0: | ||
last_dim = in_dim | ||
residual = False | ||
else: | ||
last_dim = num_hidden * num_heads # because of concat heads | ||
residual = use_residual | ||
self.reduce_layers.append( | ||
NodeReduceModule(last_dim, num_hidden, num_heads, input_dropout, | ||
attention_dropout)) | ||
self.update_layers.append( | ||
NodeUpdateModule(residual, self.reduce_layers[-1].fc, activation, | ||
lambda x: torch.cat(x, 1))) | ||
# projection | ||
self.reduce_layers.append( | ||
NodeReduceModule(num_hidden * num_heads, num_classes, 1, input_dropout, | ||
attention_dropout)) | ||
self.update_layers.append( | ||
NodeUpdateModule(False, self.reduce_layers[-1].fc, None, sum)) | ||
|
||
def forward(self, g): | ||
g.register_message_func(lambda src, dst, edge: (src['h'], dst['h'])) | ||
for reduce_func, update_func in zip(self.reduce_layers, self.update_layers): | ||
# apply dropout | ||
if self.input_dropout is not None: | ||
# TODO (lingfan): use batched dropout once we have better api | ||
# for global manipulation | ||
for n in g.nodes(): | ||
g.node[n]['h'] = F.dropout(g.node[n]['h'], p=self.input_dropout) | ||
g.node[n]['c'] = None | ||
g.node[n]['h_i'] = None | ||
g.register_reduce_func(reduce_func) | ||
g.register_update_func(update_func) | ||
g.update_all() | ||
logits = [g.node[n]['h'] for n in g.nodes()] | ||
logits = torch.cat(logits, dim=0) | ||
return logits | ||
def train(self, g, features, labels, epochs, loss_f=torch.nn.NLLLoss, loss_params={}, optimizer=torch.optim.Adam, optimizer_parameters=None, lr=0.001, ignore=[0], quiet=False): | ||
|
||
labels = torch.LongTensor(labels) | ||
print(labels) | ||
_, labels = torch.max(labels, dim=1) | ||
# convert labels and masks to tensor | ||
|
||
if optimizer_parameters is None: | ||
optimizer_parameters = self.parameters() | ||
optimizer_f = optimizer(optimizer_parameters, lr) | ||
|
||
for epoch in range(args.epochs): | ||
# reset grad | ||
optimizer_f.zero_grad() | ||
|
||
# reset graph states | ||
for n in g.nodes(): | ||
g.node[n]['h'] = torch.FloatTensor(features[n].toarray()) | ||
|
||
# forward | ||
logits = self.forward(g) | ||
|
||
loss = loss_f(**loss_params) | ||
idx = [i for i, a in enumerate(labels) if a not in ignore] | ||
logits = logits[idx, :] | ||
labels = labels[idx] | ||
out = loss(logits, labels) | ||
|
||
if not quiet: | ||
print("epoch {} loss: {}".format(epoch, out)) | ||
|
||
out.backward() | ||
optimizer_f.step() | ||
|
||
def main(args): | ||
# dropout parameters | ||
input_dropout = 0.2 | ||
attention_dropout = 0.2 | ||
|
||
# load and preprocess dataset | ||
adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask = load_data(args.dataset) | ||
features = preprocess_features(features) | ||
|
||
# initialize graph | ||
g = DGLGraph(adj) | ||
|
||
# create model | ||
model = GiniPath(args.num_layers, | ||
features.shape[1], | ||
args.num_hidden, | ||
y_train.shape[1], | ||
args.num_heads, | ||
F.elu, | ||
input_dropout, | ||
attention_dropout, | ||
args.residual) | ||
model.train(g, features, y_train, epochs=10) | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='GAT') | ||
parser.add_argument("--dataset", type=str, required=True, | ||
help="dataset name") | ||
parser.add_argument("--epochs", type=int, default=10, | ||
help="training epoch") | ||
parser.add_argument("--num-heads", type=int, default=3, | ||
help="number of attentional heads to use") | ||
parser.add_argument("--num-layers", type=int, default=1, | ||
help="number of hidden layers") | ||
parser.add_argument("--num-hidden", type=int, default=8, | ||
help="size of hidden units") | ||
parser.add_argument("--residual", action="store_true", | ||
help="use residual connection") | ||
parser.add_argument("--lr", type=float, default=0.001, | ||
help="learning rate") | ||
args = parser.parse_args() | ||
print(args) | ||
|
||
main(args) |