|
| 1 | +#coding:utf-8 |
| 2 | +import tensorflow as tf |
| 3 | +import sys,time |
| 4 | +import numpy as np |
| 5 | +import cPickle, os |
| 6 | +import random |
| 7 | +import Config |
| 8 | + |
| 9 | +test_word = [u'FDA', u'menu'] |
| 10 | + |
| 11 | +config_tf = tf.ConfigProto() |
| 12 | +config_tf.gpu_options.allow_growth = True |
| 13 | + |
| 14 | +word_vec = cPickle.load(open('word_vec.pkl', 'r')) |
| 15 | +vocab = cPickle.load(open('word_voc.pkl','r')) |
| 16 | + |
| 17 | +word_to_idx = { ch:i for i,ch in enumerate(vocab) } |
| 18 | +idx_to_word = { i:ch for i,ch in enumerate(vocab) } |
| 19 | + |
| 20 | +gen_config = Config.Config() |
| 21 | + |
| 22 | +gen_config.vocab_size = len(vocab) |
| 23 | + |
| 24 | +class Model(object): |
| 25 | + def __init__(self, is_training, config): |
| 26 | + self.batch_size = batch_size = config.batch_size |
| 27 | + self.num_steps = num_steps = config.num_steps |
| 28 | + self.size = size = config.hidden_size |
| 29 | + vocab_size = config.vocab_size |
| 30 | + |
| 31 | + self._input_data = tf.placeholder(tf.int32, [batch_size, num_steps]) |
| 32 | + self._targets = tf.placeholder(tf.int32, [batch_size, num_steps]) |
| 33 | + self._input_word = tf.placeholder(tf.int32, [batch_size, config.num_keywords]) |
| 34 | + self._init_output = tf.placeholder(tf.float32, [batch_size, size]) |
| 35 | + self._mask = tf.placeholder(tf.float32, [batch_size, num_steps]) |
| 36 | + self.seq_length = tf.placeholder(tf.float32, [batch_size, 1]) |
| 37 | + |
| 38 | + |
| 39 | + LSTM_cell = tf.nn.rnn_cell.LSTMCell(size, forget_bias=0.0, state_is_tuple=False) |
| 40 | + if is_training and config.keep_prob < 1: |
| 41 | + LSTM_cell = tf.nn.rnn_cell.DropoutWrapper( |
| 42 | + LSTM_cell, output_keep_prob=config.keep_prob) |
| 43 | + cell = tf.nn.rnn_cell.MultiRNNCell([LSTM_cell] * config.num_layers, state_is_tuple=False) |
| 44 | + |
| 45 | + self._initial_state = cell.zero_state(batch_size, tf.float32) |
| 46 | + |
| 47 | + with tf.device("/cpu:0"): |
| 48 | + embedding = tf.get_variable('word_embedding', [vocab_size, config.word_embedding_size], trainable=True, initializer=tf.constant_initializer(word_vec)) |
| 49 | + inputs = tf.nn.embedding_lookup(embedding, self._input_data) |
| 50 | + keyword_inputs = tf.nn.embedding_lookup(embedding, self._input_word) |
| 51 | + |
| 52 | + if is_training and config.keep_prob < 1: |
| 53 | + inputs = tf.nn.dropout(inputs, config.keep_prob) |
| 54 | + |
| 55 | + |
| 56 | + self.initial_gate = tf.ones([batch_size, config.num_keywords]) |
| 57 | + gate = self.initial_gate |
| 58 | + |
| 59 | + atten_sum = tf.zeros([batch_size, config.num_keywords]) |
| 60 | + |
| 61 | + with tf.variable_scope("coverage"): |
| 62 | + u_f = tf.get_variable("u_f", [config.num_keywords*config.word_embedding_size, config.num_keywords]) |
| 63 | + res1 = tf.sigmoid(tf.matmul(tf.reshape(keyword_inputs, [batch_size, -1]), u_f)) |
| 64 | + phi_res = self.seq_length * res1 |
| 65 | + |
| 66 | + self.output1 = phi_res |
| 67 | + |
| 68 | + outputs = [] |
| 69 | + output_state = self._init_output |
| 70 | + state = self._initial_state |
| 71 | + with tf.variable_scope("RNN"): |
| 72 | + entropy_cost = [] |
| 73 | + for time_step in range(num_steps): |
| 74 | + vs = [] |
| 75 | + for s2 in range(config.num_keywords): |
| 76 | + with tf.variable_scope("RNN_attention"): |
| 77 | + if time_step > 0 or s2 > 0: tf.get_variable_scope().reuse_variables() |
| 78 | + u = tf.get_variable("u", [size, 1]) |
| 79 | + w1 = tf.get_variable("w1", [size, size]) |
| 80 | + w2 = tf.get_variable("w2", [config.word_embedding_size, size]) |
| 81 | + b = tf.get_variable("b1", [size]) |
| 82 | + |
| 83 | + vi = tf.matmul(tf.tanh(tf.add(tf.add( |
| 84 | + tf.matmul(output_state, w1), |
| 85 | + tf.matmul(keyword_inputs[:, s2, :], w2)), b)), u) |
| 86 | + vs.append(vi*gate[:,s2:s2+1]) |
| 87 | + |
| 88 | + self.attention_vs = tf.concat(vs, axis=1) |
| 89 | + prob_p = tf.nn.softmax(self.attention_vs) |
| 90 | + |
| 91 | + self.attention_weight = prob_p |
| 92 | + |
| 93 | + gate = gate - (prob_p / phi_res) |
| 94 | + self.output_gate = gate |
| 95 | + |
| 96 | + atten_sum += prob_p * self._mask[:,time_step:time_step+1] |
| 97 | + |
| 98 | + mt = tf.add_n([prob_p[:,i:i+1]*keyword_inputs[:, i, :] for i in range(config.num_keywords)]) |
| 99 | + |
| 100 | + with tf.variable_scope("RNN_sentence"): |
| 101 | + if time_step > 0: tf.get_variable_scope().reuse_variables() |
| 102 | + (cell_output, state) = cell(tf.concat([inputs[:, time_step, :], mt], axis=1), state) |
| 103 | + outputs.append(cell_output) |
| 104 | + output_state = cell_output |
| 105 | + |
| 106 | + self._end_output = cell_output |
| 107 | + |
| 108 | + self.output2 = atten_sum |
| 109 | + output = tf.reshape(tf.concat(outputs, axis=1), [-1, size]) |
| 110 | + |
| 111 | + softmax_w = tf.get_variable("softmax_w", [size, vocab_size]) |
| 112 | + softmax_b = tf.get_variable("softmax_b", [vocab_size]) |
| 113 | + logits = tf.matmul(output, softmax_w) + softmax_b |
| 114 | + |
| 115 | + self._final_state = state |
| 116 | + self._prob = tf.nn.softmax(logits) |
| 117 | + |
| 118 | + return |
| 119 | + |
| 120 | + @property |
| 121 | + def input_data(self): |
| 122 | + return self._input_data |
| 123 | + |
| 124 | + @property |
| 125 | + def end_output(self): |
| 126 | + return self._end_output |
| 127 | + |
| 128 | + @property |
| 129 | + def targets(self): |
| 130 | + return self._targets |
| 131 | + |
| 132 | + @property |
| 133 | + def initial_state(self): |
| 134 | + return self._initial_state |
| 135 | + |
| 136 | + @property |
| 137 | + def cost(self): |
| 138 | + return self._cost |
| 139 | + |
| 140 | + @property |
| 141 | + def final_state(self): |
| 142 | + return self._final_state |
| 143 | + |
| 144 | + |
| 145 | + |
| 146 | +def run_epoch(session, m, data, eval_op, state=None, is_test=False, input_words=None, verbose=False, flag = 1, last_output=None, last_gate=None, lens=None): |
| 147 | + """Runs the model on the given data.""" |
| 148 | + x = data.reshape((1,1)) |
| 149 | + initial_output = np.zeros((m.batch_size, m.size)) |
| 150 | + if flag == 0: |
| 151 | + prob, _state, _last_output, _last_gate, weight, _phi, _ = session.run([m._prob, m.final_state, m.end_output, m.output_gate, m.attention_weight, m.output1, eval_op], |
| 152 | + {m.input_data: x, |
| 153 | + m._input_word: input_words, |
| 154 | + m.initial_state: state, |
| 155 | + m._init_output: initial_output, |
| 156 | + m.seq_length: [[lens]]}) |
| 157 | + |
| 158 | + return prob, _state, _last_output, _last_gate, weight, _phi |
| 159 | + else: |
| 160 | + prob, _state, _last_output, _last_gate, weight, _ = session.run([m._prob, m.final_state, m.end_output, m.output_gate, m.attention_weight, eval_op], |
| 161 | + {m.input_data: x, |
| 162 | + m._input_word: input_words, |
| 163 | + m.initial_state: state, |
| 164 | + m._init_output: last_output, |
| 165 | + m.seq_length: [[lens]], |
| 166 | + m.initial_gate: last_gate}) |
| 167 | + return prob, _state, _last_output, _last_gate, weight |
| 168 | + |
| 169 | +def main(_): |
| 170 | + with tf.Graph().as_default(), tf.Session(config=config_tf) as session: |
| 171 | + |
| 172 | + gen_config.batch_size = 1 |
| 173 | + gen_config.num_steps = 1 |
| 174 | + |
| 175 | + beam_size = gen_config.BeamSize |
| 176 | + |
| 177 | + initializer = tf.random_uniform_initializer(-gen_config.init_scale, |
| 178 | + gen_config.init_scale) |
| 179 | + with tf.variable_scope("model", reuse=None, initializer=initializer): |
| 180 | + mtest = Model(is_training=False, config=gen_config) |
| 181 | + |
| 182 | + tf.initialize_all_variables().run() |
| 183 | + |
| 184 | + model_saver = tf.train.Saver(tf.all_variables()) |
| 185 | + print 'model loading ...' |
| 186 | + model_saver.restore(session, gen_config.model_path+'--%d'%gen_config.save_time) |
| 187 | + print 'Done!' |
| 188 | + |
| 189 | + test_word = [u'信念',u'人生',u'失落',u'心灵',u'不屈'] |
| 190 | + len_of_sample = gen_config.len_of_generation |
| 191 | + |
| 192 | + _state = mtest.initial_state.eval() |
| 193 | + tmp = [] |
| 194 | + beams = [(0.0, [idx_to_word[1]], idx_to_word[1])] |
| 195 | + for wd in test_word: |
| 196 | + tmp.append(word_to_idx[wd]) |
| 197 | + _input_words = np.array([tmp], dtype=np.float32) |
| 198 | + test_data = np.int32([1]) |
| 199 | + prob, _state, _last_output, _last_gate, weight, _phi = run_epoch(session, mtest, test_data, tf.no_op(), _state, True, input_words=_input_words, flag=0, lens=len_of_sample) |
| 200 | + y1 = np.log(1e-20 + prob.reshape(-1)) |
| 201 | + if gen_config.is_sample: |
| 202 | + try: |
| 203 | + top_indices = np.random.choice(gen_config.vocab_size, beam_size, replace=False, p=prob.reshape(-1)) |
| 204 | + except: |
| 205 | + top_indices = np.random.choice(gen_config.vocab_size, beam_size, replace=True, p=prob.reshape(-1)) |
| 206 | + else: |
| 207 | + top_indices = np.argsort(-y1) |
| 208 | + b = beams[0] |
| 209 | + beam_candidates = [] |
| 210 | + for i in xrange(beam_size): |
| 211 | + wordix = top_indices[i] |
| 212 | + beam_candidates.append((b[0] + y1[wordix], b[1] + [idx_to_word[wordix]], wordix, _state, _last_output, _last_gate)) |
| 213 | + beam_candidates.sort(key = lambda x:x[0], reverse = True) # decreasing order |
| 214 | + beams = beam_candidates[:beam_size] # truncate to get new beams |
| 215 | + for xy in range(len_of_sample-1): |
| 216 | + beam_candidates = [] |
| 217 | + for b in beams: |
| 218 | + test_data = np.int32(b[2]) |
| 219 | + prob, _state, _last_output, _last_gate, weight = run_epoch(session, mtest, test_data, tf.no_op(), b[3], True, input_words=_input_words, flag=1, last_output=b[4], last_gate=b[5], lens=len_of_sample) |
| 220 | + y1 = np.log(1e-20 + prob.reshape(-1)) |
| 221 | + if gen_config.is_sample: |
| 222 | + try: |
| 223 | + top_indices = np.random.choice(gen_config.vocab_size, beam_size, replace=False, p=prob.reshape(-1)) |
| 224 | + except: |
| 225 | + top_indices = np.random.choice(gen_config.vocab_size, beam_size, replace=True, p=prob.reshape(-1)) |
| 226 | + else: |
| 227 | + top_indices = np.argsort(-y1) |
| 228 | + for i in xrange(beam_size): |
| 229 | + wordix = top_indices[i] |
| 230 | + beam_candidates.append((b[0] + y1[wordix], b[1] + [idx_to_word[wordix]], wordix, _state, _last_output, _last_gate)) |
| 231 | + beam_candidates.sort(key = lambda x:x[0], reverse = True) # decreasing order |
| 232 | + beams = beam_candidates[:beam_size] # truncate to get new beams |
| 233 | + |
| 234 | + print ' '.join(beams[0][1][1:]).encode('utf-8') |
| 235 | + |
| 236 | +if __name__ == "__main__": |
| 237 | + tf.app.run() |
0 commit comments