Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
acproject committed Nov 30, 2020
1 parent 2117730 commit 953d175
Show file tree
Hide file tree
Showing 18 changed files with 967 additions and 43 deletions.
257 changes: 244 additions & 13 deletions GNN/Model/MultiHeadAttention.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
import torch
import torch.nn as nn
import torch as th
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

from modules.layers import *
from modules.functions import *
from modules.embedding import *
from modules.viz import att_animation, get_attention_map
from optims import NoamOpt
from loss import LabelSmoothing, SimpleLossCompute
from dataset import get_dataset, GraphPool
from modules.DGL.transformer.layers import *
from modules.DGL.transformer.functions import *
from modules.DGL.transformer.embedding import *
from modules.DGL.transformer.optims import *

import dgl.function as fn
import torch.nn.init as INIT
from torch.nn import LayerNorm


class MultiHeadAttention(nn.Module):
Expand All @@ -30,3 +21,243 @@ def __init__(self, h, dim_model):
self.h = h
# W_q, W_k, W_v, W_o
self.linears = clones(nn.Linear(dim_model, dim_model), 4)


def get(self, x, fields='qkv'):
'Return a dict if queries / keys / values.'
batch_size = x.shape[0]
ret = {}
if 'q' in fields:
ret['q'] = self.linears[0](x).view(batch_size, self.h, self.d_k)
if 'k' in fields:
ret['k'] = self.linears[1](x).view(batch_size, self.h, self.d_k)
if 'v' in fields:
ret['v'] = self.linears[2](x).view(batch_size, self.h, self.d_k)
return ret

def get_o(self, x):
'get output of the multi-head attention'
batch_size = x.shape[0]
return self.linears[3](x.view(batch_size, -1))

def message_func(edges):
return {
'score': ((edges.src['k'] * edges.dst['q']).sum(-1,
keepdim=True)), 'v':edges.src['v']}

import torch as th
import torch.nn.functional as F

def reduce_func(nodes, d_k=64):
v = nodes.mailbox['v']
att = F.softmax(nodes.mailbox['score'] / th.sqrt(d_k), 1)
return {'dx': (att * v).sum(1)}

import functools.partial as partial
def naive_propagate_attention(self, g, eids):
g.send_and_recv(eids, message_func, partial(reduce_func, d_k=self.d_k))
def src_dot_dst(src_field, dst_field, out_field):
def func(edges):
return {out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(-1, keepdim=True)}

return func

def scaled_exp(field, scale_constant):
def func(edges):
# clamp for softmax numerical stability
return {field: th.exp((edges.data[field] / scale_constant).clamp(-5, 5))}

return func


def propagate_attention(self, g, eids):
# Compute attention score
g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)))
# Update node state
g.send_and_recv(eids,
[fn.src_mul_edge('v', 'score', 'v'), fn.copy_edge('score', 'score')],
[fn.sum('v', 'wv'), fn.sum('score', 'z')])

class Encoder(nn.Module):
def __init__(self, layer, N):
super(Encoder, self).__init__()
self.N = N
self.layers = clones(layer, N)
self.norm = LayerNorm(layer.size)

def pre_func(self, i, fields='qkv'):
layer = self.layers[i]
def func(nodes):
x = nodes.data['x']
norm_x = layer.sublayer[0].norm(x)
return layer.self_attn.get(norm_x, fields=fields)
return func

def post_func(self, i):
layer = self.layers[i]
def func(nodes):
x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z']
o = layer.self_attn.get_o(wv / z)
x = x + layer.sublayer[0].dropout(o)
x = layer.sublayer[1](x, layer.feed_forward)
return {'x': x if i < self.N - 1 else self.norm(x)}
return func

class Decoder(nn.Module):
def __init__(self, layer, N):
super(Decoder, self).__init__()
self.N = N
self.layers = clones(layer, N)
self.norm = LayerNorm(layer.size)

def pre_func(self, i, fields='qkv', l=0):
layer = self.layers[i]
def func(nodes):
x = nodes.data['x']
if fields == 'kv':
norm_x = x # In enc-dec attention, x has already been normalized.
else:
norm_x = layer.sublayer[l].norm(x)
return layer.self_attn.get(norm_x, fields)
return func

def post_func(self, i, l=0):
layer = self.layers[i]
def func(nodes):
x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z']
o = layer.self_attn.get_o(wv / z)
x = x + layer.sublayer[l].dropout(o)
if l == 1:
x = layer.sublayer[2](x, layer.feed_forward)
return {'x': x if i < self.N - 1 else self.norm(x)}
return func

class Transformer(nn.Module):
def __init__(self, encoder, decoder, src_embed, tgt_embed, pos_enc, generator, h, d_k):
super(Transformer, self).__init__()
self.encoder, self.decoder = encoder, decoder
self.src_embed, self.tgt_embed = src_embed, tgt_embed
self.pos_enc = pos_enc
self.generator = generator
self.h, self.d_k = h, d_k

def propagate_attention(self, g, eids):
# Compute attention score
g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)))
# Send weighted values to target nodes
g.send_and_recv(eids,
[fn.src_mul_edge('v', 'score', 'v'), fn.copy_edge('score', 'score')],
[fn.sum('v', 'wv'), fn.sum('score', 'z')])

def update_graph(self, g, eids, pre_pairs, post_pairs):
"Update the node states and edge states of the graph."

# Pre-compute queries and key-value pairs.
for pre_func, nids in pre_pairs:
g.apply_nodes(pre_func, nids)
self.propagate_attention(g, eids)
# Further calculation after attention mechanism
for post_func, nids in post_pairs:
g.apply_nodes(post_func, nids)

def forward(self, graph):
g = graph.g
nids, eids = graph.nids, graph.eids

# Word Embedding and Position Embedding
src_embed, src_pos = self.src_embed(graph.src[0]), self.pos_enc(graph.src[1])
tgt_embed, tgt_pos = self.tgt_embed(graph.tgt[0]), self.pos_enc(graph.tgt[1])
g.nodes[nids['enc']].data['x'] = self.pos_enc.dropout(src_embed + src_pos)
g.nodes[nids['dec']].data['x'] = self.pos_enc.dropout(tgt_embed + tgt_pos)

for i in range(self.encoder.N):
# Step 1: Encoder Self-attention
pre_func = self.encoder.pre_func(i, 'qkv')
post_func = self.encoder.post_func(i)
nodes, edges = nids['enc'], eids['ee']
self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)])

for i in range(self.decoder.N):
# Step 2: Dncoder Self-attention
pre_func = self.decoder.pre_func(i, 'qkv')
post_func = self.decoder.post_func(i)
nodes, edges = nids['dec'], eids['dd']
self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)])
# Step 3: Encoder-Decoder attention
pre_q = self.decoder.pre_func(i, 'q', 1)
pre_kv = self.decoder.pre_func(i, 'kv', 1)
post_func = self.decoder.post_func(i, 1)
nodes_e, nodes_d, edges = nids['enc'], nids['dec'], eids['ed']
self.update_graph(g, edges, [(pre_q, nodes_d), (pre_kv, nodes_e)], [(post_func, nodes_d)])

return self.generator(g.ndata['x'][nids['dec']])

graph_pool = GraphPool()

data_iter = dataset(graph_pool, mode='train', batch_size=1, devices=devices)
for graph in data_iter:
print(graph.nids['enc']) # encoder node ids
print(graph.nids['dec']) # decoder node ids
print(graph.eids['ee']) # encoder-encoder edge ids
print(graph.eids['ed']) # encoder-decoder edge ids
print(graph.eids['dd']) # decoder-decoder edge ids
print(graph.src[0]) # Input word index list
print(graph.src[1]) # Input positions
print(graph.tgt[0]) # Output word index list
print(graph.tgt[1]) # Ouptut positions
break

from tqdm import tqdm
import torch as th
import numpy as np


from modules import make_model
from optims import NoamOpt
from dgl.contrib.transformer import get_dataset, GraphPool

def run_epoch(data_iter, model, loss_compute, is_train=True):
for i, g in tqdm(enumerate(data_iter)):
with th.set_grad_enabled(is_train):
output = model(g)
loss = loss_compute(output, g.tgt_y, g.n_tokens)
print('average loss: {}'.format(loss_compute.avg_loss))
print('accuracy: {}'.format(loss_compute.accuracy))

N = 1
batch_size = 128
devices = ['cuda' if th.cuda.is_available() else 'cpu']

dataset = get_dataset("copy")
V = dataset.vocab_size
criterion = LabelSmoothing(V, padding_idx=dataset.pad_id, smoothing=0.1)
dim_model = 128

# Create model
model = make_model(V, V, N=N, dim_model=128, dim_ff=128, h=1)

# Sharing weights between Encoder & Decoder
model.src_embed.lut.weight = model.tgt_embed.lut.weight
model.generator.proj.weight = model.tgt_embed.lut.weight

model, criterion = model.to(devices[0]), criterion.to(devices[0])
model_opt = NoamOpt(dim_model, 1, 400,
th.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9))
loss_compute = SimpleLossCompute

att_maps = []
for epoch in range(4):
train_iter = dataset(graph_pool, mode='train', batch_size=batch_size, devices=devices)
valid_iter = dataset(graph_pool, mode='valid', batch_size=batch_size, devices=devices)
print('Epoch: {} Training...'.format(epoch))
model.train(True)
run_epoch(train_iter, model,
loss_compute(criterion, model_opt), is_train=True)
print('Epoch: {} Evaluating...'.format(epoch))
model.att_weight_map = None
model.eval()
run_epoch(valid_iter, model,
loss_compute(criterion, None), is_train=False)
att_maps.append(model.att_weight_map)
Empty file added modules/DGL/__init__.py
Empty file.
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Empty file.
62 changes: 62 additions & 0 deletions modules/DGL/transformer/dataset/fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
class Vocab:
def __init__(self, init_token=None, eos_toekn=None, pad_token=None, unk_token=None):
self.init_token = init_token
self.eos_token = eos_toekn
self.pad_token = pad_token
self.unk_token = unk_token
self.vocab_lst = []
self.vocab_dict = None

def load(self ,path):
if self.init_token is not None:
self.vocab_lst.append(self.init_token)
if self.eos_token is not None:
self.vocab_lst.append(self.eos_token)
if self.pad_token is not None:
self.vocab_lst.append(self.pad_token)
if self.unk_token is not None:
self.vocab_lst.append(self.unk_token)
with open(path, 'r', encoding='utf-8') as f:
for token in f.readlines():
token = token.strip()
self.vocab_lst.append(token)

self.vocab_dict = {
v: k for k, v in enumerate(self.vocab_lst)
}

def __len__(self):
return len(self.vocab_lst)


def __getitem__(self, key):
if isinstance(key, str):
if key in self.vocab_dict:
return self.vocab_dict[key]
else:
return self.vocab_dict[self.unk_token]
else:
return self.vocab_lst[key]

class Field:
def __init__(self, vocab, preprocessing=None, postprocessing=None):
self.vocab = vocab
self.preprocessing = preprocessing
self.postprocessing = postprocessing

def preprocess(self, x):
if self.preprocessing is not None:
return self.preprocessing(x)
return x

def postprocess(self, x):
if self.postprocessing is not None:
return self.postprocessing(x)
return x

def numericalize(self, x):
return [self.vocab[token] for token in x]

def __call__(self, x):
return self.postprocess(
self.numericalize(self.preprocess(x)))
Loading

0 comments on commit 953d175

Please sign in to comment.