Skip to content
This repository was archived by the owner on Oct 31, 2022. It is now read-only.

Commit e007317

Browse files
author
nshepperd
committed
Add top-p sampling
1 parent c180288 commit e007317

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

src/sample.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,21 @@ def _top_k():
2222
)
2323

2424

25-
def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0):
25+
def top_p_logits(logits, p):
26+
with tf.variable_scope('top_p_logits'):
27+
logits_sort = tf.sort(logits, direction='DESCENDING')
28+
probs_sort = tf.nn.softmax(logits_sort)
29+
probs_sums = tf.cumsum(probs_sort, axis=1, exclusive=True)
30+
logits_masked = tf.where(probs_sums < p, logits_sort, tf.ones_like(logits_sort)*1000) # [batchsize, vocab]
31+
min_logits = tf.reduce_min(logits_masked, axis=1) # [batchsize]
32+
return tf.where(
33+
logits < min_logits,
34+
tf.ones_like(logits, dtype=logits.dtype) * -1e10,
35+
logits,
36+
)
37+
38+
39+
def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, top_p=0.0):
2640
if start_token is None:
2741
assert context is not None, 'Specify exactly one of start_token and context!'
2842
else:
@@ -49,7 +63,10 @@ def step(hparams, tokens, past=None):
4963
def body(past, prev, output):
5064
next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
5165
logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature)
52-
logits = top_k_logits(logits, k=top_k)
66+
if top_p > 0.0:
67+
logits = top_p_logits(logits, p=top_p)
68+
else:
69+
logits = top_k_logits(logits, k=top_k)
5370
samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
5471
return [
5572
tf.concat([past, next_outputs['presents']], axis=-2),

train.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
parser.add_argument('--optimizer', type=str, default='adam', help='Optimizer. <adam|sgd>.')
3737
parser.add_argument('--noise', type=float, default=0.0, help='Add noise to input training data to regularize against typos.')
3838

39+
parser.add_argument('--top_k', type=int, default=40, help='K for top-k sampling.')
40+
parser.add_argument('--top_p', type=float, default=0.0, help='P for top-p sampling. Overrides top_k if set > 0.')
41+
3942
parser.add_argument('--restore_from', type=str, default='latest', help='Either "latest", "fresh", or a path to a checkpoint file')
4043
parser.add_argument('--run_name', type=str, default='run1', help='Run id. Name of subdirectory in checkpoint/ and samples/')
4144
parser.add_argument('--sample_every', metavar='N', type=int, default=100, help='Generate samples every N steps')
@@ -107,7 +110,8 @@ def main():
107110
context=context,
108111
batch_size=args.batch_size,
109112
temperature=1.0,
110-
top_k=40)
113+
top_k=args.top_k
114+
top_p=args.top_p)
111115

112116
all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
113117
train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars

0 commit comments

Comments
 (0)