|
| 1 | +import json |
| 2 | +from typing import Any, Dict, List, Optional |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.nn as nn |
| 6 | +import torch.nn.functional as F |
| 7 | +from accelerate import init_empty_weights |
| 8 | +from huggingface_hub import HfApi |
| 9 | + |
| 10 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
| 11 | +from transformers.quantizers import HfQuantizer, get_module_from_name, register_quantization_config, register_quantizer |
| 12 | +from transformers.utils.quantization_config import QuantizationConfigMixin |
| 13 | + |
| 14 | + |
| 15 | +# Implement INT8 Symmetric Linear layer |
| 16 | +class Int8SymmetricLinear(torch.nn.Module): |
| 17 | + def __init__(self, in_features, out_features, bias, dtype=torch.float32): |
| 18 | + super().__init__() |
| 19 | + self.in_features = in_features |
| 20 | + self.out_features = out_features |
| 21 | + |
| 22 | + self.register_buffer("weight", torch.zeros((out_features, in_features), dtype=torch.int8)) |
| 23 | + self.register_buffer("weight_scale", torch.zeros((out_features, 1), dtype=dtype)) |
| 24 | + |
| 25 | + if bias: |
| 26 | + self.register_buffer("bias", torch.zeros((self.out_features), dtype=dtype)) |
| 27 | + else: |
| 28 | + self.bias = None |
| 29 | + |
| 30 | + def forward(self, x): |
| 31 | + dequant_weight = self.weight * self.weight_scale |
| 32 | + output = F.linear(x, dequant_weight) |
| 33 | + if self.bias is not None: |
| 34 | + output = output + self.bias |
| 35 | + return output |
| 36 | + |
| 37 | + |
| 38 | +# Function to replace standard linear layers with INT8 symmetric quantized layers |
| 39 | +def _replace_with_int8_symmetric_linear( |
| 40 | + model, |
| 41 | + modules_to_not_convert=None, |
| 42 | + current_key_name=None, |
| 43 | + quantization_config=None, |
| 44 | + has_been_replaced=False, |
| 45 | + pre_quantized=False, |
| 46 | +): |
| 47 | + """ |
| 48 | + Recursively replaces nn.Linear modules with Int8SymmetricLinear modules. |
| 49 | + """ |
| 50 | + if current_key_name is None: |
| 51 | + current_key_name = [] |
| 52 | + |
| 53 | + for name, module in model.named_children(): |
| 54 | + current_key_name.append(name) |
| 55 | + |
| 56 | + if (isinstance(module, nn.Linear)) and name not in modules_to_not_convert: |
| 57 | + # Check if the current key is not in the `modules_to_not_convert` |
| 58 | + current_key_name_str = ".".join(current_key_name) |
| 59 | + if not any( |
| 60 | + (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert |
| 61 | + ): |
| 62 | + with init_empty_weights(include_buffers=True): |
| 63 | + in_features = module.in_features |
| 64 | + out_features = module.out_features |
| 65 | + model._modules[name] = Int8SymmetricLinear( |
| 66 | + in_features, out_features, module.bias is not None, dtype=module.weight.dtype |
| 67 | + ) |
| 68 | + has_been_replaced = True |
| 69 | + model._modules[name].requires_grad_(False) |
| 70 | + |
| 71 | + if len(list(module.children())) > 0: |
| 72 | + _, has_been_replaced = _replace_with_int8_symmetric_linear( |
| 73 | + module, |
| 74 | + modules_to_not_convert, |
| 75 | + current_key_name, |
| 76 | + quantization_config, |
| 77 | + has_been_replaced=has_been_replaced, |
| 78 | + pre_quantized=pre_quantized, |
| 79 | + ) |
| 80 | + # Remove the last key for recursion |
| 81 | + current_key_name.pop(-1) |
| 82 | + return model, has_been_replaced |
| 83 | + |
| 84 | + |
| 85 | +def replace_with_int8_symmetric_linear( |
| 86 | + model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, pre_quantized=False |
| 87 | +): |
| 88 | + """ |
| 89 | + Main function to replace model layers with INT8 symmetric quantized versions. |
| 90 | + """ |
| 91 | + modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert |
| 92 | + |
| 93 | + if quantization_config.modules_to_not_convert is not None: |
| 94 | + modules_to_not_convert.extend(quantization_config.modules_to_not_convert) |
| 95 | + modules_to_not_convert = list(set(modules_to_not_convert)) |
| 96 | + |
| 97 | + model, has_been_replaced = _replace_with_int8_symmetric_linear( |
| 98 | + model, modules_to_not_convert, current_key_name, quantization_config, pre_quantized=pre_quantized |
| 99 | + ) |
| 100 | + |
| 101 | + if not has_been_replaced: |
| 102 | + raise ValueError( |
| 103 | + "You are loading your model using INT8 symmetric quantization but no linear modules were found in your model." |
| 104 | + ) |
| 105 | + |
| 106 | + return model |
| 107 | + |
| 108 | + |
| 109 | +@register_quantization_config("int8_symmetric") |
| 110 | +class Int8SymmetricConfig(QuantizationConfigMixin): |
| 111 | + """ |
| 112 | + Configuration for INT8 symmetric quantization. |
| 113 | + """ |
| 114 | + |
| 115 | + def __init__(self, modules_to_not_convert: Optional[List[str]] = None, **kwargs): |
| 116 | + self.quant_method = "int8_symmetric" |
| 117 | + self.modules_to_not_convert = modules_to_not_convert |
| 118 | + |
| 119 | + def __repr__(self): |
| 120 | + config_dict = self.to_dict() |
| 121 | + return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" |
| 122 | + |
| 123 | + def to_diff_dict(self) -> Dict[str, Any]: |
| 124 | + config_dict = self.to_dict() |
| 125 | + default_config_dict = Int8SymmetricConfig().to_dict() |
| 126 | + |
| 127 | + serializable_config_dict = {} |
| 128 | + for key, value in config_dict.items(): |
| 129 | + if value != default_config_dict[key]: |
| 130 | + serializable_config_dict[key] = value |
| 131 | + |
| 132 | + return serializable_config_dict |
| 133 | + |
| 134 | + |
| 135 | +@register_quantizer("int8_symmetric") |
| 136 | +class Int8SymmetricQuantizer(HfQuantizer): |
| 137 | + """ |
| 138 | + Implementation of INT8 symmetric quantization. |
| 139 | +
|
| 140 | + """ |
| 141 | + |
| 142 | + requires_calibration = False |
| 143 | + requires_parameters_quantization = True |
| 144 | + |
| 145 | + def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): |
| 146 | + super().__init__(quantization_config, **kwargs) |
| 147 | + self.quantization_config = quantization_config |
| 148 | + |
| 149 | + def _process_model_before_weight_loading(self, model, **kwargs): |
| 150 | + """ |
| 151 | + Replace model's linear layers with quantized versions before loading weights. |
| 152 | + """ |
| 153 | + self.modules_to_not_convert = self.quantization_config.modules_to_not_convert |
| 154 | + |
| 155 | + model = replace_with_int8_symmetric_linear( |
| 156 | + model, |
| 157 | + modules_to_not_convert=self.modules_to_not_convert, |
| 158 | + quantization_config=self.quantization_config, |
| 159 | + pre_quantized=self.pre_quantized, |
| 160 | + ) |
| 161 | + |
| 162 | + def check_quantized_param( |
| 163 | + self, |
| 164 | + model, |
| 165 | + param_value: "torch.Tensor", |
| 166 | + param_name: str, |
| 167 | + state_dict: Dict[str, Any], |
| 168 | + **kwargs, |
| 169 | + ): |
| 170 | + module, tensor_name = get_module_from_name(model, param_name) |
| 171 | + |
| 172 | + if isinstance(module, Int8SymmetricLinear): |
| 173 | + if self.pre_quantized or tensor_name == "bias": |
| 174 | + if tensor_name == "weight" and param_value.dtype != torch.int8: |
| 175 | + raise ValueError("Expect quantized weights but got an unquantized weight") |
| 176 | + return False |
| 177 | + else: |
| 178 | + if tensor_name == "weight_scale": |
| 179 | + raise ValueError("Expect unquantized weights but got a quantized weight_scale") |
| 180 | + return True |
| 181 | + return False |
| 182 | + |
| 183 | + def create_quantized_param( |
| 184 | + self, |
| 185 | + model, |
| 186 | + param_value: "torch.Tensor", |
| 187 | + param_name: str, |
| 188 | + target_device: "torch.device", |
| 189 | + state_dict: Dict[str, Any], |
| 190 | + unexpected_keys: Optional[List[str]] = None, |
| 191 | + ): |
| 192 | + """ |
| 193 | + Quantizes weights to INT8 symmetric format. |
| 194 | + """ |
| 195 | + abs_max_per_row = torch.max(torch.abs(param_value), dim=1, keepdim=True)[0].clamp(min=1e-5) |
| 196 | + |
| 197 | + weight_scale = abs_max_per_row / 127.0 |
| 198 | + |
| 199 | + weight_quantized = torch.round(param_value / weight_scale).clamp(-128, 127).to(torch.int8) |
| 200 | + |
| 201 | + module, tensor_name = get_module_from_name(model, param_name) |
| 202 | + module._buffers[tensor_name] = weight_quantized.to(target_device) |
| 203 | + module._buffers["weight_scale"] = weight_scale.to(target_device) |
| 204 | + |
| 205 | + def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: |
| 206 | + not_missing_keys = [] |
| 207 | + for name, module in model.named_modules(): |
| 208 | + if isinstance(module, Int8SymmetricLinear): |
| 209 | + for missing in missing_keys: |
| 210 | + if ( |
| 211 | + (name in missing or name in f"{prefix}.{missing}") |
| 212 | + and not missing.endswith(".weight") |
| 213 | + and not missing.endswith(".bias") |
| 214 | + ): |
| 215 | + not_missing_keys.append(missing) |
| 216 | + return [k for k in missing_keys if k not in not_missing_keys] |
| 217 | + |
| 218 | + def _process_model_after_weight_loading(self, model, **kwargs): |
| 219 | + """ |
| 220 | + Post-processing after weights are loaded. |
| 221 | + """ |
| 222 | + return True |
| 223 | + |
| 224 | + def is_serializable(self, safe_serialization=None): |
| 225 | + return True |
| 226 | + |
| 227 | + @property |
| 228 | + def is_trainable(self) -> bool: |
| 229 | + return False |
| 230 | + |
| 231 | + |
| 232 | +# Example usage |
| 233 | +if __name__ == "__main__": |
| 234 | + model_int8 = AutoModelForCausalLM.from_pretrained( |
| 235 | + "meta-llama/Llama-3.2-1B", quantization_config=Int8SymmetricConfig(), torch_dtype=torch.float, device_map="cpu" |
| 236 | + ) |
| 237 | + |
| 238 | + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") |
| 239 | + input_text = "once there is" |
| 240 | + inputs = tokenizer(input_text, return_tensors="pt").to("cpu") |
| 241 | + output = model_int8.generate( |
| 242 | + **inputs, |
| 243 | + max_length=100, |
| 244 | + num_return_sequences=1, |
| 245 | + no_repeat_ngram_size=2, |
| 246 | + ) |
| 247 | + generated_text = tokenizer.decode(output[0], skip_special_tokens=True) |
| 248 | + print(generated_text) |
| 249 | + |
| 250 | + # Save and upload to HUB |
| 251 | + output_model_dir = "Llama-3.2-1B-INT8-CUSTOM" |
| 252 | + model_int8.save_pretrained(output_model_dir) |
| 253 | + tokenizer.save_pretrained(output_model_dir) |
| 254 | + api = HfApi() |
| 255 | + repo_id = "medmekk/Llama-3.2-1B-INT8-CUSTOM" |
| 256 | + api.create_repo(repo_id, private=False) |
| 257 | + api.upload_folder(folder_path=output_model_dir, repo_id=repo_id, repo_type="model") |
0 commit comments