Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added conversion script and example #1

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions conversion/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Conversion

The following example demonstrates how to convert a GPTQ model from the HF hub to Marlin format.

### Install

In addition to Marlin and PyTorch, install the following:

```bash
pip install -U transformers accelerate auto-gptq optimum
```

### Convert GTPQ Model to Marlin Format

The following converts the model from GPTQ to Marlin format. Note that this requires:
- `sym=true`
- `group_size=128`
- `desc_activations=false`

```bash
python3 convert.py --model-id "TheBloke/Llama-2-7B-Chat-GPTQ" --save-path "./marlin-model" --do-generation
```

### Load Marlin Model

The following loads the Marlin model from disk.

```python
from load import load_model
from transformers import AutoTokenizer

# Load model from disk.
model_path = "./marlin-model"
model = load_model(model_path).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_path)


# Run inference to confirm it is working.
inputs = tokenizer("My favorite song is", return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=False)
print(tokenizer.batch_decode(outputs)[0])
```

Output:
```bash
<s> My favorite song is "Bohemian Rhapsody" by Queen. I love the operatic vocals, the guitar solo, and the way the song builds from a slow ballad to a full-on rock anthem. I've been listening to it
```
161 changes: 161 additions & 0 deletions conversion/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import torch, argparse, copy
from transformers import AutoModelForCausalLM, AutoTokenizer
from auto_gptq.nn_modules.qlinear.qlinear_exllama import QuantLinear
from marlin import Layer as MarlinLayer
import gc

parser = argparse.ArgumentParser()
parser.add_argument("--model-id", type=str)
parser.add_argument("--save-path", type=str)
parser.add_argument("--do-generation", action="store_true")

def _validate_compatibility(model):
if not hasattr(model.config, "quantization_config"):
raise ValueError("Must be a quantized model to convert to Marlin Format")
quantization_config = model.config.quantization_config
if quantization_config.quant_method != "gptq":
raise ValueError(f"Only GPTQ models can be converted to Marlin format. You passed a model with quant_method={quantization_config.quant_method}")
if quantization_config.bits != 4:
raise ValueError(f"Only 4 bit quantized models can be converted to Marlin format. You passed a model with bits={quantization_config.bits}")
if quantization_config.group_size != 128:
raise ValueError(f"Only group size 128 models can be converted to Marlin format. You passed a model with group_size={quantization_config.group_size}")
if not quantization_config.sym:
raise ValueError(f"Only models with symmetric quantization can be converted to Marlin Format. You passed a model with sym={quantization_config.sym}")
if quantization_config.desc_act:
raise ValueError(f"Models with act order quantization cannot be converted to Marlin Format. You passed a model with desc_act={quantization_config.desc_act}")

@torch.no_grad()
def unpack_4bit_to_32bit_signed(qweight, qzeros):
# Unpack 4-bit values and interpret them as signed integers
unpacked_weights = torch.zeros((qweight.shape[0]*8, qweight.shape[1]), dtype=torch.int8, device=qweight.device, requires_grad=False)
unpacked_zeros = torch.zeros((qzeros.shape[0], qzeros.shape[1]*8), dtype=torch.int8, device=qzeros.device, requires_grad=False)

for row in range(unpacked_weights.shape[0]):
i = row % 8
unpacked_weights[row, :] = (qweight[row // 8, :] >> (4 * i)) & 0xF

for col in range(unpacked_zeros.shape[1]):
i = col % 8
unpacked_zeros[:, col] = (qzeros[:, col // 8] >> (4 * i)) & 0xF

return unpacked_weights, unpacked_zeros + 1

@torch.no_grad()
def dequantize_weight(layer):
qweight, qzeros, scales = layer.qweight, layer.qzeros, layer.scales
unpacked_qweight, unpacked_qzeros = unpack_4bit_to_32bit_signed(qweight, qzeros)
group_size = unpacked_qweight.shape[0] // scales.shape[0]
scales = scales.repeat_interleave(group_size, dim=0)
unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0)
unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales

return unpacked_qweight.T

@torch.no_grad()
def convert_model(model, verbose=True):
for name, module in model.named_modules():
if not isinstance(module, QuantLinear):
continue

if verbose:
print(f"--- Converting Module: {name}")
parent_name = ".".join(name.split(".")[:-1])
layer_name = name[len(parent_name) + 1:]

# Dequantize the weight.
dequantized_weight = dequantize_weight(module).to(torch.float16)
linear_module = torch.nn.Linear(
in_features=dequantized_weight.shape[1],
out_features=dequantized_weight.shape[0],
bias=False,
dtype=torch.float16,
device="cuda")
linear_module.weight.data.copy_(dequantized_weight)

# Create new linear method and copy to model.
new_module = MarlinLayer(
infeatures=linear_module.in_features,
outfeatures=linear_module.out_features,
groupsize=model.config.quantization_config.group_size)
new_module.pack(linear_module, scales=copy.deepcopy(module.scales.data.t()))

# Save to parent.
parent_module = model.get_submodule(parent_name)
setattr(parent_module, layer_name, new_module)

# Free cuda memory.
del dequantized_weight, module
torch.cuda.empty_cache()
gc.collect()

return model

@torch.no_grad()
def dequantize_model(model, verbose=True):
for name, module in model.named_modules():
if not isinstance(module, QuantLinear):
continue

if verbose:
print(f"--- Dequantizing Module: {name}")
parent_name = ".".join(name.split(".")[:-1])
layer_name = name[len(parent_name) + 1:]

# Dequantize the weight.
dequantized_weight = dequantize_weight(module)
dequantized_weight_cpu = dequantized_weight.to("cpu")

# Create new linear method and copy to model.
new_module = torch.nn.Linear(
in_features=dequantized_weight_cpu.shape[1],
out_features=dequantized_weight_cpu.shape[0],
bias=False,
dtype=torch.float16)
new_module.weight.data.copy_(dequantized_weight_cpu)
new_module.scales = torch.nn.Parameter(copy.deepcopy(module.scales.data))

# Save to parent.
parent_module = model.get_submodule(parent_name)
setattr(parent_module, layer_name, new_module)

# Free cuda memory.
del dequantized_weight, dequantized_weight_cpu, module
torch.cuda.empty_cache()

return model

if __name__ == "__main__":
args = parser.parse_args()
model_id = args.model_id
save_path = args.save_path
do_generation = args.do_generation

print("Loading gptq model...")
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Validate that this model is compatible with Marlin.
print("Validating compatibility...")
_validate_compatibility(model)

# Dequantize the Model.
print("Converting model...")
model = convert_model(model).to("cpu")

# Save after updating quantization config.
print("Saving marlin model...")
model.config.quantization_config = {
"group_size": model.config.quantization_config.group_size,
"quant_method": "marlin"
}
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

if do_generation:
print("Generating sample text...")
model.to("cuda")
prompt = "My favorite song is"
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=False)
print(tokenizer.batch_decode(outputs)[0])
172 changes: 172 additions & 0 deletions conversion/load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import torch
import numpy as np
from huggingface_hub import snapshot_download
from safetensors.torch import safe_open
from typing import Optional, Tuple, List, Iterator
import os, filelock, json, glob
from accelerate import init_empty_weights
from transformers import AutoModelForCausalLM, AutoConfig
import marlin

# Adapted from https://github.com/vllm-project/vllm/blob/14cc317ba48229d93ee2417822d96ccb8db56abe/vllm/model_executor/weight_utils.py#L191

def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
lock_dir = cache_dir if cache_dir is not None else "/tmp"
lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
return lock

def prepare_hf_model_weights(
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
fall_back_to_pt: bool = True,
revision: Optional[str] = None,
) -> Tuple[str, List[str], bool]:
# Download model weights from huggingface.
is_local = os.path.isdir(model_name_or_path)
use_safetensors = False
# Some quantized models use .pt files for storing the weights.
if load_format == "auto":
allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == "safetensors":
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == "pt":
allow_patterns = ["*.pt"]
elif load_format == "npcache":
allow_patterns = ["*.bin"]
else:
raise ValueError(f"Unknown load_format: {load_format}")

if fall_back_to_pt:
allow_patterns += ["*.pt"]

if not is_local:
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
hf_folder = snapshot_download(model_name_or_path,
allow_patterns=allow_patterns,
cache_dir=cache_dir,
revision=revision)
else:
hf_folder = model_name_or_path
hf_weights_files: List[str] = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if len(hf_weights_files) > 0:
if pattern == "*.safetensors":
use_safetensors = True
break
if not use_safetensors:
# Exclude files that are not needed for inference.
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
blacklist = [
"training_args.bin",
"optimizer.bin",
"optimizer.pt",
"scheduler.pt",
"scaler.pt",
]
hf_weights_files = [
f for f in hf_weights_files
if not any(f.endswith(x) for x in blacklist)
]

if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`")

return hf_folder, hf_weights_files, use_safetensors

def hf_model_weights_iterator(
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
fall_back_to_pt: Optional[bool] = True,
) -> Iterator[Tuple[str, torch.Tensor]]:
hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
model_name_or_path,
cache_dir=cache_dir,
load_format=load_format,
fall_back_to_pt=fall_back_to_pt,
revision=revision)

if load_format == "npcache":
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False

# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
np_folder = os.path.join(hf_folder, "np")
os.makedirs(np_folder, exist_ok=True)
weight_names_file = os.path.join(np_folder, "weight_names.json")
# Use file lock to prevent multiple processes from
# dumping the same model weights to numpy at the same time.
with get_lock(model_name_or_path, cache_dir):
if not os.path.exists(weight_names_file):
weight_names = []
for bin_file in hf_weights_files:
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
param_path = os.path.join(np_folder, name)
with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy())
weight_names.append(name)
with open(weight_names_file, "w") as f:
json.dump(weight_names, f)

with open(weight_names_file, "r") as f:
weight_names = json.load(f)

for name in weight_names:
param_path = os.path.join(np_folder, name)
with open(param_path, "rb") as f:
param = np.load(f)
yield name, torch.from_numpy(param)
elif use_safetensors:
for st_file in hf_weights_files:
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
yield name, param
else:
for bin_file in hf_weights_files:
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
yield name, param
del state
torch.cuda.empty_cache()

@torch.no_grad()
def load_model(model_path):
with init_empty_weights():
config = AutoConfig.from_pretrained(model_path)

if not hasattr(config, "quantization_config"):
raise ValueError("Must be a Marlin quantized model, but your config has no quantization config.")
if "quant_method" not in config.quantization_config:
raise ValueError("Must be a Marlin quantized model, but your quantization config has no quant_method.")
if config.quantization_config["quant_method"] != "marlin":
raise ValueError(f"Must be a Marline model, but you passed a model with quant_method = {quant_method}")

model = AutoModelForCausalLM.from_config(config)
marlin.replace_linear(
model.model,
groupsize=config.quantization_config["group_size"]
)

module_dict = dict(model.named_modules())
for name, loaded_weight in hf_model_weights_iterator(model_path):
module_name = ".".join(name.split(".")[:-1])
param_name = name[len(module_name) + 1:]
module = module_dict[module_name]

if not hasattr(module, param_name):
raise ValueError("Key mismatch.")

setattr(module, param_name, torch.nn.Parameter(loaded_weight, requires_grad=False))

return model