Skip to content

Commit fdc8c3f

Browse files
committed
Added ability to change sampling settings during a chat
1 parent ca2b46e commit fdc8c3f

File tree

2 files changed

+67
-23
lines changed

2 files changed

+67
-23
lines changed

README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ A toy chatbot powered by deep learning and trained on data from Reddit.
44
Here is a sample chat transcript (not cherry-picked). Lines that start with ">" are me; the lines without ">" are the chatbot's responses.
55

66
```
7+
$ python chatbot.py
8+
Creating model...
9+
Restoring weights...
10+
711
> Hi, how are you?
812
It's all good, I'm going to bed now but I'll see if I can get home tomorrow. I'll get back to you.
913
> Don't go to bed just yet
@@ -46,6 +50,29 @@ Try playing around with the arguments to `chatbot.py` to obtain better samples:
4650

4751
- **relevance**: Two models are run in parallel: the primary model and the mask model. The mask model is scaled by the relevance value, and then the probabilities of the primary model are multiplied by the complement of the mask model before sampling. The state of the mask model is reset upon each newline character. The net effect is that the model is encouraged to choose a line of dialogue that is most relevant to the prior line of dialogue, even if a more generic response (e.g. "I don't know anything about that") may be more absolutely probable. Lower relevance values put more pressure on the model to produce relevant responses, at the cost of the coherence of the responses. Going much below 1.5 compromises the quality of the responses; 2-3 is the recommended range. Setting it to a negative value disables relevance, and this is the default, because I'm not confident that it qualitatively improves the outputs and it halves the speed of sampling.
4852

53+
These values can also be manipulated during a chat, and the model state can be reset, without restarting the chatbot:
54+
55+
```
56+
$ python chatbot.py
57+
Creating model...
58+
Restoring weights...
59+
60+
> --temperature 1.3
61+
[Temperature set to 1.3]
62+
63+
> --relevance 2
64+
[Relevance set to 2.0]
65+
66+
> --relevance -1
67+
[Relevance disabled]
68+
69+
> --beam_width 5
70+
[Beam width set to 5]
71+
72+
> --reset
73+
[Model state reset]
74+
```
75+
4976
### Get training data
5077

5178
If you'd like to train your own model, you'll need training data. There are a few options here.

chatbot.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -120,41 +120,58 @@ def beam_sample(net, sess, chars, vocab, max_length=200, prime='The ',
120120
def sanitize_text(vocab, text):
121121
return ''.join(i for i in text if i in vocab)
122122

123+
def initial_state_with_relevance_masking(net, sess, relevance):
124+
if relevance <= 0.: return initial_state(net, sess)
125+
else: return [initial_state(net, sess), initial_state(net, sess)]
126+
123127
def chatbot(net, sess, chars, vocab, max_length, beam_width, relevance, temperature):
124-
if relevance < 0.:
125-
states = initial_state(net, sess)
126-
else:
127-
states = [initial_state(net, sess), initial_state(net, sess)]
128+
states = initial_state_with_relevance_masking(net, sess, relevance)
128129
while True:
129130
user_input = sanitize_text(vocab, raw_input('\n> '))
130-
user_command_entered, relevance, temperature = process_user_command(
131-
user_input, relevance, temperature)
131+
user_command_entered, reset, states, relevance, temperature, beam_width = process_user_command(
132+
user_input, states, relevance, temperature, beam_width)
133+
if reset: states = initial_state_with_relevance_masking(net, sess, relevance)
132134
if user_command_entered: continue
133135
states = forward_text(net, sess, states, vocab, '> ' + user_input + "\n>")
134-
computer_response_generator = beam_search_generator(sess, net, copy.deepcopy(states),
135-
vocab[' '], vocab['\n'], beam_width, forward_with_mask,
136-
(relevance, vocab['\n']), temperature)
136+
computer_response_generator = beam_search_generator(sess=sess, net=net,
137+
initial_state=copy.deepcopy(states), initial_sample=vocab[' '],
138+
early_term_token=vocab['\n'], beam_width=beam_width, forward_model_fn=forward_with_mask,
139+
forward_args=(relevance, vocab['\n']), temperature=temperature)
137140
for i, char_token in enumerate(computer_response_generator):
138141
print(chars[char_token], end='')
139142
states = forward_text(net, sess, states, vocab, chars[char_token])
140143
sys.stdout.flush()
141144
if i >= max_length: break
142145
states = forward_text(net, sess, states, vocab, '\n> ')
143146

144-
def process_user_command(user_input, relevance, temperature):
145-
if user_input.startswith('--temperature '):
146-
temperature = float(user_input[len('--temperature '):])
147-
print("[Temperature set to {}]".format(temperature))
148-
return True, relevance, temperature
149-
elif user_input.startswith('--relevance '):
150-
if relevance < 0:
151-
print("[Relevance is disabled; restart program with relevance > 0 to enable.]")
152-
return True, relevance, temperature
153-
else:
154-
relevance = float(user_input[len('--relevance '):])
155-
print("[Relevance set to {}]".format(relevance))
156-
return True, relevance, temperature
157-
return False, relevance, temperature
147+
def process_user_command(user_input, states, relevance, temperature, beam_width):
148+
user_command_entered = False
149+
reset = False
150+
try:
151+
if user_input.startswith('--temperature '):
152+
user_command_entered = True
153+
temperature = max(0.001, float(user_input[len('--temperature '):]))
154+
print("[Temperature set to {}]".format(temperature))
155+
elif user_input.startswith('--relevance '):
156+
user_command_entered = True
157+
new_relevance = float(user_input[len('--relevance '):])
158+
if relevance <= 0. and new_relevance > 0.:
159+
states = [states, copy.deepcopy(states)]
160+
elif relevance > 0. and new_relevance < 0.:
161+
states = states[0]
162+
relevance = new_relevance
163+
print("[Relevance disabled]" if relevance < 0. else "[Relevance set to {}]".format(relevance))
164+
elif user_input.startswith('--beam_width '):
165+
user_command_entered = True
166+
beam_width = max(1, int(user_input[len('--beam_width '):]))
167+
print("[Beam width set to {}]".format(beam_width))
168+
elif user_input.startswith('--reset'):
169+
user_command_entered = True
170+
reset = True
171+
print("[Model state reset]")
172+
except ValueError:
173+
print("[Value error with provided argument.]")
174+
return user_command_entered, reset, states, relevance, temperature, beam_width
158175

159176
def consensus_length(beam_outputs, early_term_token):
160177
for l in xrange(len(beam_outputs[0])):

0 commit comments

Comments
 (0)