Skip to content

Commit

Permalink
py : cleanup the code
Browse files Browse the repository at this point in the history
- use f-strings where possible
- drop first param of encode/decode functions since "utf-8" is the default
  • Loading branch information
prusnak committed Mar 31, 2023
1 parent 9733104 commit cbef542
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 29 deletions.
16 changes: 8 additions & 8 deletions convert-ggml-to-pth.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def read_tokens(fin, vocab_size):
text_len = struct.unpack("i", fin.read(4))[0]
text_bytes = fin.read(text_len)
try:
text = text_bytes.decode("utf-8")
text = text_bytes.decode()
except UnicodeDecodeError:
text = text_bytes.decode("utf-8", "replace")
text = text_bytes.decode(errors="replace")
score = struct.unpack("f", fin.read(4))[0]
tokens.append((text, score))
return tokens
Expand Down Expand Up @@ -82,7 +82,7 @@ def read_variables(fin):

shape = tuple(struct.unpack("i" * n_dims, fin.read(4 * n_dims)))
shape = shape[::-1]
name = fin.read(name_length).decode("utf-8")
name = fin.read(name_length).decode()

# ensure tensor data is aligned
tensor_data_offset = fin.tell()
Expand Down Expand Up @@ -199,19 +199,19 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, stops
device = torch.device("cpu")
llama = llama.to(device)

ctx = """You are AI.
ctx = """You are AI.
This is a dialog, where User interacts with AI. AI is helpful, kind, obedient, honest, respectful, direct, concise, should try to protect User's privacy, and knows its own limits. Also, AI must answer User and AI cannot stop the conversation by itself.
User: Hello, AI.
AI: Hello! How can I assist you today?
"""
print(ctx.rstrip("\n"))
while True:
print("-" * 60)
prompt = input(f"User: ")
prompt = input("User: ")
if ctx != "":
ctx = ctx + "User: " + prompt + "\n"
ctx = f"{ctx}User: {prompt}\n"
else:
ctx = prompt + "\nAI:"
ctx = f"{prompt}\nAI:"

ctx = (ctx[-1920:]) if len(ctx) >= 2048 else ctx

Expand All @@ -236,7 +236,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, stops
)
s = generation_output.sequences[0]
decoded = tokenizer.decode(s)
ctx = decoded + "\n"
ctx = f"{decoded}\n"


def main():
Expand Down
6 changes: 3 additions & 3 deletions convert-gpt4all-to-ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def write_header(f_out, header):
def write_tokens(fout, tokenizer):
for i in range(tokenizer.vocab_size()):
if tokenizer.is_unknown(i):
text = " \u2047 ".encode("utf-8")
text = " \u2047 ".encode()
elif tokenizer.is_control(i):
text = b""
elif tokenizer.is_byte(i):
Expand All @@ -60,13 +60,13 @@ def write_tokens(fout, tokenizer):
byte_value = int(piece[3:-1], 16)
text = struct.pack("B", byte_value)
else:
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode()
fout.write(struct.pack("i", len(text)))
fout.write(text)
fout.write(struct.pack("f", tokenizer.get_score(i)))

# TODO: GPT4All - add extra <pad> token
text = "<pad>".encode("utf-8")
text = "<pad>".encode()
fout.write(struct.pack("i", len(text)))
fout.write(text)
fout.write(struct.pack("f", 0.0))
Expand Down
14 changes: 7 additions & 7 deletions convert-gptq-to-ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
# This loop unchanged from convert-pth-to-ggml.py:
for i in range(tokenizer.vocab_size()):
if tokenizer.is_unknown(i):
text = " \u2047 ".encode("utf-8")
text = " \u2047 ".encode()
elif tokenizer.is_control(i):
text = b""
elif tokenizer.is_byte(i):
Expand All @@ -61,13 +61,13 @@
byte_value = int(piece[3:-1], 16)
text = struct.pack("B", byte_value)
else:
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode()
fout.write(struct.pack("i", len(text)))
fout.write(text)
fout.write(struct.pack("f", tokenizer.get_score(i)))

def write_header(shape, dst_name, ftype_cur):
sname = dst_name.encode('utf-8')
sname = dst_name.encode()
fout.write(struct.pack("iii", len(shape), len(sname), ftype_cur))
fout.write(struct.pack("i" * len(shape), *shape[::-1]))
fout.write(sname)
Expand All @@ -80,7 +80,7 @@ def write_header(shape, dst_name, ftype_cur):
def convert_non_q4(src_name, dst_name):
v = model[src_name]
shape = v.shape
print("Processing non-Q4 variable: " + src_name + " with shape: ", shape, " and type: ", v.dtype)
print(f"Processing non-Q4 variable: {src_name} with shape: {shape} and type: {v.dtype}")
if len(shape) == 1:
print(" Converting to float32")
v = v.to(torch.float32)
Expand All @@ -105,7 +105,7 @@ def convert_q4(src_name, dst_name, permute=False):
# Each int32 item is actually 8 int4 items packed together, and it's transposed.
shape = (qweight.shape[0], qweight.shape[1] * 8)

print("Processing Q4 variable: " + src_name + " with shape: ", shape)
print(f"Processing Q4 variable: {src_name} with shape: {shape}")

# The output format has the int4 weights in groups of 32 rather than 8.
# It looks like this:
Expand Down Expand Up @@ -168,5 +168,5 @@ def convert_q4(src_name, dst_name, permute=False):

fout.close()

print("Done. Output file: " + fname_out)
print("")
print(f"Done. Output file: {fname_out}")
print()
6 changes: 3 additions & 3 deletions convert-pth-to-ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def write_header(fout, hparams, ftype):
def write_tokens(fout, tokenizer):
for i in range(tokenizer.vocab_size()):
if tokenizer.is_unknown(i):
text = " \u2047 ".encode("utf-8")
text = " \u2047 ".encode()
elif tokenizer.is_control(i):
text = b""
elif tokenizer.is_byte(i):
Expand All @@ -131,7 +131,7 @@ def write_tokens(fout, tokenizer):
byte_value = int(piece[3:-1], 16)
text = struct.pack("B", byte_value)
else:
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode()
fout.write(struct.pack("i", len(text)))
fout.write(text)
fout.write(struct.pack("f", tokenizer.get_score(i)))
Expand Down Expand Up @@ -191,7 +191,7 @@ def process_and_write_variables(fout, model, ftype, part_id, n_parts):
fullshape = list(partshape)
if n_dims > 1:
fullshape[split_dim] *= n_parts
sname = name.encode('utf-8')
sname = name.encode()
fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur))
for dim in reversed(fullshape):
fout.write(struct.pack("i", dim))
Expand Down
4 changes: 2 additions & 2 deletions convert-unversioned-ggml-to-ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def write_header(f_out, header):
def write_tokens(fout, tokenizer):
for i in range(tokenizer.vocab_size()):
if tokenizer.is_unknown(i):
text = " \u2047 ".encode("utf-8")
text = " \u2047 ".encode()
elif tokenizer.is_control(i):
text = b""
elif tokenizer.is_byte(i):
Expand All @@ -55,7 +55,7 @@ def write_tokens(fout, tokenizer):
byte_value = int(piece[3:-1], 16)
text = struct.pack("B", byte_value)
else:
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode()
fout.write(struct.pack("i", len(text)))
fout.write(text)
fout.write(struct.pack("f", tokenizer.get_score(i)))
Expand Down
10 changes: 4 additions & 6 deletions migrate-ggml-2023-03-30-pr613.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,21 +272,19 @@ def main():
tokens = read_tokens(fin, hparams)

if hparams['magic'] == 0x67676a74: # ggjt
print("%s: input ggml has already been converted to 'ggjt' magic\n" %
(args.fin_path))
print(f"{args.fin_path}: input ggml has already been converted to 'ggjt' magic\n")
sys.exit(1)

if hparams['magic'] != 0x67676d66: # ggmf
print("%s: input ggml file doesn't have expected 'ggmf' magic: %#x\n" %
(args.fin_path, hparams['magic']))
print(f"{args.fin_path}: input ggml file doesn't have expected 'ggmf' magic: {hparams['magic']:#x}\n")
sys.exit(1)

hparams['magic'] = 0x67676a74 # ggjt

# count number of multipart files by convention
n_parts = 1
while True:
if os.path.exists("%s.%d" % (args.fin_path, n_parts)):
if os.path.exists(f"{args.fin_path}.{n_parts}"):
n_parts += 1
else:
break
Expand All @@ -302,7 +300,7 @@ def main():
print(f"Processing part {part_id+1} of {n_parts}\n")
fin_path = args.fin_path
if part_id > 0:
fin_path += ".%d" % (part_id)
fin_path += f".{part_id}"
with open(fin_path, "rb") as fin:
read_tokens(fin, read_hparams(fin))
copy_tensors(fin, fout, part_id, n_parts)
Expand Down

0 comments on commit cbef542

Please sign in to comment.