Skip to content

Commit 1296361

Browse files
committed
Add optional top-n filtering, fix temperature in chatbot.py
1 parent 2d9e6f0 commit 1296361

File tree

2 files changed

+70
-46
lines changed

2 files changed

+70
-46
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ Try playing around with the arguments to `chatbot.py` to obtain better samples:
5454

5555
- **temperature**: At each step, the model ascribes a certain probability to each character. Temperature can adjust the probability distribution. 1.0 is neutral (and the default), lower values increase high probability values and decrease lower probability values to make the choices more conservative, and higher values will do the reverse. Values outside of the range of 0.5-1.5 are unlikely to give coherent results.
5656

57+
- **top-n**: At each step, zero out the probability of all possible characters except the *n* most likely. Disabled by default.
58+
5759
- **relevance**: Two models are run in parallel: the primary model and the mask model. The mask model is scaled by the relevance value, and then the probabilities of the primary model are combined according to equation 9 in [Li, Jiwei, et al. "A diversity-promoting objective function for neural conversation models." arXiv preprint arXiv:1510.03055 (2015)](https://arxiv.org/abs/1510.03055). The state of the mask model is reset upon each newline character. The net effect is that the model is encouraged to choose a line of dialogue that is most relevant to the prior line of dialogue, even if a more generic response (e.g. "I don't know anything about that") may be more absolutely probable. Higher relevance values put more pressure on the model to produce relevant responses, at the cost of the coherence of the responses. Going much above 0.4 compromises the quality of the responses. Setting it to a negative value disables relevance, and this is the default, because I'm not confident that it qualitatively improves the outputs and it halves the speed of sampling.
5860

5961
These values can also be manipulated during a chat, and the model state can be reset, without restarting the chatbot:
@@ -72,6 +74,12 @@ Restoring weights...
7274
> --relevance -1
7375
[Relevance disabled]
7476
77+
> --topn 2
78+
[Top-n filtering set to 2]
79+
80+
> --topn -1
81+
[Top-n filtering disabled]
82+
7583
> --beam_width 5
7684
[Beam width set to 5]
7785

chatbot.py

Lines changed: 62 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ def main():
2626
parser.add_argument('--temperature', type=float, default=1.0,
2727
help='sampling temperature'
2828
'(lower is more conservative, default is 1.0, which is neutral)')
29+
parser.add_argument('--topn', type=int, default=-1,
30+
help='at each step, choose from only this many most likely characters;'
31+
'set to <0 to disable top-n filtering.')
2932
parser.add_argument('--relevance', type=float, default=-1.,
3033
help='amount of "relevance masking/MMI (disabled by default):"'
3134
'higher is more pressure, 0.4 is probably as high as it can go without'
@@ -75,7 +78,8 @@ def sample_main(args):
7578
# Restore the saved variables, replacing the initialized values.
7679
print("Restoring weights...")
7780
saver.restore(sess, model_path)
78-
chatbot(net, sess, chars, vocab, args.n, args.beam_width, args.relevance, args.temperature)
81+
chatbot(net, sess, chars, vocab, args.n, args.beam_width,
82+
args.relevance, args.temperature, args.topn)
7983

8084
def initial_state(net, sess):
8185
# Return freshly initialized model states.
@@ -96,15 +100,6 @@ def forward_text(net, sess, states, relevance, vocab, prime_text=None):
96100
_, states = net.forward_model(sess, states, vocab[char])
97101
return states
98102

99-
def scale_prediction(prediction, temperature):
100-
if (temperature == 1.0): return prediction # Temperature 1.0 makes no change
101-
np.seterr(divide='ignore')
102-
scaled_prediction = np.log(prediction) / temperature
103-
scaled_prediction = scaled_prediction - np.logaddexp.reduce(scaled_prediction)
104-
scaled_prediction = np.exp(scaled_prediction)
105-
np.seterr(divide='warn')
106-
return scaled_prediction
107-
108103
def sanitize_text(vocab, text): # Strip out characters that are not part of the net's vocab.
109104
return ''.join(i for i in text if i in vocab)
110105

@@ -125,29 +120,29 @@ def possibly_escaped_char(raw_chars):
125120
return backspace_seq + new_seq + "".join([' '] * diff_length) + "".join(['\b'] * diff_length)
126121
return raw_chars[-1]
127122

128-
def chatbot(net, sess, chars, vocab, max_length, beam_width, relevance, temperature):
123+
def chatbot(net, sess, chars, vocab, max_length, beam_width, relevance, temperature, topn):
129124
states = initial_state_with_relevance_masking(net, sess, relevance)
130125
while True:
131126
user_input = input('\n> ')
132-
user_command_entered, reset, states, relevance, temperature, beam_width = process_user_command(
133-
user_input, states, relevance, temperature, beam_width)
127+
user_command_entered, reset, states, relevance, temperature, topn, beam_width = process_user_command(
128+
user_input, states, relevance, temperature, topn, beam_width)
134129
if reset: states = initial_state_with_relevance_masking(net, sess, relevance)
135-
if user_command_entered: continue
136-
states = forward_text(net, sess, states, relevance, vocab, sanitize_text(vocab, "> " + user_input + "\n>"))
137-
computer_response_generator = beam_search_generator(sess=sess, net=net,
138-
initial_state=copy.deepcopy(states), initial_sample=vocab[' '],
139-
early_term_token=vocab['\n'], beam_width=beam_width, forward_model_fn=forward_with_mask,
140-
forward_args={'relevance':relevance, 'mask_reset_token':vocab['\n'], 'forbidden_token':vocab['>']},
141-
temperature=temperature)
142-
out_chars = []
143-
for i, char_token in enumerate(computer_response_generator):
144-
out_chars.append(chars[char_token])
145-
print(possibly_escaped_char(out_chars), end='', flush=True)
146-
states = forward_text(net, sess, states, relevance, vocab, chars[char_token])
147-
if i >= max_length: break
148-
states = forward_text(net, sess, states, relevance, vocab, sanitize_text(vocab, "\n> "))
130+
if not user_command_entered:
131+
states = forward_text(net, sess, states, relevance, vocab, sanitize_text(vocab, "> " + user_input + "\n>"))
132+
computer_response_generator = beam_search_generator(sess=sess, net=net,
133+
initial_state=copy.deepcopy(states), initial_sample=vocab[' '],
134+
early_term_token=vocab['\n'], beam_width=beam_width, forward_model_fn=forward_with_mask,
135+
forward_args={'relevance':relevance, 'mask_reset_token':vocab['\n'], 'forbidden_token':vocab['>'],
136+
'temperature':temperature, 'topn':topn})
137+
out_chars = []
138+
for i, char_token in enumerate(computer_response_generator):
139+
out_chars.append(chars[char_token])
140+
print(possibly_escaped_char(out_chars), end='', flush=True)
141+
states = forward_text(net, sess, states, relevance, vocab, chars[char_token])
142+
if i >= max_length: break
143+
states = forward_text(net, sess, states, relevance, vocab, sanitize_text(vocab, "\n> "))
149144

150-
def process_user_command(user_input, states, relevance, temperature, beam_width):
145+
def process_user_command(user_input, states, relevance, temperature, topn, beam_width):
151146
user_command_entered = False
152147
reset = False
153148
try:
@@ -164,6 +159,10 @@ def process_user_command(user_input, states, relevance, temperature, beam_width)
164159
states = states[0]
165160
relevance = new_relevance
166161
print("[Relevance disabled]" if relevance <= 0. else "[Relevance set to {}]".format(relevance))
162+
elif user_input.startswith('--topn '):
163+
user_command_entered = True
164+
topn = int(user_input[len('--topn '):])
165+
print("[Top-n filtering disabled]" if topn <= 0 else "[Top-n filtering set to {}]".format(topn))
167166
elif user_input.startswith('--beam_width '):
168167
user_command_entered = True
169168
beam_width = max(1, int(user_input[len('--beam_width '):]))
@@ -174,7 +173,7 @@ def process_user_command(user_input, states, relevance, temperature, beam_width)
174173
print("[Model state reset]")
175174
except ValueError:
176175
print("[Value error with provided argument.]")
177-
return user_command_entered, reset, states, relevance, temperature, beam_width
176+
return user_command_entered, reset, states, relevance, temperature, topn, beam_width
178177

179178
def consensus_length(beam_outputs, early_term_token):
180179
for l in range(len(beam_outputs[0])):
@@ -184,30 +183,50 @@ def consensus_length(beam_outputs, early_term_token):
184183
if beam_outputs[0][l] != b[l]: return l, False
185184
return l, False
186185

186+
def scale_prediction(prediction, temperature):
187+
if (temperature == 1.0): return prediction # Temperature 1.0 makes no change
188+
np.seterr(divide='ignore')
189+
scaled_prediction = np.log(prediction) / temperature
190+
scaled_prediction = scaled_prediction - np.logaddexp.reduce(scaled_prediction)
191+
scaled_prediction = np.exp(scaled_prediction)
192+
np.seterr(divide='warn')
193+
return scaled_prediction
194+
187195
def forward_with_mask(sess, net, states, input_sample, forward_args):
188-
# forward_args is a dictionary containing relevance, mask_reset_token, forbidden_token.
196+
# forward_args is a dictionary containing arguments for generating probabilities.
189197
relevance = forward_args['relevance']
190198
mask_reset_token = forward_args['mask_reset_token']
191199
forbidden_token = forward_args['forbidden_token']
200+
temperature = forward_args['temperature']
201+
topn = forward_args['topn']
202+
192203
if relevance <= 0.:
193204
# No relevance masking.
194205
prob, states = net.forward_model(sess, states, input_sample)
195-
prob[forbidden_token] = 0
196-
return prob / sum(prob), states
197-
# states should be a 2-length list: [primary net state, mask net state].
198-
if input_sample == mask_reset_token:
199-
# Reset the mask probs when reaching mask_reset_token (newline).
200-
states[1] = initial_state(net, sess)
201-
primary_prob, states[0] = net.forward_model(sess, states[0], input_sample)
202-
primary_prob /= sum(primary_prob)
203-
mask_prob, states[1] = net.forward_model(sess, states[1], input_sample)
204-
mask_prob /= sum(mask_prob)
205-
combined_prob = np.exp(np.log(primary_prob) - relevance * np.log(mask_prob))
206+
else:
207+
# states should be a 2-length list: [primary net state, mask net state].
208+
if input_sample == mask_reset_token:
209+
# Reset the mask probs when reaching mask_reset_token (newline).
210+
states[1] = initial_state(net, sess)
211+
primary_prob, states[0] = net.forward_model(sess, states[0], input_sample)
212+
primary_prob /= sum(primary_prob)
213+
mask_prob, states[1] = net.forward_model(sess, states[1], input_sample)
214+
mask_prob /= sum(mask_prob)
215+
prob = np.exp(np.log(primary_prob) - relevance * np.log(mask_prob))
216+
# Mask out the forbidden token (">") to prevent the bot from deciding the chat is over)
217+
prob[forbidden_token] = 0
206218
# Normalize probabilities so they sum to 1.
207-
return combined_prob / sum(combined_prob), states
219+
prob = prob / sum(prob)
220+
# Apply temperature.
221+
prob = scale_prediction(prob, temperature)
222+
# Apply top-n filtering if enabled
223+
if topn > 0:
224+
prob[np.argsort(prob)[:-topn]] = 0
225+
prob = prob / sum(prob)
226+
return prob, states
208227

209228
def beam_search_generator(sess, net, initial_state, initial_sample,
210-
early_term_token, beam_width, forward_model_fn, forward_args, temperature):
229+
early_term_token, beam_width, forward_model_fn, forward_args):
211230
'''Run beam search! Yield consensus tokens sequentially, as a generator;
212231
return when reaching early_term_token (newline).
213232
@@ -224,8 +243,6 @@ def beam_search_generator(sess, net, initial_state, initial_sample,
224243
probability_output, beam_state =
225244
forward_model_fn(sess, net, beam_state, beam_sample, forward_args)
226245
(Note: probability_output has to be a valid probability distribution!)
227-
temperature: how conservatively to sample tokens from each distribution
228-
(1.0 = neutral, lower means more conservative)
229246
tot_steps: how many tokens to generate before stopping,
230247
unless already stopped via early_term_token.
231248
Returns: a generator to yield a sequence of beam-sampled tokens.'''
@@ -253,7 +270,6 @@ def beam_search_generator(sess, net, initial_state, initial_sample,
253270
# Forward the model.
254271
prediction, beam_states[beam_index] = forward_model_fn(
255272
sess, net, beam_state, beam_sample, forward_args)
256-
prediction = scale_prediction(prediction, temperature)
257273

258274
# Sample best_tokens from the probability distribution.
259275
# Sample from the scaled probability distribution beam_width choices

0 commit comments

Comments
 (0)