Skip to content

Commit a110937

Browse files
committed
Generate tokens
1 parent 25cdefb commit a110937

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

pytorch/gpt_pytorch.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from utils import load_encoder_hparams_and_params
66
from model import GPT2
77
import torchsummaryX
8+
from tqdm import tqdm
89

910

1011
if __name__ == "__main__":
@@ -31,3 +32,13 @@
3132

3233
model = GPT2(params, hparams, drop_p=0.1)
3334
torchsummaryX.summary(model, torch.ones(1, len(input_ids), dtype=torch.long))
35+
36+
for _ in tqdm(range(args.n_tokens_to_generate), "generating"):
37+
logits = model(torch.tensor(input_ids).unsqueeze(0))
38+
next_id = torch.argmax(logits[0][-1], dim=-1)
39+
input_ids.append(next_id.item())
40+
print("Input text:\n", input_text)
41+
print(
42+
"Generated:\n",
43+
encoder.decode(input_ids[len(input_ids) - args.n_tokens_to_generate :]),
44+
)

0 commit comments

Comments
 (0)