forked from brightmart/text_classification
-
Notifications
You must be signed in to change notification settings - Fork 0
/
a2_encoder.py
126 lines (111 loc) · 6.16 KB
/
a2_encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# -*- coding: utf-8 -*-
"""
encoder for the transformer:
6 layers.each layers has two sub-layers.
the first is multi-head self-attention mechanism;
the second is position-wise fully connected feed-forward network.
for each sublayer. use LayerNorm(x+Sublayer(x)). all dimension=512.
"""
#TODO LAYER NORMALIZATION
import tensorflow as tf
from a2_base_model import BaseClass
import time
class Encoder(BaseClass):
def __init__(self,d_model,d_k,d_v,sequence_length,h,batch_size,num_layer,Q,K_s,type='encoder',mask=None,dropout_keep_prob=None,use_residual_conn=True):
"""
:param d_model:
:param d_k:
:param d_v:
:param sequence_length:
:param h:
:param batch_size:
:param embedded_words: shape:[batch_size*sequence_length,embed_size]
"""
super(Encoder, self).__init__(d_model,d_k,d_v,sequence_length,h,batch_size,num_layer=num_layer)
self.Q=Q
self.K_s=K_s
self.type=type
self.mask=mask
self.initializer = tf.random_normal_initializer(stddev=0.1)
self.dropout_keep_prob=dropout_keep_prob
self.use_residual_conn=use_residual_conn
def encoder_fn(self):
start = time.time()
print("encoder_fn.started.")
Q=self.Q
K_s=self.K_s
for layer_index in range(self.num_layer):
Q, K_s=self.encoder_single_layer(Q,K_s,layer_index)
print("encoder_fn.",layer_index,".Q:",Q,";K_s:",K_s)
end = time.time()
print("encoder_fn.ended.Q:",Q,";K_s:",K_s,";time spent:",(end-start))
return Q,K_s
def encoder_single_layer(self,Q,K_s,layer_index):
"""
singel layer for encoder.each layers has two sub-layers:
the first is multi-head self-attention mechanism; the second is position-wise fully connected feed-forward network.
for each sublayer. use LayerNorm(x+Sublayer(x)). input and output of last dimension: d_model
:param Q: shape should be: [batch_size*sequence_length,d_model]
:param K_s: shape should be: [batch_size*sequence_length,d_model]
:return:output: shape should be:[batch_size*sequence_length,d_model]
"""
#1.1 the first is multi-head self-attention mechanism
multi_head_attention_output=self.sub_layer_multi_head_attention(layer_index,Q,K_s,self.type,mask=self.mask,dropout_keep_prob=self.dropout_keep_prob) #[batch_size,sequence_length,d_model]
#1.2 use LayerNorm(x+Sublayer(x)). all dimension=512.
multi_head_attention_output=self.sub_layer_layer_norm_residual_connection(K_s ,multi_head_attention_output,layer_index,'encoder_multi_head_attention',dropout_keep_prob=self.dropout_keep_prob,use_residual_conn=self.use_residual_conn)
#2.1 the second is position-wise fully connected feed-forward network.
postion_wise_feed_forward_output=self.sub_layer_postion_wise_feed_forward(multi_head_attention_output,layer_index,self.type)
#2.2 use LayerNorm(x+Sublayer(x)). all dimension=512.
postion_wise_feed_forward_output= self.sub_layer_layer_norm_residual_connection(multi_head_attention_output,postion_wise_feed_forward_output,layer_index,'encoder_postion_wise_ff',dropout_keep_prob=self.dropout_keep_prob)
return postion_wise_feed_forward_output,postion_wise_feed_forward_output
def init():
#1. assign value to fields
vocab_size=1000
d_model = 512
d_k = 64
d_v = 64
sequence_length = 5*10
h = 8
batch_size=4*32
initializer = tf.random_normal_initializer(stddev=0.1)
# 2.set values for Q,K,V
vocab_size=1000
embed_size=d_model
Embedding = tf.get_variable("Embedding_E", shape=[vocab_size, embed_size],initializer=initializer)
input_x = tf.placeholder(tf.int32, [batch_size,sequence_length], name="input_x") #[4,10]
print("input_x:",input_x)
embedded_words = tf.nn.embedding_lookup(Embedding, input_x) #[batch_size*sequence_length,embed_size]
Q = embedded_words # [batch_size*sequence_length,embed_size]
K_s = embedded_words # [batch_size*sequence_length,embed_size]
num_layer=6
mask = get_mask(batch_size, sequence_length)
#3. get class object
encoder_class=Encoder(d_model,d_k,d_v,sequence_length,h,batch_size,num_layer,Q,K_s,mask=mask) #Q,K_s,embedded_words
return encoder_class,Q,K_s
def get_mask(batch_size,sequence_length):
lower_triangle=tf.matrix_band_part(tf.ones([sequence_length,sequence_length]),-1,0)
result=-1e9*(1.0-lower_triangle)
print("get_mask==>result:",result)
return result
def test_sub_layer_multi_head_attention(encoder_class,index_layer,Q,K_s):
sub_layer_multi_head_attention_output=encoder_class.sub_layer_multi_head_attention(index_layer,Q,K_s)
return sub_layer_multi_head_attention_output
def test_postion_wise_feed_forward(encoder_class,x,layer_index):
sub_layer_postion_wise_feed_forward_output=encoder_class.sub_layer_postion_wise_feed_forward(x, layer_index)
return sub_layer_postion_wise_feed_forward_output
encoder_class,Q,K_s=init()
#index_layer=0
#below is 4 callable codes for testing functions:from small function to whole function of encoder.
#1.test 1: for sub layer of multi head attention
#sub_layer_multi_head_attention_output=test_sub_layer_multi_head_attention(encoder_class,index_layer,Q,K_s)
#print("sub_layer_multi_head_attention_output1:",sub_layer_multi_head_attention_output)
#2. test 2: for sub layer of multi head attention with poistion-wise feed forward
#d1,d2,d3=sub_layer_multi_head_attention_output.get_shape().as_list()
#postion_wise_ff_input=tf.reshape(sub_layer_multi_head_attention_output,shape=[-1,d3])
#sub_layer_postion_wise_feed_forward_output=test_postion_wise_feed_forward(encoder_class,postion_wise_ff_input,index_layer)
#sub_layer_postion_wise_feed_forward_output=tf.reshape(sub_layer_postion_wise_feed_forward_output,shape=(d1,d2,d3))
#print("sub_layer_postion_wise_feed_forward_output2:",sub_layer_postion_wise_feed_forward_output)
#3.test 3: test for single layer of encoder
#encoder_class.encoder_single_layer(Q,K_s,index_layer)
#4.test 4: test for encoder. with N layers
#Q,K_s = encoder_class.encoder_fn()