-
Notifications
You must be signed in to change notification settings - Fork 148
/
model_components.py
90 lines (75 loc) · 3.81 KB
/
model_components.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import tensorflow as tf
import tensorflow.contrib.layers as layers
try:
from tensorflow.contrib.rnn import LSTMStateTuple
except ImportError:
LSTMStateTuple = tf.nn.rnn_cell.LSTMStateTuple
def bidirectional_rnn(cell_fw, cell_bw, inputs_embedded, input_lengths,
scope=None):
"""Bidirecional RNN with concatenated outputs and states"""
with tf.variable_scope(scope or "birnn") as scope:
((fw_outputs,
bw_outputs),
(fw_state,
bw_state)) = (
tf.nn.bidirectional_dynamic_rnn(cell_fw=cell_fw,
cell_bw=cell_bw,
inputs=inputs_embedded,
sequence_length=input_lengths,
dtype=tf.float32,
swap_memory=True,
scope=scope))
outputs = tf.concat((fw_outputs, bw_outputs), 2)
def concatenate_state(fw_state, bw_state):
if isinstance(fw_state, LSTMStateTuple):
state_c = tf.concat(
(fw_state.c, bw_state.c), 1, name='bidirectional_concat_c')
state_h = tf.concat(
(fw_state.h, bw_state.h), 1, name='bidirectional_concat_h')
state = LSTMStateTuple(c=state_c, h=state_h)
return state
elif isinstance(fw_state, tf.Tensor):
state = tf.concat((fw_state, bw_state), 1,
name='bidirectional_concat')
return state
elif (isinstance(fw_state, tuple) and
isinstance(bw_state, tuple) and
len(fw_state) == len(bw_state)):
# multilayer
state = tuple(concatenate_state(fw, bw)
for fw, bw in zip(fw_state, bw_state))
return state
else:
raise ValueError(
'unknown state type: {}'.format((fw_state, bw_state)))
state = concatenate_state(fw_state, bw_state)
return outputs, state
def task_specific_attention(inputs, output_size,
initializer=layers.xavier_initializer(),
activation_fn=tf.tanh, scope=None):
"""
Performs task-specific attention reduction, using learned
attention context vector (constant within task of interest).
Args:
inputs: Tensor of shape [batch_size, units, input_size]
`input_size` must be static (known)
`units` axis will be attended over (reduced from output)
`batch_size` will be preserved
output_size: Size of output's inner (feature) dimension
Returns:
outputs: Tensor of shape [batch_size, output_dim].
"""
assert len(inputs.get_shape()) == 3 and inputs.get_shape()[-1].value is not None
with tf.variable_scope(scope or 'attention') as scope:
attention_context_vector = tf.get_variable(name='attention_context_vector',
shape=[output_size],
initializer=initializer,
dtype=tf.float32)
input_projection = layers.fully_connected(inputs, output_size,
activation_fn=activation_fn,
scope=scope)
vector_attn = tf.reduce_sum(tf.multiply(input_projection, attention_context_vector), axis=2, keep_dims=True)
attention_weights = tf.nn.softmax(vector_attn, dim=1)
weighted_projection = tf.multiply(input_projection, attention_weights)
outputs = tf.reduce_sum(weighted_projection, axis=1)
return outputs