Skip to content

Commit da47671

Browse files
committed
Load encoder and hparams, then encode input prompt
1 parent ad7ec53 commit da47671

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed

pytorch/gpt_pytorch.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import torch
22
import argparse
3+
import sys
4+
import os
5+
from utils import load_encoder_hparams
6+
37

48
if __name__ == "__main__":
59
parser = argparse.ArgumentParser()
@@ -17,3 +21,13 @@
1721

1822
state_dict = torch.load(args.model_path)
1923
print(f"state_dict: {len(state_dict.keys())} params")
24+
25+
model_size = "124M"
26+
models_dir = "models"
27+
encoder, hparams = load_encoder_hparams(model_size, models_dir)
28+
print("hparams:", hparams)
29+
30+
print("prompt:", args.prompt)
31+
input_ids = encoder.encode(args.prompt)
32+
input_text = encoder.decode(input_ids)
33+
print("input_ids:", input_ids)

pytorch/utils.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import os
2+
import json
3+
import sys
4+
import requests
5+
from tqdm import tqdm
6+
7+
# import picoGPT
8+
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
9+
sys.path.append(os.path.join(os.path.dirname(__file__), "../picoGPT"))
10+
from encoder import get_encoder
11+
12+
13+
# Copy from picoGPT because picoGPT/utils.py import tensorflow
14+
def download_gpt2_files(model_size, model_dir):
15+
assert model_size in ["124M", "355M", "774M", "1558M"]
16+
for filename in [
17+
"checkpoint",
18+
"encoder.json",
19+
"hparams.json",
20+
"model.ckpt.data-00000-of-00001",
21+
"model.ckpt.index",
22+
"model.ckpt.meta",
23+
"vocab.bpe",
24+
]:
25+
url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
26+
r = requests.get(f"{url}/{model_size}/{filename}", stream=True)
27+
r.raise_for_status()
28+
29+
with open(os.path.join(model_dir, filename), "wb") as f:
30+
file_size = int(r.headers["content-length"])
31+
chunk_size = 1000
32+
with tqdm(
33+
ncols=100,
34+
desc="Fetching " + filename,
35+
total=file_size,
36+
unit_scale=True,
37+
unit="b",
38+
) as pbar:
39+
# 1k for chunk_size, since Ethernet packet size is around 1500 bytes
40+
for chunk in r.iter_content(chunk_size=chunk_size):
41+
f.write(chunk)
42+
pbar.update(chunk_size)
43+
44+
45+
def load_encoder_hparams(model_size, models_dir):
46+
assert model_size in ["124M", "355M", "774M", "1558M"]
47+
48+
model_dir = os.path.join(models_dir, model_size)
49+
if not model_dir: # download files if necessary
50+
os.makedirs(model_dir, exist_ok=True)
51+
download_gpt2_files(model_size, model_dir)
52+
53+
encoder = get_encoder(model_size, models_dir)
54+
hparams = json.load(open(os.path.join(model_dir, "hparams.json")))
55+
56+
return encoder, hparams

0 commit comments

Comments
 (0)