Skip to content

Commit e5cea20

Browse files
authored
Add Example for Custom quantization (#36286)
* add example * rename
1 parent e3d99ec commit e5cea20

File tree

2 files changed

+258
-0
lines changed

2 files changed

+258
-0
lines changed
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
import json
2+
from typing import Any, Dict, List, Optional
3+
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
from accelerate import init_empty_weights
8+
from huggingface_hub import HfApi
9+
10+
from transformers import AutoModelForCausalLM, AutoTokenizer
11+
from transformers.quantizers import HfQuantizer, get_module_from_name, register_quantization_config, register_quantizer
12+
from transformers.utils.quantization_config import QuantizationConfigMixin
13+
14+
15+
# Implement INT8 Symmetric Linear layer
16+
class Int8SymmetricLinear(torch.nn.Module):
17+
def __init__(self, in_features, out_features, bias, dtype=torch.float32):
18+
super().__init__()
19+
self.in_features = in_features
20+
self.out_features = out_features
21+
22+
self.register_buffer("weight", torch.zeros((out_features, in_features), dtype=torch.int8))
23+
self.register_buffer("weight_scale", torch.zeros((out_features, 1), dtype=dtype))
24+
25+
if bias:
26+
self.register_buffer("bias", torch.zeros((self.out_features), dtype=dtype))
27+
else:
28+
self.bias = None
29+
30+
def forward(self, x):
31+
dequant_weight = self.weight * self.weight_scale
32+
output = F.linear(x, dequant_weight)
33+
if self.bias is not None:
34+
output = output + self.bias
35+
return output
36+
37+
38+
# Function to replace standard linear layers with INT8 symmetric quantized layers
39+
def _replace_with_int8_symmetric_linear(
40+
model,
41+
modules_to_not_convert=None,
42+
current_key_name=None,
43+
quantization_config=None,
44+
has_been_replaced=False,
45+
pre_quantized=False,
46+
):
47+
"""
48+
Recursively replaces nn.Linear modules with Int8SymmetricLinear modules.
49+
"""
50+
if current_key_name is None:
51+
current_key_name = []
52+
53+
for name, module in model.named_children():
54+
current_key_name.append(name)
55+
56+
if (isinstance(module, nn.Linear)) and name not in modules_to_not_convert:
57+
# Check if the current key is not in the `modules_to_not_convert`
58+
current_key_name_str = ".".join(current_key_name)
59+
if not any(
60+
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
61+
):
62+
with init_empty_weights(include_buffers=True):
63+
in_features = module.in_features
64+
out_features = module.out_features
65+
model._modules[name] = Int8SymmetricLinear(
66+
in_features, out_features, module.bias is not None, dtype=module.weight.dtype
67+
)
68+
has_been_replaced = True
69+
model._modules[name].requires_grad_(False)
70+
71+
if len(list(module.children())) > 0:
72+
_, has_been_replaced = _replace_with_int8_symmetric_linear(
73+
module,
74+
modules_to_not_convert,
75+
current_key_name,
76+
quantization_config,
77+
has_been_replaced=has_been_replaced,
78+
pre_quantized=pre_quantized,
79+
)
80+
# Remove the last key for recursion
81+
current_key_name.pop(-1)
82+
return model, has_been_replaced
83+
84+
85+
def replace_with_int8_symmetric_linear(
86+
model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, pre_quantized=False
87+
):
88+
"""
89+
Main function to replace model layers with INT8 symmetric quantized versions.
90+
"""
91+
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
92+
93+
if quantization_config.modules_to_not_convert is not None:
94+
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
95+
modules_to_not_convert = list(set(modules_to_not_convert))
96+
97+
model, has_been_replaced = _replace_with_int8_symmetric_linear(
98+
model, modules_to_not_convert, current_key_name, quantization_config, pre_quantized=pre_quantized
99+
)
100+
101+
if not has_been_replaced:
102+
raise ValueError(
103+
"You are loading your model using INT8 symmetric quantization but no linear modules were found in your model."
104+
)
105+
106+
return model
107+
108+
109+
@register_quantization_config("int8_symmetric")
110+
class Int8SymmetricConfig(QuantizationConfigMixin):
111+
"""
112+
Configuration for INT8 symmetric quantization.
113+
"""
114+
115+
def __init__(self, modules_to_not_convert: Optional[List[str]] = None, **kwargs):
116+
self.quant_method = "int8_symmetric"
117+
self.modules_to_not_convert = modules_to_not_convert
118+
119+
def __repr__(self):
120+
config_dict = self.to_dict()
121+
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
122+
123+
def to_diff_dict(self) -> Dict[str, Any]:
124+
config_dict = self.to_dict()
125+
default_config_dict = Int8SymmetricConfig().to_dict()
126+
127+
serializable_config_dict = {}
128+
for key, value in config_dict.items():
129+
if value != default_config_dict[key]:
130+
serializable_config_dict[key] = value
131+
132+
return serializable_config_dict
133+
134+
135+
@register_quantizer("int8_symmetric")
136+
class Int8SymmetricQuantizer(HfQuantizer):
137+
"""
138+
Implementation of INT8 symmetric quantization.
139+
140+
"""
141+
142+
requires_calibration = False
143+
requires_parameters_quantization = True
144+
145+
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
146+
super().__init__(quantization_config, **kwargs)
147+
self.quantization_config = quantization_config
148+
149+
def _process_model_before_weight_loading(self, model, **kwargs):
150+
"""
151+
Replace model's linear layers with quantized versions before loading weights.
152+
"""
153+
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
154+
155+
model = replace_with_int8_symmetric_linear(
156+
model,
157+
modules_to_not_convert=self.modules_to_not_convert,
158+
quantization_config=self.quantization_config,
159+
pre_quantized=self.pre_quantized,
160+
)
161+
162+
def check_quantized_param(
163+
self,
164+
model,
165+
param_value: "torch.Tensor",
166+
param_name: str,
167+
state_dict: Dict[str, Any],
168+
**kwargs,
169+
):
170+
module, tensor_name = get_module_from_name(model, param_name)
171+
172+
if isinstance(module, Int8SymmetricLinear):
173+
if self.pre_quantized or tensor_name == "bias":
174+
if tensor_name == "weight" and param_value.dtype != torch.int8:
175+
raise ValueError("Expect quantized weights but got an unquantized weight")
176+
return False
177+
else:
178+
if tensor_name == "weight_scale":
179+
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
180+
return True
181+
return False
182+
183+
def create_quantized_param(
184+
self,
185+
model,
186+
param_value: "torch.Tensor",
187+
param_name: str,
188+
target_device: "torch.device",
189+
state_dict: Dict[str, Any],
190+
unexpected_keys: Optional[List[str]] = None,
191+
):
192+
"""
193+
Quantizes weights to INT8 symmetric format.
194+
"""
195+
abs_max_per_row = torch.max(torch.abs(param_value), dim=1, keepdim=True)[0].clamp(min=1e-5)
196+
197+
weight_scale = abs_max_per_row / 127.0
198+
199+
weight_quantized = torch.round(param_value / weight_scale).clamp(-128, 127).to(torch.int8)
200+
201+
module, tensor_name = get_module_from_name(model, param_name)
202+
module._buffers[tensor_name] = weight_quantized.to(target_device)
203+
module._buffers["weight_scale"] = weight_scale.to(target_device)
204+
205+
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
206+
not_missing_keys = []
207+
for name, module in model.named_modules():
208+
if isinstance(module, Int8SymmetricLinear):
209+
for missing in missing_keys:
210+
if (
211+
(name in missing or name in f"{prefix}.{missing}")
212+
and not missing.endswith(".weight")
213+
and not missing.endswith(".bias")
214+
):
215+
not_missing_keys.append(missing)
216+
return [k for k in missing_keys if k not in not_missing_keys]
217+
218+
def _process_model_after_weight_loading(self, model, **kwargs):
219+
"""
220+
Post-processing after weights are loaded.
221+
"""
222+
return True
223+
224+
def is_serializable(self, safe_serialization=None):
225+
return True
226+
227+
@property
228+
def is_trainable(self) -> bool:
229+
return False
230+
231+
232+
# Example usage
233+
if __name__ == "__main__":
234+
model_int8 = AutoModelForCausalLM.from_pretrained(
235+
"meta-llama/Llama-3.2-1B", quantization_config=Int8SymmetricConfig(), torch_dtype=torch.float, device_map="cpu"
236+
)
237+
238+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
239+
input_text = "once there is"
240+
inputs = tokenizer(input_text, return_tensors="pt").to("cpu")
241+
output = model_int8.generate(
242+
**inputs,
243+
max_length=100,
244+
num_return_sequences=1,
245+
no_repeat_ngram_size=2,
246+
)
247+
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
248+
print(generated_text)
249+
250+
# Save and upload to HUB
251+
output_model_dir = "Llama-3.2-1B-INT8-CUSTOM"
252+
model_int8.save_pretrained(output_model_dir)
253+
tokenizer.save_pretrained(output_model_dir)
254+
api = HfApi()
255+
repo_id = "medmekk/Llama-3.2-1B-INT8-CUSTOM"
256+
api.create_repo(repo_id, private=False)
257+
api.upload_folder(folder_path=output_model_dir, repo_id=repo_id, repo_type="model")

src/transformers/quantizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@
1313
# limitations under the License.
1414
from .auto import AutoHfQuantizer, AutoQuantizationConfig, register_quantization_config, register_quantizer
1515
from .base import HfQuantizer
16+
from .quantizers_utils import get_module_from_name

0 commit comments

Comments
 (0)