|
| 1 | +import os |
| 2 | +import json |
| 3 | +import sys |
| 4 | +import requests |
| 5 | +from tqdm import tqdm |
| 6 | + |
| 7 | +# import picoGPT |
| 8 | +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) |
| 9 | +sys.path.append(os.path.join(os.path.dirname(__file__), "../picoGPT")) |
| 10 | +from encoder import get_encoder |
| 11 | + |
| 12 | + |
| 13 | +# Copy from picoGPT because picoGPT/utils.py import tensorflow |
| 14 | +def download_gpt2_files(model_size, model_dir): |
| 15 | + assert model_size in ["124M", "355M", "774M", "1558M"] |
| 16 | + for filename in [ |
| 17 | + "checkpoint", |
| 18 | + "encoder.json", |
| 19 | + "hparams.json", |
| 20 | + "model.ckpt.data-00000-of-00001", |
| 21 | + "model.ckpt.index", |
| 22 | + "model.ckpt.meta", |
| 23 | + "vocab.bpe", |
| 24 | + ]: |
| 25 | + url = "https://openaipublic.blob.core.windows.net/gpt-2/models" |
| 26 | + r = requests.get(f"{url}/{model_size}/{filename}", stream=True) |
| 27 | + r.raise_for_status() |
| 28 | + |
| 29 | + with open(os.path.join(model_dir, filename), "wb") as f: |
| 30 | + file_size = int(r.headers["content-length"]) |
| 31 | + chunk_size = 1000 |
| 32 | + with tqdm( |
| 33 | + ncols=100, |
| 34 | + desc="Fetching " + filename, |
| 35 | + total=file_size, |
| 36 | + unit_scale=True, |
| 37 | + unit="b", |
| 38 | + ) as pbar: |
| 39 | + # 1k for chunk_size, since Ethernet packet size is around 1500 bytes |
| 40 | + for chunk in r.iter_content(chunk_size=chunk_size): |
| 41 | + f.write(chunk) |
| 42 | + pbar.update(chunk_size) |
| 43 | + |
| 44 | + |
| 45 | +def load_encoder_hparams(model_size, models_dir): |
| 46 | + assert model_size in ["124M", "355M", "774M", "1558M"] |
| 47 | + |
| 48 | + model_dir = os.path.join(models_dir, model_size) |
| 49 | + if not model_dir: # download files if necessary |
| 50 | + os.makedirs(model_dir, exist_ok=True) |
| 51 | + download_gpt2_files(model_size, model_dir) |
| 52 | + |
| 53 | + encoder = get_encoder(model_size, models_dir) |
| 54 | + hparams = json.load(open(os.path.join(model_dir, "hparams.json"))) |
| 55 | + |
| 56 | + return encoder, hparams |
0 commit comments