@@ -26,6 +26,9 @@ def main():
2626 parser .add_argument ('--temperature' , type = float , default = 1.0 ,
2727 help = 'sampling temperature'
2828 '(lower is more conservative, default is 1.0, which is neutral)' )
29+ parser .add_argument ('--topn' , type = int , default = - 1 ,
30+ help = 'at each step, choose from only this many most likely characters;'
31+ 'set to <0 to disable top-n filtering.' )
2932 parser .add_argument ('--relevance' , type = float , default = - 1. ,
3033 help = 'amount of "relevance masking/MMI (disabled by default):"'
3134 'higher is more pressure, 0.4 is probably as high as it can go without'
@@ -75,7 +78,8 @@ def sample_main(args):
7578 # Restore the saved variables, replacing the initialized values.
7679 print ("Restoring weights..." )
7780 saver .restore (sess , model_path )
78- chatbot (net , sess , chars , vocab , args .n , args .beam_width , args .relevance , args .temperature )
81+ chatbot (net , sess , chars , vocab , args .n , args .beam_width ,
82+ args .relevance , args .temperature , args .topn )
7983
8084def initial_state (net , sess ):
8185 # Return freshly initialized model states.
@@ -96,15 +100,6 @@ def forward_text(net, sess, states, relevance, vocab, prime_text=None):
96100 _ , states = net .forward_model (sess , states , vocab [char ])
97101 return states
98102
99- def scale_prediction (prediction , temperature ):
100- if (temperature == 1.0 ): return prediction # Temperature 1.0 makes no change
101- np .seterr (divide = 'ignore' )
102- scaled_prediction = np .log (prediction ) / temperature
103- scaled_prediction = scaled_prediction - np .logaddexp .reduce (scaled_prediction )
104- scaled_prediction = np .exp (scaled_prediction )
105- np .seterr (divide = 'warn' )
106- return scaled_prediction
107-
108103def sanitize_text (vocab , text ): # Strip out characters that are not part of the net's vocab.
109104 return '' .join (i for i in text if i in vocab )
110105
@@ -125,29 +120,29 @@ def possibly_escaped_char(raw_chars):
125120 return backspace_seq + new_seq + "" .join ([' ' ] * diff_length ) + "" .join (['\b ' ] * diff_length )
126121 return raw_chars [- 1 ]
127122
128- def chatbot (net , sess , chars , vocab , max_length , beam_width , relevance , temperature ):
123+ def chatbot (net , sess , chars , vocab , max_length , beam_width , relevance , temperature , topn ):
129124 states = initial_state_with_relevance_masking (net , sess , relevance )
130125 while True :
131126 user_input = input ('\n > ' )
132- user_command_entered , reset , states , relevance , temperature , beam_width = process_user_command (
133- user_input , states , relevance , temperature , beam_width )
127+ user_command_entered , reset , states , relevance , temperature , topn , beam_width = process_user_command (
128+ user_input , states , relevance , temperature , topn , beam_width )
134129 if reset : states = initial_state_with_relevance_masking (net , sess , relevance )
135- if user_command_entered : continue
136- states = forward_text (net , sess , states , relevance , vocab , sanitize_text (vocab , "> " + user_input + "\n >" ))
137- computer_response_generator = beam_search_generator (sess = sess , net = net ,
138- initial_state = copy .deepcopy (states ), initial_sample = vocab [' ' ],
139- early_term_token = vocab ['\n ' ], beam_width = beam_width , forward_model_fn = forward_with_mask ,
140- forward_args = {'relevance' :relevance , 'mask_reset_token' :vocab ['\n ' ], 'forbidden_token' :vocab ['>' ]} ,
141- temperature = temperature )
142- out_chars = []
143- for i , char_token in enumerate (computer_response_generator ):
144- out_chars .append (chars [char_token ])
145- print (possibly_escaped_char (out_chars ), end = '' , flush = True )
146- states = forward_text (net , sess , states , relevance , vocab , chars [char_token ])
147- if i >= max_length : break
148- states = forward_text (net , sess , states , relevance , vocab , sanitize_text (vocab , "\n > " ))
130+ if not user_command_entered :
131+ states = forward_text (net , sess , states , relevance , vocab , sanitize_text (vocab , "> " + user_input + "\n >" ))
132+ computer_response_generator = beam_search_generator (sess = sess , net = net ,
133+ initial_state = copy .deepcopy (states ), initial_sample = vocab [' ' ],
134+ early_term_token = vocab ['\n ' ], beam_width = beam_width , forward_model_fn = forward_with_mask ,
135+ forward_args = {'relevance' :relevance , 'mask_reset_token' :vocab ['\n ' ], 'forbidden_token' :vocab ['>' ],
136+ ' temperature' : temperature , 'topn' : topn } )
137+ out_chars = []
138+ for i , char_token in enumerate (computer_response_generator ):
139+ out_chars .append (chars [char_token ])
140+ print (possibly_escaped_char (out_chars ), end = '' , flush = True )
141+ states = forward_text (net , sess , states , relevance , vocab , chars [char_token ])
142+ if i >= max_length : break
143+ states = forward_text (net , sess , states , relevance , vocab , sanitize_text (vocab , "\n > " ))
149144
150- def process_user_command (user_input , states , relevance , temperature , beam_width ):
145+ def process_user_command (user_input , states , relevance , temperature , topn , beam_width ):
151146 user_command_entered = False
152147 reset = False
153148 try :
@@ -164,6 +159,10 @@ def process_user_command(user_input, states, relevance, temperature, beam_width)
164159 states = states [0 ]
165160 relevance = new_relevance
166161 print ("[Relevance disabled]" if relevance <= 0. else "[Relevance set to {}]" .format (relevance ))
162+ elif user_input .startswith ('--topn ' ):
163+ user_command_entered = True
164+ topn = int (user_input [len ('--topn ' ):])
165+ print ("[Top-n filtering disabled]" if topn <= 0 else "[Top-n filtering set to {}]" .format (topn ))
167166 elif user_input .startswith ('--beam_width ' ):
168167 user_command_entered = True
169168 beam_width = max (1 , int (user_input [len ('--beam_width ' ):]))
@@ -174,7 +173,7 @@ def process_user_command(user_input, states, relevance, temperature, beam_width)
174173 print ("[Model state reset]" )
175174 except ValueError :
176175 print ("[Value error with provided argument.]" )
177- return user_command_entered , reset , states , relevance , temperature , beam_width
176+ return user_command_entered , reset , states , relevance , temperature , topn , beam_width
178177
179178def consensus_length (beam_outputs , early_term_token ):
180179 for l in range (len (beam_outputs [0 ])):
@@ -184,30 +183,50 @@ def consensus_length(beam_outputs, early_term_token):
184183 if beam_outputs [0 ][l ] != b [l ]: return l , False
185184 return l , False
186185
186+ def scale_prediction (prediction , temperature ):
187+ if (temperature == 1.0 ): return prediction # Temperature 1.0 makes no change
188+ np .seterr (divide = 'ignore' )
189+ scaled_prediction = np .log (prediction ) / temperature
190+ scaled_prediction = scaled_prediction - np .logaddexp .reduce (scaled_prediction )
191+ scaled_prediction = np .exp (scaled_prediction )
192+ np .seterr (divide = 'warn' )
193+ return scaled_prediction
194+
187195def forward_with_mask (sess , net , states , input_sample , forward_args ):
188- # forward_args is a dictionary containing relevance, mask_reset_token, forbidden_token .
196+ # forward_args is a dictionary containing arguments for generating probabilities .
189197 relevance = forward_args ['relevance' ]
190198 mask_reset_token = forward_args ['mask_reset_token' ]
191199 forbidden_token = forward_args ['forbidden_token' ]
200+ temperature = forward_args ['temperature' ]
201+ topn = forward_args ['topn' ]
202+
192203 if relevance <= 0. :
193204 # No relevance masking.
194205 prob , states = net .forward_model (sess , states , input_sample )
195- prob [forbidden_token ] = 0
196- return prob / sum (prob ), states
197- # states should be a 2-length list: [primary net state, mask net state].
198- if input_sample == mask_reset_token :
199- # Reset the mask probs when reaching mask_reset_token (newline).
200- states [1 ] = initial_state (net , sess )
201- primary_prob , states [0 ] = net .forward_model (sess , states [0 ], input_sample )
202- primary_prob /= sum (primary_prob )
203- mask_prob , states [1 ] = net .forward_model (sess , states [1 ], input_sample )
204- mask_prob /= sum (mask_prob )
205- combined_prob = np .exp (np .log (primary_prob ) - relevance * np .log (mask_prob ))
206+ else :
207+ # states should be a 2-length list: [primary net state, mask net state].
208+ if input_sample == mask_reset_token :
209+ # Reset the mask probs when reaching mask_reset_token (newline).
210+ states [1 ] = initial_state (net , sess )
211+ primary_prob , states [0 ] = net .forward_model (sess , states [0 ], input_sample )
212+ primary_prob /= sum (primary_prob )
213+ mask_prob , states [1 ] = net .forward_model (sess , states [1 ], input_sample )
214+ mask_prob /= sum (mask_prob )
215+ prob = np .exp (np .log (primary_prob ) - relevance * np .log (mask_prob ))
216+ # Mask out the forbidden token (">") to prevent the bot from deciding the chat is over)
217+ prob [forbidden_token ] = 0
206218 # Normalize probabilities so they sum to 1.
207- return combined_prob / sum (combined_prob ), states
219+ prob = prob / sum (prob )
220+ # Apply temperature.
221+ prob = scale_prediction (prob , temperature )
222+ # Apply top-n filtering if enabled
223+ if topn > 0 :
224+ prob [np .argsort (prob )[:- topn ]] = 0
225+ prob = prob / sum (prob )
226+ return prob , states
208227
209228def beam_search_generator (sess , net , initial_state , initial_sample ,
210- early_term_token , beam_width , forward_model_fn , forward_args , temperature ):
229+ early_term_token , beam_width , forward_model_fn , forward_args ):
211230 '''Run beam search! Yield consensus tokens sequentially, as a generator;
212231 return when reaching early_term_token (newline).
213232
@@ -224,8 +243,6 @@ def beam_search_generator(sess, net, initial_state, initial_sample,
224243 probability_output, beam_state =
225244 forward_model_fn(sess, net, beam_state, beam_sample, forward_args)
226245 (Note: probability_output has to be a valid probability distribution!)
227- temperature: how conservatively to sample tokens from each distribution
228- (1.0 = neutral, lower means more conservative)
229246 tot_steps: how many tokens to generate before stopping,
230247 unless already stopped via early_term_token.
231248 Returns: a generator to yield a sequence of beam-sampled tokens.'''
@@ -253,7 +270,6 @@ def beam_search_generator(sess, net, initial_state, initial_sample,
253270 # Forward the model.
254271 prediction , beam_states [beam_index ] = forward_model_fn (
255272 sess , net , beam_state , beam_sample , forward_args )
256- prediction = scale_prediction (prediction , temperature )
257273
258274 # Sample best_tokens from the probability distribution.
259275 # Sample from the scaled probability distribution beam_width choices
0 commit comments