Skip to content

Commit

Permalink
update metalstm
Browse files Browse the repository at this point in the history
  • Loading branch information
LittleChenCc committed May 6, 2018
1 parent a881a08 commit 2b6dbfd
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions MetaLSTM.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def state_size(self):
def output_size(self):
return self._num_units

def getMetaResults(self, hyper_output, input, dimensions, scope="meta"):
def getMetaResults(self, meta_output, input, dimensions, scope="meta"):
"""calculate the gates results of basic lstm with meta-lstm network"""
# with tf.variable_scope('z_trans'):
# hyper_output = rnn_cell._linear(hyper_output, self._meta_num_units, False)
# meta_output = rnn_cell._linear(meta_output, self._meta_num_units, False)

with tf.variable_scope(scope):
W_matrix_list = []
Expand All @@ -48,13 +48,13 @@ def getMetaResults(self, hyper_output, input, dimensions, scope="meta"):
Q = tf.get_variable('Q{}'.format(i), shape=[self._meta_num_units, input_shape],
initializer=tf.uniform_unit_scaling_initializer(),dtype=tf.float32)

_W_matrix = tf.matmul(tf.reshape(tf.matrix_diag(hyper_output),[-1, self._meta_num_units]), P)
_W_matrix = tf.matmul(tf.reshape(tf.matrix_diag(meta_output),[-1, self._meta_num_units]), P)
_W_matrix = tf.reshape(_W_matrix, [-1, self._meta_num_units, dimensions])
_W_matrix = tf.matmul(tf.reshape(tf.transpose(_W_matrix, [0,2,1]), [-1, self._meta_num_units]), Q)
_W_matrix = tf.reshape(_W_matrix, [-1, dimensions, input_shape])
W_matrix_list.append(_W_matrix)
W_matrix = tf.concat(values=W_matrix_list, axis=1)
Bias = rnn_cell._linear(hyper_output, 4*dimensions, False)
Bias = rnn_cell._linear(meta_output, 4*dimensions, False)

result = tf.matmul(W_matrix, tf.expand_dims(input, -1))
result = tf.add(tf.reshape(result, [-1, 4*dimensions]), Bias)
Expand All @@ -68,23 +68,23 @@ def __call__(self, inputs, state, scope=None):
total_h, total_c = tf.split(axis=1, num_or_size_splits=2, value=state)
h = total_h[:, 0:self._num_units]
c = total_c[:, 0:self._num_units]
hyper_state = tf.concat(values=[total_h[:, self._num_units:], total_c[:, self._num_units:]], axis=1)
hyper_input = tf.concat(values=[inputs, h], axis=1)
meta_state = tf.concat(values=[total_h[:, self._num_units:], total_c[:, self._num_units:]], axis=1)
meta_input = tf.concat(values=[inputs, h], axis=1)

#get outputs from meta-lstm
hyper_output, hyper_new_state = self._meta_cell(hyper_input, hyper_state)
meta_output, meta_new_state = self._meta_cell(meta_input, meta_state)

#calculate gates of basic lstm
input_concat = tf.concat(values=[inputs, h], axis=1)
lstm_gates= self.getMetaResults(hyper_output, input_concat, self._num_units, scope = 'hyper_result')
lstm_gates= self.getMetaResults(meta_output, input_concat, self._num_units, scope = 'meta_result')
i, j, f, o = tf.split(axis=1, num_or_size_splits=4, value=lstm_gates)
new_c = (c * sigmoid(f) + sigmoid(i) * self._activation(j))
new_h = self._activation(new_c) * sigmoid(o)

#update new states
hyper_h, hyper_c = tf.split(axis=1, num_or_size_splits=2, value=hyper_new_state)
new_total_h = tf.concat(values=[new_h, hyper_h], axis=1)
new_total_c = tf.concat(values=[new_c, hyper_c], axis=1)
meta_h, meta_c = tf.split(axis=1, num_or_size_splits=2, value=meta_new_state)
new_total_h = tf.concat(values=[new_h, meta_h], axis=1)
new_total_c = tf.concat(values=[new_c, meta_c], axis=1)
new_total_state = tf.concat(values=[new_total_h, new_total_c], axis=1)

return new_h, new_total_state

0 comments on commit 2b6dbfd

Please sign in to comment.