Skip to content

Commit

Permalink
Fix indentation weirdness in GPT-2 example.
Browse files Browse the repository at this point in the history
  • Loading branch information
cynthia committed Apr 21, 2019
1 parent 68a889e commit 14b1f71
Showing 1 changed file with 17 additions and 19 deletions.
36 changes: 17 additions & 19 deletions examples/run_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,25 +107,23 @@ def run_model():
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text)
print("=" * 80)
if args.unconditional:
generated = 0
for _ in range(args.nsamples // args.batch_size):
out = sample_sequence(
model=model, length=args.length,
context=None,
start_token=enc.encoder['<|endoftext|>'],
batch_size=args.batch_size,
temperature=args.temperature, top_k=args.top_k, device=device
)
out = out[:,1:].tolist()
for i in range(args.batch_size):
generated += 1
text = enc.decode(out[i])
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text)
print("=" * 80)
if args.unconditional:
break
else:
generated = 0
for _ in range(args.nsamples // args.batch_size):
out = sample_sequence(
model=model, length=args.length,
context=None,
start_token=enc.encoder['<|endoftext|>'],
batch_size=args.batch_size,
temperature=args.temperature, top_k=args.top_k, device=device
)
out = out[:,1:].tolist()
for i in range(args.batch_size):
generated += 1
text = enc.decode(out[i])
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text)
print("=" * 80)

if __name__ == '__main__':
run_model()
Expand Down

0 comments on commit 14b1f71

Please sign in to comment.