Skip to content

Commit

Permalink
add MultoHeadAttention
Browse files Browse the repository at this point in the history
  • Loading branch information
acproject committed Nov 27, 2020
1 parent f4d5d60 commit 4995a57
Show file tree
Hide file tree
Showing 6 changed files with 371 additions and 0 deletions.
212 changes: 212 additions & 0 deletions modules/act.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
from .attention import *
from .layers import *
from .functions import *
from .embedding import *
import torch as th
import dgl.function as fn
import torch.nn.init as INIT

class UEncoder(nn.Module):
def __init__(self, layer):
super(UEncoder, self).__init__()
self.layer = layer
self.norm = LayerNorm(layer.size)

def pre_func(self, fields='qkv'):
layer = self.layer
def func(nodes):
x = nodes.data['x']
norm_x = layer.sublayer[0].norm(x)
return layer.self_attn.get(norm_x, fields=fields)
return func

def post_func(self):
layer = self.layer
def func(nodes):
x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z']
o = layer.self_attn.get_o(wv / z)
x = x + layer.sublayer[0].dropout(o)
x = layer.sublayer[1](x, layer.feed_forward)
return {'x': x}
return func

class UDecoder(nn.Module):
def __init__(self, layer):
super(UDecoder, self).__init__()
self.layer = layer
self.norm = LayerNorm(layer.size)

def pre_func(self, fields='qkv', l=0):
layer = self.layer
def func(nodes):
x = nodes.data['x']
if fields == 'kv':
norm_x = x
else:
norm_x = layer.sublayer[1].norm(x)
return layer.self_attn.get(norm_x, fields)
return func

def post_func(self, l=0):
layer = self.layer
def func(nodes):
x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z']
o = layer.self_attn.get_o(wv / z)
x = x + layer.sublayer[1].dropout(o)
if l == 1:
x = layer.sublayer[2](x, layer.feed_forward)
return {'x':x}
return func

class HaltingUnit(nn.Module):
halting_bias_init = 1.0
def __init__(self, dim_model):
super(HaltingUnit, self).__init__()
self.linear = nn.Linear(dim_model, 1)
self.norm = LayerNorm(dim_model)
INIT.constant_(self.linear.bias, self.halting_bias_init)

def forward(self, x):
return th.sigmoid(self.linear(self.norm(x)))

class UTransformer(nn.Module):
"Universal Transformer(https://arxiv.org/pdf/1807.03819.pdf) with ACT(https://arxiv.org/pdf/1603.08983.pdf)."
MAX_DEPTH = 8
thres = 0.99
act_loss_weight = 0.01
def __init__(self, encoder ,decoder, src_embed, tgt_embed, pos_enc,
time_enc, genertor, h, d_k):
super(UTransformer, self).__init__()
self.encoder, self.decoder = encoder, decoder
self.src_embed, self.tgt_embed = src_embed, tgt_embed
self.pos_enc, self.time_enc = pos_enc, time_enc
self.halt_enc = HaltingUnit(h * d_k)
self.halt_dec = HaltingUnit(h * d_k)
self.generator = genertor
self.h, self.d_k = h, d_k
self.reset_stat()
def reset_stat(self):
self.stat = [0] * (self.MAX_DEPTH + 1)

def step_forward(self, nodes):
x = nodes.data['x']
step = nodes.data['step']
pos = nodes.data['pos']
return {'x': self.pos_enc.dropout(x + self.pos_enc(pos.view(-1)) +
self.time_enc(step.view(-1))),
'step': step + 1}

def halt_and_accum(self, name, end=False):
"field: 'enc' or 'dec' "
halt = self.halt_enc if name == 'enc' else self.halt_dec
thres = self.thres
def func(nodes):
p = halt(nodes.data['x'])
sum_p = nodes.data['sum_p'] + p
active = (sum_p < thres) & (1 -end)
_continue = active.float()
r = nodes.data['r'] * (1 - _continue) + (1 - sum_p) * _continue
s = nodes.data['s'] + ((1 - _continue) * r + _continue * p) * nodes.data['x']
return {'p': p, 'sum_p':sum_p, 'r':r, 's':s, 'active': active}
return func

def propagate_attention(self, g, eids):
# Compute attention score
g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids)
# send weighted values to target nodes
g.send_and_recv(eids,
[fn.src_mul_edge('v', 'score', 'v'), fn.copy_edge('score', 'score')],
[fn.sum('v', 'wv'), fn.sum('score', 'z')])

def update_graph(self, g, eids, pre_pairs, post_pairs):
"Update the node states and edge states of the graph."
# Pre-compute queries and key-value pairs.
for pre_func, nids in pre_pairs:
g.apply_nodes(pre_func, nids)
self.propagate_attention(g, eids)
# Further calculation after attention mechanism
for post_func, nids in post_pairs:
g.apply_nodes(post_func, nids)

def forward(self, graph):
g = graph.g
N, E = graph.n_nodes, graph.n_edges
nids, eids = graph.nids, graph.eids

# embed & pos
g.nodes[nids['enc']].data['x'] = self.src_embed(graph.src[0])
g.nodes[nids['dec']].data['x'] = self.tgt_embed(graph.tgt[0])
g.nodes[nids['enc']].data['pos'] = graph.src[1]
g.nodes[nids['dec']].data['pos'] = graph.tgt[1]

# init step
device = next(self.parameters()).device
g.ndata['s'] = th.zeros(N, self.h * self.d_k, dtype=th.float, device=device)
g.ndata['p'] = th.zeros(N, 1, dtype=th.float, device=device)
g.ndata['r'] = th.ones(N, 1, dtype=th.float, device=device)
g.ndata['sum_p'] = th.zeros(N, 1, dtype=th.float, device=device)
g.ndata['step'] = th.zeros(N, 1, dtype=th.long, device=device)
g.ndata['active'] = th.ones(N, 1, dtype=th.uint8, device=device)

for step in range(self.MAX_DEPTH):
pre_func = self.encoder.per_func('qkv')
post_func = self.encoder.post_func()
nodes = g.filter_nodes(lambda v: v.data['active'].view(-1), nids['enc'])
if len(nodes) == 0: break
edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['ee'])
end = step == self.MAX_DEPTH - 1
self.update_graph(g, edges,
[(self.step_forward, nodes), (pre_func, nodes)],
[(post_func, nodes), (self.halt_and_accum('enc', end), nodes)])

g.nodes[nids['enc']].data['x'] = self.encoder.norm(g.nodes[nids['enc']].data['s'])

for step in range(self.MAX_DEPTH):
pre_func = self.decoder.pre_func('qkv')
post_func = self.decoder.post_func()
nodes = g.filter_nodes(lambda v: v.data['active'].view(-1), nids['dec'])
if len(nodes) == 0: break
edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['dd'])
self.update_graph(g,edges,
[(self.step_forward, nodes), (pre_func, nodes)],
[post_func, nodes])
pre_q = self.decoder.pre_func('q', 1)
pre_kv = self.decoder.pre_func('kv', 1)
post_func = self.decoder.post_func(1)
nodes_e = nids['enc']
edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['ed'])
end = step == self.MAX_DEPTH -1
self.update_graph(g, edges,
[(pre_q, nodes), (pre_kv, nodes_e)],
[post_func, nodes], (self.halt_and_accum('dec', end), nodes))
g.nodes[nids['dec']].data['x'] = self.decoder.norm(g.nodes[nids['dec']].data['s'])
act_loos = th.mean(g.ndata['r']) # ACT loss

self.stat[0] += N
for step in range(1, self.MAX_DEPTH + 1):
self.stat[step] += th.sum(g.ndata['step'] >= step).item()

return self.generator(g.ndata['x'][nids['dec']]), act_loos * self.act_loss_weight

def infer(self, *args, **kwargs):
raise NotImplementedError

def make_universal_model(src_vocab, tgt_vocab, dim_model=512, dim_ff=2048, h=8, dropout=0.1):
c = copy.deepcopy
attn = MultiHeadAttention(h, dim_model)
ff = PositionWiseFeedForWard(dim_model, dim_ff)
pos_enc = PositionalEncoding(dim_model, dropout)
time_enc = PositionalEncoding(dim_model, dropout)
encoder = UEncoder(EncoderLayer((dim_model), c(attn), c(ff), dropout))
decoder = UDecoder(DecoderLayer((dim_model), c(attn), c(attn), c(ff), dropout))
src_embed = Embeddings(src_vocab, dim_model)
tgt_embed = Embeddings(tgt_vocab, dim_model)
generator = Generator(dim_model, tgt_vocab)
model = UTransformer(
encoder, decoder, src_embed, tgt_embed, pos_enc, time_enc, generator, h, dim_model // h)
# xavier init
for p in model.parameters():
if p.dim() > 1:
INIT.xavier_uniform_(p)
return model
35 changes: 35 additions & 0 deletions modules/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch as th
import torch.nn as nn
import numpy as np
from .layers import clones

class MultiHeadAttention(nn.Module):
'Multi-Head Attention'
def __init__(self, h, dim_model):
'''
:param h: number of heads
:param dim_model: hidden dimension
'''
super(MultiHeadAttention, self).__init__()
self.d_k = dim_model // h
self.h = h
# W_q, W_k, W_v, W_o
self.linears = clones(nn.Linear(dim_model, dim_model, bias=False), 4)

def get(self, x, fields='qkv'):
'Return a dict of queries / keys / values.'
batch_size = x.shape[0]
ret = {}
if 'q' in fields:
ret['q'] = self.linears[0](x).view(batch_size, self.h, self.d_k)
if 'k' in fields:
ret['k'] = self.linears[1](x).view(batch_size, self.h, self.d_k)
if 'v' in fields:
ret['v'] = self.linears[2](x).view(batch_size, self.h, self.d_k)
return ret

def get_o(self, x):
'Get output of the multi-head attention'
batch_size = x.shape[0]
return self.linears[3](x.view(batch_size, -1))
1 change: 1 addition & 0 deletions modules/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
VIZ_IDX = 3
31 changes: 31 additions & 0 deletions modules/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch as th
import torch.nn as nn
import numpy as np

class PositionalEncoding(nn.Module):
'Position Encoding module'
def __init__(self, dim_model, dropout, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(dropout)
# Compute the positional encodings once in log space.
pe = th.zeros(max_len, dim_model, dtype=th.float)
position = th.arange(0, max_len, dtype=th.float).unsqueeze(1)
div_term = th.exp(th.arange(0, dim_model, 2, dtype=th.float)*
-(np.log(10000.0) / dim_model))
pe[:, 0::2] = th.sin(position * div_term)
pe[:, 1::2] = th.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe) # Not a parameter but should be in state_dict

def forward(self, pos):
return th.index_select(self.pe, 1 ,pos).squeeze(0)

class Embeddings(nn.Module):
'Word Embedding module'
def __init__(self, vocab_size, dim_model):
super(Embeddings, self).__init__()
self.lut =nn.Embedding(vocab_size, dim_model)
self.dim_model = dim_model

def forward(self, x):
return self.lut(x) * np.sqrt(self.dim_model)
24 changes: 24 additions & 0 deletions modules/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch as th

def src_dot_dst(src_field, dst_field, out_field):
'''
This function serves as a surrogate for `src_dot_dst` built-in apply_edge function
:param src_field:
:param dst_field:
:param out_field:
:return:
'''
def func(edges):
return {out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(-1, keepdim=True)}
return func

def scaled_exp(field, c):
'''
This function applies $exp(x / c)$ for input $x$, which is required by *Scaled Dot-Product Attention* mentioned in the paper.
:param field:
:param c:
:return:
'''
def func(edges):
return {field: th.exp((edges.data[field] / c).clamp(-10, 10))}
return func
68 changes: 68 additions & 0 deletions modules/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch as th
import torch.nn as nn
from torch.nn import LayerNorm

class Generator(nn.Module):
'''
Generate next token from the representation. This part is separated from the decoder, mostly for the convenience of sharing weight between embedding and generator.
log(softmax(Wx + b))
'''
def __init__(self, dim_model, vocab_size):
super(Generator ,self).__init__()
self.proj = nn.Linear(dim_model, vocab_size)


def forward(self, x):
return th.log_softmax(self.proj(x), dim=-1)


class SubLayerWrapper(nn.Module):
'''
The module wraps normalization, dropout, residual connection into one equation:
sublayerwrapper(sublayer)(x) = x + dropout(sublayer(norm(x)))
'''

def __init__(self, size, dropout):
super(SubLayerWrapper, self).__init__()
self.norm = LayerNorm(size)
self.dropout = nn.Dropout(dropout)

def forward(self, x, sublayer):
return x + self.dropout(sublayer(self.norm(x)))

class PositionWiseFeedForWard(nn.Module):
'''
This module implements feed-forward network(after the Multi-Head Network) equation:
FFN(x) = max(0, x @ W_1 + b_1) @ W_2 + b_2
'''
def __init__(self, dim_model, dim_ff, dropout=0.1):
super(PositionWiseFeedForWard, self).__init__()
self.w_1 = nn.Linear(dim_model, dim_ff)
self.w_2 = nn.Linear(dim_ff, dim_model)
self.dropout = nn.Dropout(dropout)

def forward(self, x):
return self.w_2(self.dropout(th.relu(self.w_1(x))))

import copy
def clones(module, k):
return nn.ModuleList(
copy.deepcopy(module) for _ in range(k))

class EncoderLayer(nn.Module):
def __init__(self, size, self_attn, feed_forward, dropout):
super(EncoderLayer, self).__init__()
self.size = size
self.self_attn = self_attn # (key, query, value, mask)
self.feed_forward = feed_forward
self.sublayer = clones(SubLayerWrapper(size, dropout), 2)

class DecoderLayer(nn.Module):
def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
super(DecoderLayer, self).__init__()
self.size = size
self.self_attn = self_attn # (key, query, value, mask)
self.src_attn = src_attn
self.feed_forward = feed_forward
self.sublayer = clones(SubLayerWrapper(size, dropout), 3)

0 comments on commit 4995a57

Please sign in to comment.