Skip to content

Commit

Permalink
support splits in convert.py
Browse files Browse the repository at this point in the history
  • Loading branch information
christianazinn committed Apr 27, 2024
1 parent 928e0b7 commit 874c341
Showing 1 changed file with 74 additions and 4 deletions.
78 changes: 74 additions & 4 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,16 @@

DEFAULT_CONCURRENCY = 8

DEFAULT_SPLIT_TENSORS = 128

ADDED_TOKENS_FILE = 'added_tokens.json'
FAST_TOKENIZER_FILE = 'tokenizer.json'

LLM_KV_SPLIT_NO = "split.no"
LLM_KV_SPLIT_COUNT = "split.count"
LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"
SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf"

#
# data types
#
Expand Down Expand Up @@ -1235,6 +1242,49 @@ def write_all(

of.close()

@staticmethod
def write_split(
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
total_tensors: int, concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
pad_vocab: bool = False, tensors_per_shard: int = DEFAULT_SPLIT_TENSORS, small_first_shard: bool = True,
) -> None:
check_vocab_size(params, vocab, pad_vocab=pad_vocab)

model_list = list(model.items())
total_shards = math.ceil(total_tensors / tensors_per_shard) + small_first_shard
shard_files = [fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, i + 1, total_shards)) for i in range(total_shards)]

for i, shard in enumerate(shard_files):
of = OutputFile(shard, endianess=endianess)

if i == 0:
of.add_meta_arch(params)
if isinstance(vocab, Vocab):
of.add_meta_vocab(vocab)
of.add_meta_special_vocab(svocab)
else: # NoVocab
of.gguf.add_tokenizer_model(vocab.tokenizer_model)

of.gguf.add_uint16(LLM_KV_SPLIT_NO, i)
of.gguf.add_uint16(LLM_KV_SPLIT_COUNT, total_shards)
of.gguf.add_int32(LLM_KV_SPLIT_TENSORS_COUNT, total_tensors)

# have the option to write a first shard with only the metadata
if small_first_shard and i == 0:
of.write_meta()
of.close()
continue

stop = min((i + 1 - small_first_shard) * tensors_per_shard, total_tensors)
shard_models = model_list[(i - small_first_shard) * tensors_per_shard:stop]
for name, lazy_tensor in shard_models:
of.add_tensor_info(name, lazy_tensor)

of.write_meta()
of.write_tensor_info()
of.write_tensor_data(ftype, dict(shard_models), concurrency)
of.close()


def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType:
wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) + ".weight"].data_type
Expand Down Expand Up @@ -1473,6 +1523,9 @@ def main(args_in: list[str] | None = None) -> None:
parser.add_argument("--big-endian", action="store_true", help="model is executed on big endian machine")
parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides")
parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
parser.add_argument("--split", action="store_true", help="split the converted model into multiple files")
parser.add_argument("--split-max-tensors", type=int, help=f"maximum number of tensors per file when splitting (default: {DEFAULT_SPLIT_TENSORS})", default=DEFAULT_SPLIT_TENSORS)
parser.add_argument("--large-first-shard", action="store_true", help="include tensors in the first shard when splitting (default is to only include metadata)")

args = parser.parse_args(args_in)
if args.no_vocab and args.vocab_only:
Expand Down Expand Up @@ -1544,11 +1597,28 @@ def main(args_in: list[str] | None = None) -> None:
outfile = args.outfile or default_outfile(model_plus.paths, ftype)

params.ftype = ftype
print(f"Writing {outfile}, format {ftype}")

OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,
concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab)
print(f"Wrote {outfile}")
if args.split:
total_tensors = len(model)
if total_tensors < args.split_max_tensors:

print("Model has fewer tensors than the split threshold, not splitting")
print(f"Writing {outfile}, format {ftype}")
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,
concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab)
else:
print(f"Writing {outfile} as shards, format {ftype}")
OutputFile.write_split(outfile, ftype, params, model, vocab, special_vocab, total_tensors,
concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab,
tensors_per_shard=args.split_max_tensors, small_first_shard=not args.large_first_shard)
print(f"Wrote {outfile}")

else:
print(f"Writing {outfile}, format {ftype}")

OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,
concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab)
print(f"Wrote {outfile}")


if __name__ == '__main__':
Expand Down

0 comments on commit 874c341

Please sign in to comment.