Skip to content

Commit 7a2b130

Browse files
committed
Added separate file for talking with the Chatbot
1 parent b16357f commit 7a2b130

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

talk.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,12 @@
3636
hidden_size = 512
3737
projection_size = 300
3838
embedding_size = 300
39-
num_layers = 3
39+
num_layers = 1
4040

4141
# ouput_size for softmax layer
4242
output_size = projection_size
4343

44+
keep_prob = 0.95
4445
beam_size = 10
4546
top_k = 10
4647
max_sequence_len = 20
@@ -61,7 +62,7 @@
6162
dec_inputs = tf.placeholder(tf.int32, shape=(None, batch_size), name="dec_inputs")
6263

6364
#input embedding layers
64-
emb_weights = tf.Variable(tf.truncated_normal([vocab_size, embedding_size], stddev=truncated_std), name="emb_weights")
65+
emb_weights = tf.Variable(tf.truncated_normal([vocab_size, embedding_size]), name="emb_weights")
6566
enc_inputs_emb = tf.nn.embedding_lookup(emb_weights, enc_inputs, name="enc_inputs_emb")
6667
dec_inputs_emb = tf.nn.embedding_lookup(emb_weights, dec_inputs, name="dec_inputs_emb")
6768

@@ -113,9 +114,9 @@
113114
scope="decoder")
114115

115116
#output layers
116-
project_w = tf.Variable(tf.truncated_normal(shape=[output_size, embedding_size], stddev=truncated_std), name="project_w")
117+
project_w = tf.Variable(tf.truncated_normal(shape=[output_size, embedding_size]), name="project_w")
117118
project_b = tf.Variable(tf.constant(shape=[embedding_size], value = 0.1), name="project_b")
118-
softmax_w = tf.Variable(tf.truncated_normal(shape=[embedding_size, vocab_size], stddev=truncated_std), name="softmax_w")
119+
softmax_w = tf.Variable(tf.truncated_normal(shape=[embedding_size, vocab_size]), name="softmax_w")
119120
softmax_b = tf.Variable(tf.constant(shape=[vocab_size], value = 0.1), name="softmax_b")
120121

121122
dec_outputs = tf.reshape(dec_outputs, [-1, output_size], name="dec_ouputs")
@@ -199,8 +200,6 @@ def predict(enc_inp):
199200
if len(candidates) == 0:
200201
break
201202

202-
if signal:
203-
best_sequence = [signal] + best_sequence
204203

205204
return best_sequence[:-1]
206205

0 commit comments

Comments
 (0)