@@ -22,7 +22,21 @@ def _top_k():
22
22
)
23
23
24
24
25
- def sample_sequence (* , hparams , length , start_token = None , batch_size = None , context = None , temperature = 1 , top_k = 0 ):
25
+ def top_p_logits (logits , p ):
26
+ with tf .variable_scope ('top_p_logits' ):
27
+ logits_sort = tf .sort (logits , direction = 'DESCENDING' )
28
+ probs_sort = tf .nn .softmax (logits_sort )
29
+ probs_sums = tf .cumsum (probs_sort , axis = 1 , exclusive = True )
30
+ logits_masked = tf .where (probs_sums < p , logits_sort , tf .ones_like (logits_sort )* 1000 ) # [batchsize, vocab]
31
+ min_logits = tf .reduce_min (logits_masked , axis = 1 ) # [batchsize]
32
+ return tf .where (
33
+ logits < min_logits ,
34
+ tf .ones_like (logits , dtype = logits .dtype ) * - 1e10 ,
35
+ logits ,
36
+ )
37
+
38
+
39
+ def sample_sequence (* , hparams , length , start_token = None , batch_size = None , context = None , temperature = 1 , top_k = 0 , top_p = 0.0 ):
26
40
if start_token is None :
27
41
assert context is not None , 'Specify exactly one of start_token and context!'
28
42
else :
@@ -49,7 +63,10 @@ def step(hparams, tokens, past=None):
49
63
def body (past , prev , output ):
50
64
next_outputs = step (hparams , prev [:, tf .newaxis ], past = past )
51
65
logits = next_outputs ['logits' ][:, - 1 , :] / tf .to_float (temperature )
52
- logits = top_k_logits (logits , k = top_k )
66
+ if top_p > 0.0 :
67
+ logits = top_p_logits (logits , p = top_p )
68
+ else :
69
+ logits = top_k_logits (logits , k = top_k )
53
70
samples = tf .multinomial (logits , num_samples = 1 , output_dtype = tf .int32 )
54
71
return [
55
72
tf .concat ([past , next_outputs ['presents' ]], axis = - 2 ),
0 commit comments