Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replit + MPT #145

Merged
merged 26 commits into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
418d844
Add replit model
lukasmoellerch May 9, 2023
c08ab88
Add unigram tokenization support
lukasmoellerch May 10, 2023
77e6c87
Remove debug log
lukasmoellerch May 10, 2023
c880ab6
Port alibi attn bias fix
lukasmoellerch May 10, 2023
1eb5cda
Remove torch input
lukasmoellerch May 10, 2023
0ea3532
Fix hardcoded path
lukasmoellerch May 10, 2023
9bea19d
Remove unsupported hyperparams
lukasmoellerch May 10, 2023
bf237cb
Add mpt
lukasmoellerch May 10, 2023
353873d
Add replit quantization script
lukasmoellerch May 10, 2023
ab0bc67
Remove debug print
lukasmoellerch May 10, 2023
b43281a
Add quantization support to mpt
lukasmoellerch May 10, 2023
337630d
Reformat
lukasmoellerch May 13, 2023
45f59c6
Remove trailing return type
lukasmoellerch May 13, 2023
0963049
Implement stylistic changes
lukasmoellerch May 13, 2023
011d2b1
use f16 in k/v memory calculations for replit/mpt
lukasmoellerch May 13, 2023
478fde7
Update context size calculation
lukasmoellerch May 14, 2023
0f64ed9
Add clip_qkv and alibi_bias_max support
lukasmoellerch May 14, 2023
3e7e3e0
fix clamping implementation, remove implicit conversions
lukasmoellerch May 14, 2023
bf2cc21
Fix qkv if condition
lukasmoellerch May 14, 2023
0acd380
Merge branch 'master' into replit
lukasmoellerch May 14, 2023
ff8de0e
Fix replit context size calculation
lukasmoellerch May 14, 2023
26c9e91
Potentially fix gcc compilation error
lukasmoellerch May 14, 2023
018d973
Merge branch 'master' into replit
lukasmoellerch May 14, 2023
43c3245
Fix warning
lukasmoellerch May 14, 2023
409da4f
Adjust object overhead
lukasmoellerch May 14, 2023
d404de8
Remove dead code
lukasmoellerch May 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@ add_subdirectory(whisper)
add_subdirectory(mnist)
add_subdirectory(gpt-neox)
add_subdirectory(dolly-v2)
add_subdirectory(replit)
add_subdirectory(mpt)
add_subdirectory(starcoder)
13 changes: 13 additions & 0 deletions examples/mpt/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#
# mpt

set(TEST_TARGET mpt)
add_executable(${TEST_TARGET} main.cpp)
target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)

#
# mpt-quantize

set(TEST_TARGET mpt-quantize)
add_executable(${TEST_TARGET} quantize.cpp)
target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)
111 changes: 111 additions & 0 deletions examples/mpt/convert-h5-to-ggml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import sys
import struct
import json
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import sentencepiece.sentencepiece_model_pb2 as model

if len(sys.argv) < 3:
print("Usage: convert-h5-to-ggml.py dir-model [use-f32]\n")
print(" ftype == 0 -> float32")
print(" ftype == 1 -> float16")
sys.exit(1)


# output in the same directory as the model
dir_model = sys.argv[1]
fname_out = sys.argv[1] + "/ggml-model.bin"


with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)

# possible data types
# ftype == 0 -> float32
# ftype == 1 -> float16
#
# map from ftype to string
ftype_str = ["f32", "f16"]

ftype = 1
if len(sys.argv) > 2:
ftype = int(sys.argv[2])
if ftype < 0 or ftype > 1:
print("Invalid ftype: " + str(ftype))
sys.exit(1)
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"


tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
dir_model, low_cpu_mem_usage=True, trust_remote_code=True
)
# print (model)

# print(tokenizer.encode('I believe the meaning of life is'))

list_vars = model.state_dict()
for name in list_vars.keys():
print(name, list_vars[name].shape, list_vars[name].dtype)

fout = open(fname_out, "wb")

print(hparams)

fout.write(struct.pack("i", 0x67676D6C)) # magic: ggml in hex
fout.write(struct.pack("i", hparams["d_model"]))
fout.write(struct.pack("i", hparams["max_seq_len"]))
fout.write(struct.pack("i", hparams["n_heads"]))
fout.write(struct.pack("i", hparams["n_layers"]))
fout.write(struct.pack("i", hparams["vocab_size"]))
fout.write(struct.pack("f", hparams["attn_config"]["alibi_bias_max"]))
fout.write(struct.pack("f", hparams["attn_config"]["clip_qkv"] or 0.0))
fout.write(struct.pack("i", ftype))


# TODO: temporary hack to not deal with implementing the tokenizer
dot_token = tokenizer.encode(".")[0]
for i in range(hparams["vocab_size"]):
text = tokenizer.decode([dot_token, i]).encode("utf-8")
# remove the first byte (it's always '.')
text = text[1:]
fout.write(struct.pack("i", len(text)))
fout.write(text)

for name in list_vars.keys():
data = list_vars[name].squeeze().numpy()
print("Processing variable: " + name + " with shape: ", data.shape)

n_dims = len(data.shape)

# ftype == 0 -> float32, ftype == 1 -> float16
ftype_cur = 0
if ftype != 0:
if name[-7:] == ".weight" and n_dims == 2:
print(" Converting to float16")
data = data.astype(np.float16)
ftype_cur = 1
else:
print(" Converting to float32")
data = data.astype(np.float32)
ftype_cur = 0
else:
if data.dtype != np.float32:
print(" Converting to float32")
data = data.astype(np.float32)
ftype_cur = 0

# header
str = name.encode("utf-8")
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
for i in range(n_dims):
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
fout.write(str)

# data
data.tofile(fout)

fout.close()

print("Done. Output file: " + fname_out)
print("")
Loading