Skip to content

Commit

Permalink
lora : add support for non-llama models (ggerganov#3333)
Browse files Browse the repository at this point in the history
* lora : add support for non-llama models

ggml-ci

* avoid leaking ggml_context on failure
cleanup

ggml-ci

* lora : allow 1d tensors

* lora : include embd and output layers in size calculation

* fix style
  • Loading branch information
slaren authored and teleprint-me committed Dec 21, 2023
1 parent 5a8674d commit ae561d1
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 105 deletions.
84 changes: 44 additions & 40 deletions convert-lora-to-ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,20 @@

import json
import os
import re
import struct
import sys
from typing import Any, BinaryIO, Sequence

import numpy as np
import torch

NUMPY_TYPE_TO_FTYPE: dict[str, int] = {"float32": 0, "float16": 1}

from pathlib import Path
if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
import gguf

HF_SUBLAYER_TO_GGML = {
"self_attn.q_proj": "attn_q",
"self_attn.k_proj": "attn_k",
"self_attn.v_proj": "attn_v",
"self_attn.o_proj": "attn_output",
"mlp.gate_proj": "ffn_gate",
"mlp.down_proj": "ffn_down",
"mlp.up_proj": "ffn_up",
"input_layernorm": "attn_norm",
"post_attention_layernorm": "ffn_norm",
}


def translate_tensor_name(t: str) -> str:
match = re.match(r".*layers\.(\d+)\.(\w+\.\w+)\.lora_(A|B)\.weight", t)
if match:
nn = match.group(1)
sub_layer = match.group(2)
lora_type = match.group(3)

sub_layer_renamed = HF_SUBLAYER_TO_GGML.get(sub_layer)
if sub_layer_renamed is None:
print(f"Error: unrecognized sub-layer {sub_layer} in tensor {t}")
sys.exit(1)

output_string = (
f"blk.{nn}.{HF_SUBLAYER_TO_GGML[sub_layer]}.weight.lora{lora_type}"
)
return output_string
else:
print(f"Error: unrecognized tensor {t}")
sys.exit(1)
NUMPY_TYPE_TO_FTYPE: dict[str, int] = {"float32": 0, "float16": 1}


def write_file_header(fout: BinaryIO, params: dict[str, Any]) -> None:
Expand All @@ -61,9 +32,7 @@ def write_file_header(fout: BinaryIO, params: dict[str, Any]) -> None:
fout.write(struct.pack("i", int(params["lora_alpha"])))


def write_tensor_header(
self, name: str, shape: Sequence[int], data_type: np.dtype[Any]
) -> None:
def write_tensor_header(fout: BinaryIO, name: str, shape: Sequence[int], data_type: np.dtype[Any]) -> None:
sname = name.encode("utf-8")
fout.write(
struct.pack(
Expand All @@ -78,18 +47,27 @@ def write_tensor_header(
fout.seek((fout.tell() + 31) & -32)


if len(sys.argv) != 2:
print(f"Usage: python {sys.argv[0]} <path>")
if len(sys.argv) < 2:
print(f"Usage: python {sys.argv[0]} <path> [arch]")
print(
"Path must contain HuggingFace PEFT LoRA files 'adapter_config.json' and 'adapter_model.bin'"
)
print(f"Arch must be one of {list(gguf.MODEL_ARCH_NAMES.values())} (default: llama)")
sys.exit(1)

input_json = os.path.join(sys.argv[1], "adapter_config.json")
input_model = os.path.join(sys.argv[1], "adapter_model.bin")
output_path = os.path.join(sys.argv[1], "ggml-adapter-model.bin")

model = torch.load(input_model, map_location="cpu")
arch_name = sys.argv[2] if len(sys.argv) == 3 else "llama"

if arch_name not in gguf.MODEL_ARCH_NAMES.values():
print(f"Error: unsupported architecture {arch_name}")
sys.exit(1)

arch = list(gguf.MODEL_ARCH_NAMES.keys())[list(gguf.MODEL_ARCH_NAMES.values()).index(arch_name)]
name_map = gguf.TensorNameMap(arch, 200) # 200 layers ought to be enough for anyone

with open(input_json, "r") as f:
params = json.load(f)
Expand Down Expand Up @@ -117,6 +95,7 @@ def write_tensor_header(

write_file_header(fout, params)
for k, v in model.items():
orig_k = k
if k.endswith(".default.weight"):
k = k.replace(".default.weight", ".weight")
if k in ["llama_proj.weight", "llama_proj.bias"]:
Expand All @@ -129,7 +108,32 @@ def write_tensor_header(
v = v.float()

t = v.detach().numpy()
tname = translate_tensor_name(k)

prefix = "base_model.model."
if k.startswith(prefix):
k = k[len(prefix) :]

lora_suffixes = (".lora_A.weight", ".lora_B.weight")
if k.endswith(lora_suffixes):
suffix = k[-len(lora_suffixes[0]):]
k = k[: -len(lora_suffixes[0])]
else:
print(f"Error: unrecognized tensor name {orig_k}")
sys.exit(1)

tname = name_map.get_name(k)
if tname is None:
print(f"Error: could not map tensor name {orig_k}")
print(" Note: the arch parameter must be specified if the model is not llama")
sys.exit(1)

if suffix == ".lora_A.weight":
tname += ".weight.loraA"
elif suffix == ".lora_B.weight":
tname += ".weight.loraB"
else:
assert False

print(f"{k} => {tname} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB")
write_tensor_header(fout, tname, t.shape, t.dtype)
t.tofile(fout)
Expand Down
Loading

0 comments on commit ae561d1

Please sign in to comment.