Skip to content

Commit 8eaae6b

Browse files
keetrapMekkCyber
andauthored
Added Support for Custom Quantization (#35915)
* Added Support for Custom Quantization * Update code * code reformatted * Updated Changes * Updated Changes --------- Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
1 parent 07182b2 commit 8eaae6b

File tree

4 files changed

+116
-3
lines changed

4 files changed

+116
-3
lines changed
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import json
2+
from typing import Any, Dict
3+
4+
import torch
5+
6+
from transformers import AutoModelForCausalLM, AutoTokenizer
7+
from transformers.quantizers import HfQuantizer, register_quantization_config, register_quantizer
8+
from transformers.utils.quantization_config import QuantizationConfigMixin
9+
10+
11+
@register_quantization_config("custom")
12+
class CustomConfig(QuantizationConfigMixin):
13+
def __init__(self):
14+
self.quant_method = "custom"
15+
self.bits = 8
16+
17+
def to_dict(self) -> Dict[str, Any]:
18+
output = {
19+
"num_bits": self.bits,
20+
}
21+
return output
22+
23+
def __repr__(self):
24+
config_dict = self.to_dict()
25+
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
26+
27+
def to_diff_dict(self) -> Dict[str, Any]:
28+
config_dict = self.to_dict()
29+
30+
default_config_dict = CustomConfig().to_dict()
31+
32+
serializable_config_dict = {}
33+
34+
for key, value in config_dict.items():
35+
if value != default_config_dict[key]:
36+
serializable_config_dict[key] = value
37+
38+
return serializable_config_dict
39+
40+
41+
@register_quantizer("custom")
42+
class CustomQuantizer(HfQuantizer):
43+
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
44+
super().__init__(quantization_config, **kwargs)
45+
self.quantization_config = quantization_config
46+
self.scale_map = {}
47+
self.device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
48+
self.torch_dtype = kwargs.get("torch_dtype", torch.float32)
49+
50+
def _process_model_before_weight_loading(self, model, **kwargs):
51+
return True
52+
53+
def _process_model_after_weight_loading(self, model, **kwargs):
54+
return True
55+
56+
def is_serializable(self) -> bool:
57+
return True
58+
59+
def is_trainable(self) -> bool:
60+
return False
61+
62+
63+
model_8bit = AutoModelForCausalLM.from_pretrained(
64+
"facebook/opt-350m", quantization_config=CustomConfig(), torch_dtype="auto"
65+
)
66+
67+
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
68+
input_text = "once there is"
69+
inputs = tokenizer(input_text, return_tensors="pt")
70+
output = model_8bit.generate(
71+
**inputs,
72+
max_length=100,
73+
num_return_sequences=1,
74+
no_repeat_ngram_size=2,
75+
)
76+
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
77+
78+
print(generated_text)

src/transformers/modeling_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3706,8 +3706,10 @@ def from_pretrained(
37063706
device_map = hf_quantizer.update_device_map(device_map)
37073707

37083708
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
3709-
user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
3710-
3709+
if hasattr(hf_quantizer.quantization_config.quant_method, "value"):
3710+
user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
3711+
else:
3712+
user_agent["quant"] = hf_quantizer.quantization_config.quant_method
37113713
# Force-set to `True` for more mem efficiency
37123714
if low_cpu_mem_usage is None:
37133715
low_cpu_mem_usage = True

src/transformers/quantizers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from .auto import AutoHfQuantizer, AutoQuantizationConfig
14+
from .auto import AutoHfQuantizer, AutoQuantizationConfig, register_quantization_config, register_quantizer
1515
from .base import HfQuantizer

src/transformers/quantizers/auto.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
TorchAoConfig,
3636
VptqConfig,
3737
)
38+
from .base import HfQuantizer
3839
from .quantizer_aqlm import AqlmHfQuantizer
3940
from .quantizer_awq import AwqQuantizer
4041
from .quantizer_bitnet import BitNetHfQuantizer
@@ -226,3 +227,35 @@ def supports_quant_method(quantization_config_dict):
226227
)
227228
return False
228229
return True
230+
231+
232+
def register_quantization_config(method: str):
233+
"""Register a custom quantization configuration."""
234+
235+
def register_config_fn(cls):
236+
if method in AUTO_QUANTIZATION_CONFIG_MAPPING:
237+
raise ValueError(f"Config '{method}' already registered")
238+
239+
if not issubclass(cls, QuantizationConfigMixin):
240+
raise ValueError("Config must extend QuantizationConfigMixin")
241+
242+
AUTO_QUANTIZATION_CONFIG_MAPPING[method] = cls
243+
return cls
244+
245+
return register_config_fn
246+
247+
248+
def register_quantizer(name: str):
249+
"""Register a custom quantizer."""
250+
251+
def register_quantizer_fn(cls):
252+
if name in AUTO_QUANTIZER_MAPPING:
253+
raise ValueError(f"Quantizer '{name}' already registered")
254+
255+
if not issubclass(cls, HfQuantizer):
256+
raise ValueError("Quantizer must extend HfQuantizer")
257+
258+
AUTO_QUANTIZER_MAPPING[name] = cls
259+
return cls
260+
261+
return register_quantizer_fn

0 commit comments

Comments
 (0)