Skip to content

Commit b37d2b5

Browse files
committed
Generate tokens works, LLM seems sane!
1 parent b95628b commit b37d2b5

File tree

9 files changed

+85
-199
lines changed

9 files changed

+85
-199
lines changed

autoencoder/generate_tokens.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import torch
2+
import logging
3+
import argparse
4+
from resource_loader import ResourceLoader
5+
6+
def main(gpt_ckpt_dir: str, prompt: str):
7+
resourceloader = ResourceLoader(
8+
dataset='shakespeare_char',
9+
gpt_ckpt_dir=gpt_ckpt_dir,
10+
device='cpu',
11+
mode="prepare",
12+
)
13+
enc_fxn, dec_fxn = resourceloader.load_tokenizer()
14+
tokens = torch.Tensor([enc_fxn(prompt)]).long()
15+
logging.info(tokens)
16+
generated = resourceloader.transformer.generate(
17+
idx=tokens,
18+
max_new_tokens=100,
19+
)
20+
generated = dec_fxn(generated.squeeze().tolist())
21+
print(generated)
22+
23+
24+
if __name__ == '__main__':
25+
parser = argparse.ArgumentParser()
26+
parser.add_argument('--gpt_ckpt_dir', type=str, default='')
27+
parser.add_argument('--prompt', type=str, help='Try "def run(" or "oh romeo!"')
28+
args = parser.parse_args()
29+
main(**vars(args))

autoencoder/resource_loader.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,5 @@ def load_tokenizer(self):
177177
encode = lambda s: [stoi[c] for c in s]
178178
decode = lambda l: ''.join([itos[i] for i in l])
179179
else:
180-
# ok let's assume gpt-2 encodings by default
181-
print("No meta.pkl found, assuming GPT-2 encodings...")
182-
enc = tiktoken.get_encoding("gpt2")
183-
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
184-
decode = lambda l: enc.decode(l)
180+
raise DeprecationWarning('must load from dataset dir')
185181
return encode, decode

transformer/data/openwebtext/prepare.py

Lines changed: 0 additions & 80 deletions
This file was deleted.

transformer/data/openwebtext/readme.md

Lines changed: 0 additions & 15 deletions
This file was deleted.

transformer/data/shakespeare/prepare.py

Lines changed: 0 additions & 33 deletions
This file was deleted.

transformer/data/shakespeare/readme.md

Lines changed: 0 additions & 9 deletions
This file was deleted.

transformer/data/shakespeare_char/prepare.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pickle
99
import requests
1010
import numpy as np
11+
import pandas as pd
1112

1213
# download the tiny shakespeare dataset
1314
input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt')
@@ -18,6 +19,22 @@
1819

1920
with open(input_file_path, 'r') as f:
2021
data = f.read()
22+
23+
# Add in some Python code training data so the model learns both Shakespare and Python
24+
df = pd.read_parquet(
25+
"hf://datasets/matlok/python-text-copilot-training-instruct-ai-research-2024-02-10/schema/train-0022-qwen-agent-qwen_agent.parquet"
26+
)
27+
python_code = '\n###\n'.join(df['code'].dropna().astype(str))
28+
python_code = python_code.encode('ascii', 'ignore').decode() # there's a few non-ascii characters but I don't want to deal with them
29+
30+
train_split = python_code[:int(len(python_code) * 0.9)]
31+
val_split = python_code[int(len(python_code) * 0.9):]
32+
33+
# Add the train split to the beginning, and then the val split at the end, so that
34+
# the code below to create the train/val splits works as expected.
35+
# In the industry, this is what we call a "insane awful hack".
36+
data = train_split + data + val_split
37+
2138
print(f"length of dataset in characters: {len(data):,}")
2239

2340
# get all the unique characters that occur in this text
@@ -60,9 +77,9 @@ def decode(l):
6077
with open(os.path.join(os.path.dirname(__file__), 'meta.pkl'), 'wb') as f:
6178
pickle.dump(meta, f)
6279

63-
# length of dataset in characters: 1115394
80+
# length of dataset in characters: 1,217,175
6481
# all the unique characters:
65-
# !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
66-
# vocab size: 65
67-
# train has 1003854 tokens
68-
# val has 111540 tokens
82+
# !"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~
83+
# vocab size: 96
84+
# train has 1,095,457 tokens
85+
# val has 121,718 tokens

transformer/model.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def from_pretrained(cls, model_type, override_args=None):
260260

261261
return model
262262

263-
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
263+
def configure_optimizers(self, weight_decay, learning_rate, betas):
264264
# start with all of the candidate parameters
265265
param_dict = {pn: p for pn, p in self.named_parameters()}
266266
# filter out those that do not require grad
@@ -277,12 +277,13 @@ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
277277
num_nodecay_params = sum(p.numel() for p in nodecay_params)
278278
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
279279
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
280+
280281
# Create AdamW optimizer and use the fused version if it is available
281-
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
282-
use_fused = fused_available and device_type == 'cuda'
283-
extra_args = dict(fused=True) if use_fused else dict()
284-
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
285-
print(f"using fused AdamW: {use_fused}")
282+
# fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
283+
# use_fused = fused_available and device_type == 'cuda'
284+
# extra_args = dict(fused=True) if use_fused else dict()
285+
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, fused=False)
286+
# print(f"using fused AdamW: {use_fused}")
286287

287288
return optimizer
288289

0 commit comments

Comments
 (0)