Skip to content

Commit 7f551db

Browse files
committed
new model export: versions 0 (legacy) and 1
1 parent bd18228 commit 7f551db

File tree

3 files changed

+245
-53
lines changed

3 files changed

+245
-53
lines changed

export.py

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
"""
2+
This script has functions and utilties for model export.
3+
Basically, we have a bunch of versions of the model, and we
4+
want to export them to .bin files to be read from and inferenced in C.
5+
6+
Among the "input" versions of PyTorch files/models:
7+
- Official Llama 2 weights released by Meta
8+
- Huggingface weights available on the hub
9+
- llama2.c (this repo) trained models
10+
11+
Among the "output" versions of .bin files:
12+
- v0: Legacy files of the original llama2.c repo (will eventually be DEPRECATED)
13+
- v1-vN: Improved .bin files with a proper header, cache alignment, etc.
14+
15+
This script aspires to provide all of these conversions.
16+
"""
17+
import struct
18+
import argparse
19+
import torch
20+
import numpy as np
21+
22+
from model import ModelArgs, Transformer
23+
24+
# -----------------------------------------------------------------------------
25+
# common utilities
26+
27+
def serialize_fp32(file, tensor):
28+
""" writes one fp32 tensor to file that is open in wb mode """
29+
d = tensor.detach().cpu().view(-1).numpy().astype(np.float32)
30+
b = struct.pack(f'{len(d)}f', *d)
31+
file.write(b)
32+
33+
def serialize_int8(file, tensor):
34+
""" writes one int8 tensor to file that is open in wb mode """
35+
d = tensor.detach().cpu().view(-1).numpy().astype(np.int8)
36+
b = struct.pack(f'{len(d)}b', *d)
37+
file.write(b)
38+
39+
def quantize_q80(w, group_size):
40+
"""
41+
takes a tensor and returns the Q8_0 quantized version
42+
i.e. symmetric quantization into int8, range [-127,127]
43+
"""
44+
assert w.numel() % group_size == 0
45+
ori_shape = w.shape
46+
w = w.float() # convert to float32
47+
w = w.reshape(-1, group_size)
48+
# find the max in each group
49+
wmax = torch.abs(w).max(dim=1).values
50+
# calculate the scaling factor such that float = quant * scale
51+
scale = wmax / 127.0
52+
# scale into range [-127, 127]
53+
quant = w / scale[:,None]
54+
# round to nearest integer
55+
int8val = torch.round(quant).to(torch.int8)
56+
# dequantize by rescaling
57+
fp32val = (int8val.float() * scale[:,None]).view(-1)
58+
fp32valr = fp32val.reshape(-1, group_size)
59+
# calculate the max error in each group
60+
err = torch.abs(fp32valr - w).max(dim=1).values
61+
# find the max error across all groups
62+
maxerr = err.max().item()
63+
return int8val, scale, maxerr
64+
65+
# -----------------------------------------------------------------------------
66+
# legacy
67+
68+
def legacy_export(model, filepath):
69+
""" Original export of llama2.c bin files, i.e. version v0 """
70+
out_file = open(filepath, 'wb')
71+
72+
# first write out the header
73+
hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
74+
p = model.params
75+
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
76+
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
77+
n_kv_heads, p.vocab_size, p.max_seq_len)
78+
out_file.write(header)
79+
80+
# next write out the embedding weights
81+
serialize_fp32(out_file, model.tok_embeddings.weight)
82+
83+
# now all the layers
84+
# attention weights
85+
for layer in model.layers:
86+
serialize_fp32(out_file, layer.attention_norm.weight)
87+
for layer in model.layers:
88+
serialize_fp32(out_file, layer.attention.wq.weight)
89+
for layer in model.layers:
90+
serialize_fp32(out_file, layer.attention.wk.weight)
91+
for layer in model.layers:
92+
serialize_fp32(out_file, layer.attention.wv.weight)
93+
for layer in model.layers:
94+
serialize_fp32(out_file, layer.attention.wo.weight)
95+
# ffn weights
96+
for layer in model.layers:
97+
serialize_fp32(out_file, layer.ffn_norm.weight)
98+
for layer in model.layers:
99+
serialize_fp32(out_file, layer.feed_forward.w1.weight)
100+
for layer in model.layers:
101+
serialize_fp32(out_file, layer.feed_forward.w2.weight)
102+
for layer in model.layers:
103+
serialize_fp32(out_file, layer.feed_forward.w3.weight)
104+
# final rmsnorm
105+
serialize_fp32(out_file, model.norm.weight)
106+
# note: no need to write final classifier weights due to weight sharing
107+
# freqs_cis
108+
serialize_fp32(out_file, model.freqs_cos[:p.max_seq_len])
109+
serialize_fp32(out_file, model.freqs_sin[:p.max_seq_len])
110+
111+
# write to binary file
112+
out_file.close()
113+
print(f"wrote {filepath}")
114+
115+
# -----------------------------------------------------------------------------
116+
# new version
117+
118+
def version1_export(model, filepath, group_size=64):
119+
"""
120+
Export the model weights in Q8_0 into .bin file to be read from C.
121+
That is:
122+
- quantize all weights to symmetric int8, in range [-127, 127]
123+
- all other tensors (the rmsnorm params) are kept and exported in fp32
124+
- quantization is done in groups of group_size to reduce the effects of any outliers
125+
"""
126+
version = 1
127+
128+
# let's first do some validation for this export type
129+
while model.params.dim % group_size != 0:
130+
group_size //= 2
131+
print(f"BACKOFF: reducing group size to {group_size} to fit hidden_dim")
132+
weights = [
133+
model.tok_embeddings.weight,
134+
*[layer.attention.wq.weight for layer in model.layers],
135+
*[layer.attention.wk.weight for layer in model.layers],
136+
*[layer.attention.wv.weight for layer in model.layers],
137+
*[layer.attention.wo.weight for layer in model.layers],
138+
*[layer.feed_forward.w1.weight for layer in model.layers],
139+
*[layer.feed_forward.w2.weight for layer in model.layers],
140+
*[layer.feed_forward.w3.weight for layer in model.layers],
141+
]
142+
for w in weights:
143+
assert w.numel() % group_size == 0, f"weight {i} has numel {w.numel()}, not a multiple of group_size {group_size}"
144+
145+
# write
146+
out_file = open(filepath, 'wb')
147+
# first write out the header. the header will be 256 bytes
148+
nbytes = 0
149+
# 1) write magic, which will be uint32 of "ak42" in ASCII
150+
out_file.write(struct.pack('I', 0x616b3432))
151+
nbytes += 4
152+
# 2) write version, which will be int
153+
out_file.write(struct.pack('i', version))
154+
nbytes += 4
155+
# 3) write the params, which will be 7 ints
156+
p = model.params
157+
hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
158+
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
159+
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
160+
n_kv_heads, p.vocab_size, p.max_seq_len)
161+
out_file.write(header)
162+
nbytes += 7*4
163+
# 4) write some other flags
164+
shared_classifier = 1 # we do share a classifier, write flag as a byte
165+
out_file.write(struct.pack('B', shared_classifier))
166+
nbytes += 1
167+
out_file.write(struct.pack('i', group_size)) # group size used for quantization
168+
nbytes += 4
169+
pad = 256 - nbytes # pad the rest with zeros
170+
assert pad >= 0
171+
out_file.write(b'\0' * pad)
172+
# now that the header is done, let's write out the model
173+
174+
# first let's write out all the params that we are keeping in fp32: the norms
175+
for layer in model.layers: # attention norms
176+
serialize_fp32(out_file, layer.attention_norm.weight)
177+
for layer in model.layers: # MLP norms
178+
serialize_fp32(out_file, layer.ffn_norm.weight)
179+
serialize_fp32(out_file, model.norm.weight) # final pre-classifier norm
180+
181+
# now let's write out all the params that we are quantizing to Q8_0
182+
# note we skip classifier weights, which are shared with the embedding
183+
ew = []
184+
scales = []
185+
for i, w in enumerate(weights):
186+
# quantize this weight
187+
q, s, err = quantize_q80(w, group_size)
188+
# save the int8 weights to file
189+
serialize_int8(out_file, q) # save the tensor in int8
190+
scales.append(s) # we'll do all the scales after all the qs
191+
# logging
192+
ew.append((err, w.shape))
193+
print(f"{i+1}/{len(weights)} quantized {tuple(w.shape)} to Q8_0 with max error {err}")
194+
195+
# save the scaling factors in fp32 here
196+
# this is done to keep all the weights contiquous, making pointer arithmetic easier in C
197+
for s in scales:
198+
serialize_fp32(out_file, s)
199+
200+
# print the highest error across all weights, should be very small, e.g. O(~0.001)
201+
ew.sort(reverse=True)
202+
print(f"max quantization group error across all weights: {ew[0][0]}")
203+
204+
# write to binary file
205+
out_file.close()
206+
print(f"wrote {filepath}")
207+
208+
# -----------------------------------------------------------------------------
209+
# API entrypoint
210+
211+
def model_export(model, filepath, version):
212+
if version == 0:
213+
legacy_export(model, filepath)
214+
elif version == 1:
215+
version1_export(model, filepath)
216+
else:
217+
raise ValueError(f"unknown version {version}")
218+
219+
# -----------------------------------------------------------------------------
220+
# CLI entrypoint
221+
222+
if __name__ == "__main__":
223+
224+
parser = argparse.ArgumentParser()
225+
parser.add_argument("filepath", type=str, help="the output filepath")
226+
parser.add_argument("--checkpoint", default="", type=str, help="model checkpoint, .pt file")
227+
parser.add_argument("--version", default=0, type=int, help="the version to export with")
228+
args = parser.parse_args()
229+
230+
# load the provided model checkpoint
231+
checkpoint_dict = torch.load(args.checkpoint, map_location='cpu')
232+
gptconf = ModelArgs(**checkpoint_dict['model_args'])
233+
model = Transformer(gptconf)
234+
state_dict = checkpoint_dict['model']
235+
unwanted_prefix = '_orig_mod.'
236+
for k,v in list(state_dict.items()):
237+
if k.startswith(unwanted_prefix):
238+
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
239+
model.load_state_dict(state_dict, strict=False)
240+
model.eval()
241+
242+
# export
243+
model_export(model, args.filepath, args.version)

model.py

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -338,55 +338,3 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
338338
idx = torch.cat((idx, idx_next), dim=1)
339339

340340
return idx
341-
342-
def export(self, filepath='model.bin'):
343-
"""export the model weights in fp32 into .bin file to be read from C"""
344-
f = open(filepath, 'wb')
345-
346-
def serialize(t):
347-
d = t.detach().cpu().view(-1).numpy().astype(np.float32)
348-
b = struct.pack(f'{len(d)}f', *d)
349-
f.write(b)
350-
351-
# first write out the header
352-
hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0]
353-
p = self.params
354-
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
355-
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
356-
n_kv_heads, p.vocab_size, p.max_seq_len)
357-
f.write(header)
358-
359-
# next write out the embedding weights
360-
serialize(self.tok_embeddings.weight)
361-
362-
# now all the layers
363-
# attention weights
364-
for layer in self.layers:
365-
serialize(layer.attention_norm.weight)
366-
for layer in self.layers:
367-
serialize(layer.attention.wq.weight)
368-
for layer in self.layers:
369-
serialize(layer.attention.wk.weight)
370-
for layer in self.layers:
371-
serialize(layer.attention.wv.weight)
372-
for layer in self.layers:
373-
serialize(layer.attention.wo.weight)
374-
# ffn weights
375-
for layer in self.layers:
376-
serialize(layer.ffn_norm.weight)
377-
for layer in self.layers:
378-
serialize(layer.feed_forward.w1.weight)
379-
for layer in self.layers:
380-
serialize(layer.feed_forward.w2.weight)
381-
for layer in self.layers:
382-
serialize(layer.feed_forward.w3.weight)
383-
# final rmsnorm
384-
serialize(self.norm.weight)
385-
# note: no need to write final classifier weights due to weight sharing
386-
# freqs_cis
387-
serialize(self.freqs_cos[:p.max_seq_len])
388-
serialize(self.freqs_sin[:p.max_seq_len])
389-
390-
# write to binary file
391-
f.close()
392-
print(f"wrote {filepath}")

train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from torch.nn.parallel import DistributedDataParallel as DDP
3030

3131
from tinystories import Task
32+
from export import model_export
3233

3334
# -----------------------------------------------------------------------------
3435
# I/O
@@ -287,7 +288,7 @@ def get_lr(it):
287288
}
288289
print(f"saving checkpoint to {out_dir}")
289290
torch.save(checkpoint, os.path.join(out_dir, "ckpt.pt"))
290-
raw_model.export(os.path.join(out_dir, "model.bin"))
291+
model_export(raw_model, os.path.join(out_dir, "model.bin"), version=0)
291292
if iter_num == 0 and eval_only:
292293
break
293294

0 commit comments

Comments
 (0)