Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] save_quantized() #296

Merged
merged 5 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 86 additions & 13 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from transformers.modeling_utils import no_init_weights, shard_checkpoint
from transformers.utils.generic import ContextManagers

from ..nn_modules.qlinear.qlinear_qbits import qbits_dtype
from ..nn_modules.qlinear.qlinear_qbits import qbits_dtype, QBitsQuantLinear
from ..quantization import GPTQ, QuantizeConfig
from ..quantization.config import (FORMAT, FORMAT_FIELD_JSON, META_FIELD_QUANTIZER, META_QUANTIZER_GPTQMODEL,
MIN_VERSION_WITH_V2, QUANTIZE_BLACK_LIST, AutoRoundQuantizeConfig)
Expand Down Expand Up @@ -79,11 +79,13 @@ def __init__(
quantized: bool,
quantize_config: QuantizeConfig,
qlinear_kernel: nn.Module = None,
load_quantized_model: bool = False,
):
super().__init__()

self.model = model
self._quantized = quantized
self.load_quantized_model = load_quantized_model
self.quantize_config = quantize_config
self.config = self.model.config

Expand Down Expand Up @@ -603,7 +605,6 @@ def save_quantized(
# The config, quantize_config and model may be edited in place in save_quantized.
config = copy.deepcopy(self.model.config)
quantize_config = copy.deepcopy(self.quantize_config)
model = self.model

if not self.quantized:
raise ValueError("Save aborted as model is not quantized. Please call `quantize()` first.")
Expand All @@ -619,17 +620,20 @@ def save_quantized(
f"Using 'format = {FORMAT.GPTQ_V2}': the serialized model is only supported by GPTQModel version >= {MIN_VERSION_WITH_V2}."
)

# internal is always gptq v2 but allow users to pass gptq (v1) via config
if quantize_config.format == FORMAT.GPTQ:
# Model qzeros may be edited in place.
# TODO: avoid inplace modification of the weights
model = copy.deepcopy(self.model)
model = convert_gptq_v2_to_v1_format(
model, quantize_config=quantize_config, qlinear_kernel=self.qlinear_kernel
)

# The model saved during bitblas format quantization uses BitblasQuantLinear, which can be used directly.
if not self.load_quantized_model or quantize_config.format == FORMAT.BITBLAS:
model = self.model
# # internal is always gptq v2 but allow users to pass gptq (v1) via config
if quantize_config.format == FORMAT.GPTQ:
# Model qzeros may be edited in place.
# TODO: avoid inplace modification of the weights
model = copy.deepcopy(self.model)
model = convert_gptq_v2_to_v1_format(
model, quantize_config=quantize_config, qlinear_kernel=self.qlinear_kernel
)
else:
model = self.get_model_with_quantize(quantize_config)
model.to(CPU)

state_dict = model.state_dict()

if quantize_config.model_file_base_name is None:
Expand Down Expand Up @@ -759,6 +763,70 @@ def save_quantized(
quantize_config.model_file_base_name = model_base_name
quantize_config.save_pretrained(save_dir)

def get_model_with_quantize(self, quantize_config):
config = AutoConfig.from_pretrained(
quantize_config.model_name_or_path,
trust_remote_code=True,
)

def skip(*args, **kwargs):
pass

torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
transformers.modeling_utils._init_weights = False
init_contexts = [no_init_weights()]
with ContextManagers(init_contexts):
model = AutoModelForCausalLM.from_config(
config, torch_dtype=torch.float16
)

if self.dynamic_expert_index is not None:
num_experts = getattr(config, self.dynamic_expert_index)
self.layer_modules = get_moe_layer_modules(layer_modules=self.layer_modules,
num_experts=num_experts)

layers = find_layers(model)
ignore_layers = [self.lm_head] + self.base_modules

for name in list(layers.keys()):
# allow loading of quantized lm_head
if quantize_config.lm_head and name == self.lm_head:
continue

if any(name.startswith(ignore_layer) for ignore_layer in ignore_layers) or all(
not name.endswith(ignore_layer) for sublist in self.layer_modules for ignore_layer in sublist
):
# log non-lm-head quantizerd layers only
if name is not self.lm_head:
logger.info(f"The layer {name} is not quantized.")
del layers[name]

make_quant(
model,
layers,
quantize_config.bits,
quantize_config.group_size,
backend=BACKEND.AUTO,
format=quantize_config.format,
desc_act=quantize_config.desc_act,
pack=True,
)
model.tie_weights()

accelerate.load_checkpoint_in_model(
model,
dtype=torch.float16,
# This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292
checkpoint=self.checkpoint_file_name,
# device_map=device_map,
# offload_state_dict=True,
# offload_buffers=True,
)
torch.cuda.empty_cache()
return model

def save_pretrained(
self,
save_dir: str,
Expand Down Expand Up @@ -1033,6 +1101,7 @@ def from_quantized(
)

quantize_config.model_file_base_name = true_model_basename
quantize_config.runtime_format = quantize_config.format

model_save_name = resolved_archive_file # In case a model is sharded, this would be `model.safetensors.index.json` which may later break.
if verify_hash:
Expand All @@ -1059,6 +1128,7 @@ def skip(*args, **kwargs):
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype
)
model.checkpoint_file_name = model_save_name

if cls.dynamic_expert_index is not None:
num_experts = getattr(config, cls.dynamic_expert_index)
Expand Down Expand Up @@ -1090,6 +1160,8 @@ def skip(*args, **kwargs):
format=FORMAT.GPTQ_V2,
desc_act=quantize_config.desc_act,
)
if preload_qlinear_kernel == QBitsQuantLinear:
quantize_config.runtime_format = FORMAT.QBITS
model.tie_weights()

# == step3: load checkpoint and dispatch == #
Expand Down Expand Up @@ -1152,7 +1224,7 @@ def skip(*args, **kwargs):
qlinear_kernel=preload_qlinear_kernel,
)
load_checkpoint_in_model = True
quantize_config.format = FORMAT.GPTQ_V2
quantize_config.runtime_format = FORMAT.GPTQ_V2

if backend == BACKEND.MARLIN:
if is_sharded:
Expand Down Expand Up @@ -1251,6 +1323,7 @@ def skip(*args, **kwargs):
quantized=True,
quantize_config=quantize_config,
qlinear_kernel=qlinear_kernel,
load_quantized_model=True,
)

def __getattr__(self, item):
Expand Down
1 change: 1 addition & 0 deletions gptqmodel/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class QuantizeConfig():
# default to gptq v1 format for maximum compat with 3rd party inference libs with minimal loss vs v2
# if you inference with gptqmodel, save to gptq_v2 format for best result
format: FORMAT = field(default=FORMAT.GPTQ)
runtime_format: FORMAT = field(default=FORMAT.GPTQ)

# TODO: remove
model_name_or_path: Optional[str] = field(default=None)
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/utils/bitblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,6 @@ def convert_to_bitblas(model, model_quantlinear, quant_config: QuantizeConfig, s
gc.collect()

# Set quantization config to be BitBLAS.
quant_config.format = FORMAT.BITBLAS
quant_config.runtime_format = FORMAT.BITBLAS

return model
2 changes: 1 addition & 1 deletion gptqmodel/utils/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,6 @@ def convert_to_marlin(
gc.collect()

# Set quantization config to be Marlin.
quantization_config.format = FORMAT.MARLIN
quantization_config.runtime_format = FORMAT.MARLIN

return model
46 changes: 46 additions & 0 deletions tests/test_save_loaded_quantized_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import tempfile
import unittest

import torch
from gptqmodel import BACKEND, GPTQModel
from parameterized import parameterized
from transformers import AutoTokenizer

MODEL_ID = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"

class TestSave(unittest.TestCase):
@parameterized.expand(
[
(BACKEND.AUTO),
(BACKEND.EXLLAMA_V2),
(BACKEND.EXLLAMA),
(BACKEND.TRITON),
(BACKEND.BITBLAS),
(BACKEND.MARLIN),
(BACKEND.QBITS),
]
)
def test_save(self, backend):
prompt = "I am in Paris and"
device = torch.device("cuda:0") if backend != BACKEND.QBITS else torch.device("cpu")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
inp = tokenizer(prompt, return_tensors="pt").to(device)

# origin model produce correct output
origin_model = GPTQModel.from_quantized(MODEL_ID, backend=backend)
origin_model_res = origin_model.generate(**inp, num_beams=1, min_new_tokens=60, max_new_tokens=60)
origin_model_predicted_text = tokenizer.decode(origin_model_res[0])

with tempfile.TemporaryDirectory() as tmpdir:
origin_model.save_quantized(tmpdir)

# saved model produce wrong output
new_model = GPTQModel.from_quantized(tmpdir, backend=backend)

new_model_res = new_model.generate(**inp, num_beams=1, min_new_tokens=60, max_new_tokens=60)
new_model_predicted_text = tokenizer.decode(new_model_res[0])

print("origin_model_predicted_text",origin_model_predicted_text)
print("new_model_predicted_text",new_model_predicted_text)

self.assertEqual(origin_model_predicted_text[:20], new_model_predicted_text[:20])
13 changes: 4 additions & 9 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import unittest # noqa: E402

from gptqmodel import BACKEND, GPTQModel # noqa: E402
from gptqmodel.quantization import FORMAT, FORMAT_FIELD_JSON, QUANT_CONFIG_FILENAME # noqa: E402
from gptqmodel.quantization import FORMAT, FORMAT_FIELD_JSON # noqa: E402


class TestSerialization(unittest.TestCase):
Expand All @@ -24,23 +24,18 @@ def test_marlin_local_serialization(self):

self.assertTrue(os.path.isfile(os.path.join(tmpdir, "gptq_model-4bit-128g.safetensors")))

with open(os.path.join(tmpdir, QUANT_CONFIG_FILENAME), "r") as config_file:
config = json.load(config_file)

self.assertTrue(config[FORMAT_FIELD_JSON] == FORMAT.MARLIN)

model = GPTQModel.from_quantized(tmpdir, device="cuda:0", backend=BACKEND.MARLIN)

def test_marlin_hf_cache_serialization(self):
model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0", backend=BACKEND.MARLIN)
self.assertEqual(model.quantize_config.format, FORMAT.MARLIN)
self.assertEqual(model.quantize_config.runtime_format, FORMAT.MARLIN)

model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0", backend=BACKEND.MARLIN)
self.assertEqual(model.quantize_config.format, FORMAT.MARLIN)
self.assertEqual(model.quantize_config.runtime_format, FORMAT.MARLIN)

def test_gptq_v1_to_v2_runtime_convert(self):
model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0")
self.assertEqual(model.quantize_config.format, FORMAT.GPTQ_V2)
self.assertEqual(model.quantize_config.runtime_format, FORMAT.GPTQ_V2)

def test_gptq_v1_serialization(self):
model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0")
Expand Down