-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
371 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
from .attention import * | ||
from .layers import * | ||
from .functions import * | ||
from .embedding import * | ||
import torch as th | ||
import dgl.function as fn | ||
import torch.nn.init as INIT | ||
|
||
class UEncoder(nn.Module): | ||
def __init__(self, layer): | ||
super(UEncoder, self).__init__() | ||
self.layer = layer | ||
self.norm = LayerNorm(layer.size) | ||
|
||
def pre_func(self, fields='qkv'): | ||
layer = self.layer | ||
def func(nodes): | ||
x = nodes.data['x'] | ||
norm_x = layer.sublayer[0].norm(x) | ||
return layer.self_attn.get(norm_x, fields=fields) | ||
return func | ||
|
||
def post_func(self): | ||
layer = self.layer | ||
def func(nodes): | ||
x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z'] | ||
o = layer.self_attn.get_o(wv / z) | ||
x = x + layer.sublayer[0].dropout(o) | ||
x = layer.sublayer[1](x, layer.feed_forward) | ||
return {'x': x} | ||
return func | ||
|
||
class UDecoder(nn.Module): | ||
def __init__(self, layer): | ||
super(UDecoder, self).__init__() | ||
self.layer = layer | ||
self.norm = LayerNorm(layer.size) | ||
|
||
def pre_func(self, fields='qkv', l=0): | ||
layer = self.layer | ||
def func(nodes): | ||
x = nodes.data['x'] | ||
if fields == 'kv': | ||
norm_x = x | ||
else: | ||
norm_x = layer.sublayer[1].norm(x) | ||
return layer.self_attn.get(norm_x, fields) | ||
return func | ||
|
||
def post_func(self, l=0): | ||
layer = self.layer | ||
def func(nodes): | ||
x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z'] | ||
o = layer.self_attn.get_o(wv / z) | ||
x = x + layer.sublayer[1].dropout(o) | ||
if l == 1: | ||
x = layer.sublayer[2](x, layer.feed_forward) | ||
return {'x':x} | ||
return func | ||
|
||
class HaltingUnit(nn.Module): | ||
halting_bias_init = 1.0 | ||
def __init__(self, dim_model): | ||
super(HaltingUnit, self).__init__() | ||
self.linear = nn.Linear(dim_model, 1) | ||
self.norm = LayerNorm(dim_model) | ||
INIT.constant_(self.linear.bias, self.halting_bias_init) | ||
|
||
def forward(self, x): | ||
return th.sigmoid(self.linear(self.norm(x))) | ||
|
||
class UTransformer(nn.Module): | ||
"Universal Transformer(https://arxiv.org/pdf/1807.03819.pdf) with ACT(https://arxiv.org/pdf/1603.08983.pdf)." | ||
MAX_DEPTH = 8 | ||
thres = 0.99 | ||
act_loss_weight = 0.01 | ||
def __init__(self, encoder ,decoder, src_embed, tgt_embed, pos_enc, | ||
time_enc, genertor, h, d_k): | ||
super(UTransformer, self).__init__() | ||
self.encoder, self.decoder = encoder, decoder | ||
self.src_embed, self.tgt_embed = src_embed, tgt_embed | ||
self.pos_enc, self.time_enc = pos_enc, time_enc | ||
self.halt_enc = HaltingUnit(h * d_k) | ||
self.halt_dec = HaltingUnit(h * d_k) | ||
self.generator = genertor | ||
self.h, self.d_k = h, d_k | ||
self.reset_stat() | ||
def reset_stat(self): | ||
self.stat = [0] * (self.MAX_DEPTH + 1) | ||
|
||
def step_forward(self, nodes): | ||
x = nodes.data['x'] | ||
step = nodes.data['step'] | ||
pos = nodes.data['pos'] | ||
return {'x': self.pos_enc.dropout(x + self.pos_enc(pos.view(-1)) + | ||
self.time_enc(step.view(-1))), | ||
'step': step + 1} | ||
|
||
def halt_and_accum(self, name, end=False): | ||
"field: 'enc' or 'dec' " | ||
halt = self.halt_enc if name == 'enc' else self.halt_dec | ||
thres = self.thres | ||
def func(nodes): | ||
p = halt(nodes.data['x']) | ||
sum_p = nodes.data['sum_p'] + p | ||
active = (sum_p < thres) & (1 -end) | ||
_continue = active.float() | ||
r = nodes.data['r'] * (1 - _continue) + (1 - sum_p) * _continue | ||
s = nodes.data['s'] + ((1 - _continue) * r + _continue * p) * nodes.data['x'] | ||
return {'p': p, 'sum_p':sum_p, 'r':r, 's':s, 'active': active} | ||
return func | ||
|
||
def propagate_attention(self, g, eids): | ||
# Compute attention score | ||
g.apply_edges(src_dot_dst('k', 'q', 'score'), eids) | ||
g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids) | ||
# send weighted values to target nodes | ||
g.send_and_recv(eids, | ||
[fn.src_mul_edge('v', 'score', 'v'), fn.copy_edge('score', 'score')], | ||
[fn.sum('v', 'wv'), fn.sum('score', 'z')]) | ||
|
||
def update_graph(self, g, eids, pre_pairs, post_pairs): | ||
"Update the node states and edge states of the graph." | ||
# Pre-compute queries and key-value pairs. | ||
for pre_func, nids in pre_pairs: | ||
g.apply_nodes(pre_func, nids) | ||
self.propagate_attention(g, eids) | ||
# Further calculation after attention mechanism | ||
for post_func, nids in post_pairs: | ||
g.apply_nodes(post_func, nids) | ||
|
||
def forward(self, graph): | ||
g = graph.g | ||
N, E = graph.n_nodes, graph.n_edges | ||
nids, eids = graph.nids, graph.eids | ||
|
||
# embed & pos | ||
g.nodes[nids['enc']].data['x'] = self.src_embed(graph.src[0]) | ||
g.nodes[nids['dec']].data['x'] = self.tgt_embed(graph.tgt[0]) | ||
g.nodes[nids['enc']].data['pos'] = graph.src[1] | ||
g.nodes[nids['dec']].data['pos'] = graph.tgt[1] | ||
|
||
# init step | ||
device = next(self.parameters()).device | ||
g.ndata['s'] = th.zeros(N, self.h * self.d_k, dtype=th.float, device=device) | ||
g.ndata['p'] = th.zeros(N, 1, dtype=th.float, device=device) | ||
g.ndata['r'] = th.ones(N, 1, dtype=th.float, device=device) | ||
g.ndata['sum_p'] = th.zeros(N, 1, dtype=th.float, device=device) | ||
g.ndata['step'] = th.zeros(N, 1, dtype=th.long, device=device) | ||
g.ndata['active'] = th.ones(N, 1, dtype=th.uint8, device=device) | ||
|
||
for step in range(self.MAX_DEPTH): | ||
pre_func = self.encoder.per_func('qkv') | ||
post_func = self.encoder.post_func() | ||
nodes = g.filter_nodes(lambda v: v.data['active'].view(-1), nids['enc']) | ||
if len(nodes) == 0: break | ||
edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['ee']) | ||
end = step == self.MAX_DEPTH - 1 | ||
self.update_graph(g, edges, | ||
[(self.step_forward, nodes), (pre_func, nodes)], | ||
[(post_func, nodes), (self.halt_and_accum('enc', end), nodes)]) | ||
|
||
g.nodes[nids['enc']].data['x'] = self.encoder.norm(g.nodes[nids['enc']].data['s']) | ||
|
||
for step in range(self.MAX_DEPTH): | ||
pre_func = self.decoder.pre_func('qkv') | ||
post_func = self.decoder.post_func() | ||
nodes = g.filter_nodes(lambda v: v.data['active'].view(-1), nids['dec']) | ||
if len(nodes) == 0: break | ||
edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['dd']) | ||
self.update_graph(g,edges, | ||
[(self.step_forward, nodes), (pre_func, nodes)], | ||
[post_func, nodes]) | ||
pre_q = self.decoder.pre_func('q', 1) | ||
pre_kv = self.decoder.pre_func('kv', 1) | ||
post_func = self.decoder.post_func(1) | ||
nodes_e = nids['enc'] | ||
edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['ed']) | ||
end = step == self.MAX_DEPTH -1 | ||
self.update_graph(g, edges, | ||
[(pre_q, nodes), (pre_kv, nodes_e)], | ||
[post_func, nodes], (self.halt_and_accum('dec', end), nodes)) | ||
g.nodes[nids['dec']].data['x'] = self.decoder.norm(g.nodes[nids['dec']].data['s']) | ||
act_loos = th.mean(g.ndata['r']) # ACT loss | ||
|
||
self.stat[0] += N | ||
for step in range(1, self.MAX_DEPTH + 1): | ||
self.stat[step] += th.sum(g.ndata['step'] >= step).item() | ||
|
||
return self.generator(g.ndata['x'][nids['dec']]), act_loos * self.act_loss_weight | ||
|
||
def infer(self, *args, **kwargs): | ||
raise NotImplementedError | ||
|
||
def make_universal_model(src_vocab, tgt_vocab, dim_model=512, dim_ff=2048, h=8, dropout=0.1): | ||
c = copy.deepcopy | ||
attn = MultiHeadAttention(h, dim_model) | ||
ff = PositionWiseFeedForWard(dim_model, dim_ff) | ||
pos_enc = PositionalEncoding(dim_model, dropout) | ||
time_enc = PositionalEncoding(dim_model, dropout) | ||
encoder = UEncoder(EncoderLayer((dim_model), c(attn), c(ff), dropout)) | ||
decoder = UDecoder(DecoderLayer((dim_model), c(attn), c(attn), c(ff), dropout)) | ||
src_embed = Embeddings(src_vocab, dim_model) | ||
tgt_embed = Embeddings(tgt_vocab, dim_model) | ||
generator = Generator(dim_model, tgt_vocab) | ||
model = UTransformer( | ||
encoder, decoder, src_embed, tgt_embed, pos_enc, time_enc, generator, h, dim_model // h) | ||
# xavier init | ||
for p in model.parameters(): | ||
if p.dim() > 1: | ||
INIT.xavier_uniform_(p) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import torch as th | ||
import torch.nn as nn | ||
import numpy as np | ||
from .layers import clones | ||
|
||
class MultiHeadAttention(nn.Module): | ||
'Multi-Head Attention' | ||
def __init__(self, h, dim_model): | ||
''' | ||
:param h: number of heads | ||
:param dim_model: hidden dimension | ||
''' | ||
super(MultiHeadAttention, self).__init__() | ||
self.d_k = dim_model // h | ||
self.h = h | ||
# W_q, W_k, W_v, W_o | ||
self.linears = clones(nn.Linear(dim_model, dim_model, bias=False), 4) | ||
|
||
def get(self, x, fields='qkv'): | ||
'Return a dict of queries / keys / values.' | ||
batch_size = x.shape[0] | ||
ret = {} | ||
if 'q' in fields: | ||
ret['q'] = self.linears[0](x).view(batch_size, self.h, self.d_k) | ||
if 'k' in fields: | ||
ret['k'] = self.linears[1](x).view(batch_size, self.h, self.d_k) | ||
if 'v' in fields: | ||
ret['v'] = self.linears[2](x).view(batch_size, self.h, self.d_k) | ||
return ret | ||
|
||
def get_o(self, x): | ||
'Get output of the multi-head attention' | ||
batch_size = x.shape[0] | ||
return self.linears[3](x.view(batch_size, -1)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
VIZ_IDX = 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import torch as th | ||
import torch.nn as nn | ||
import numpy as np | ||
|
||
class PositionalEncoding(nn.Module): | ||
'Position Encoding module' | ||
def __init__(self, dim_model, dropout, max_len=5000): | ||
super(PositionalEncoding, self).__init__() | ||
self.dropout = nn.Dropout(dropout) | ||
# Compute the positional encodings once in log space. | ||
pe = th.zeros(max_len, dim_model, dtype=th.float) | ||
position = th.arange(0, max_len, dtype=th.float).unsqueeze(1) | ||
div_term = th.exp(th.arange(0, dim_model, 2, dtype=th.float)* | ||
-(np.log(10000.0) / dim_model)) | ||
pe[:, 0::2] = th.sin(position * div_term) | ||
pe[:, 1::2] = th.cos(position * div_term) | ||
pe = pe.unsqueeze(0) | ||
self.register_buffer('pe', pe) # Not a parameter but should be in state_dict | ||
|
||
def forward(self, pos): | ||
return th.index_select(self.pe, 1 ,pos).squeeze(0) | ||
|
||
class Embeddings(nn.Module): | ||
'Word Embedding module' | ||
def __init__(self, vocab_size, dim_model): | ||
super(Embeddings, self).__init__() | ||
self.lut =nn.Embedding(vocab_size, dim_model) | ||
self.dim_model = dim_model | ||
|
||
def forward(self, x): | ||
return self.lut(x) * np.sqrt(self.dim_model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import torch as th | ||
|
||
def src_dot_dst(src_field, dst_field, out_field): | ||
''' | ||
This function serves as a surrogate for `src_dot_dst` built-in apply_edge function | ||
:param src_field: | ||
:param dst_field: | ||
:param out_field: | ||
:return: | ||
''' | ||
def func(edges): | ||
return {out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(-1, keepdim=True)} | ||
return func | ||
|
||
def scaled_exp(field, c): | ||
''' | ||
This function applies $exp(x / c)$ for input $x$, which is required by *Scaled Dot-Product Attention* mentioned in the paper. | ||
:param field: | ||
:param c: | ||
:return: | ||
''' | ||
def func(edges): | ||
return {field: th.exp((edges.data[field] / c).clamp(-10, 10))} | ||
return func |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import torch as th | ||
import torch.nn as nn | ||
from torch.nn import LayerNorm | ||
|
||
class Generator(nn.Module): | ||
''' | ||
Generate next token from the representation. This part is separated from the decoder, mostly for the convenience of sharing weight between embedding and generator. | ||
log(softmax(Wx + b)) | ||
''' | ||
def __init__(self, dim_model, vocab_size): | ||
super(Generator ,self).__init__() | ||
self.proj = nn.Linear(dim_model, vocab_size) | ||
|
||
|
||
def forward(self, x): | ||
return th.log_softmax(self.proj(x), dim=-1) | ||
|
||
|
||
class SubLayerWrapper(nn.Module): | ||
''' | ||
The module wraps normalization, dropout, residual connection into one equation: | ||
sublayerwrapper(sublayer)(x) = x + dropout(sublayer(norm(x))) | ||
''' | ||
|
||
def __init__(self, size, dropout): | ||
super(SubLayerWrapper, self).__init__() | ||
self.norm = LayerNorm(size) | ||
self.dropout = nn.Dropout(dropout) | ||
|
||
def forward(self, x, sublayer): | ||
return x + self.dropout(sublayer(self.norm(x))) | ||
|
||
class PositionWiseFeedForWard(nn.Module): | ||
''' | ||
This module implements feed-forward network(after the Multi-Head Network) equation: | ||
FFN(x) = max(0, x @ W_1 + b_1) @ W_2 + b_2 | ||
''' | ||
def __init__(self, dim_model, dim_ff, dropout=0.1): | ||
super(PositionWiseFeedForWard, self).__init__() | ||
self.w_1 = nn.Linear(dim_model, dim_ff) | ||
self.w_2 = nn.Linear(dim_ff, dim_model) | ||
self.dropout = nn.Dropout(dropout) | ||
|
||
def forward(self, x): | ||
return self.w_2(self.dropout(th.relu(self.w_1(x)))) | ||
|
||
import copy | ||
def clones(module, k): | ||
return nn.ModuleList( | ||
copy.deepcopy(module) for _ in range(k)) | ||
|
||
class EncoderLayer(nn.Module): | ||
def __init__(self, size, self_attn, feed_forward, dropout): | ||
super(EncoderLayer, self).__init__() | ||
self.size = size | ||
self.self_attn = self_attn # (key, query, value, mask) | ||
self.feed_forward = feed_forward | ||
self.sublayer = clones(SubLayerWrapper(size, dropout), 2) | ||
|
||
class DecoderLayer(nn.Module): | ||
def __init__(self, size, self_attn, src_attn, feed_forward, dropout): | ||
super(DecoderLayer, self).__init__() | ||
self.size = size | ||
self.self_attn = self_attn # (key, query, value, mask) | ||
self.src_attn = src_attn | ||
self.feed_forward = feed_forward | ||
self.sublayer = clones(SubLayerWrapper(size, dropout), 3) | ||
|