Skip to content

Commit b05d2c4

Browse files
SunMarcMekkCyber
andauthored
Fix dtype quantizer (#42882)
* fix dtype quantizer * fix * rm print * fix * style * fix * revert * bitnet * fix * gogo * Update src/transformers/modeling_utils.py Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> * warn instead * fix * fix --------- Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
1 parent 0001b3e commit b05d2c4

38 files changed

+124
-281
lines changed

src/transformers/configuration_utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,10 +1019,6 @@ def _remove_keys_not_serialized(self, d: dict[str, Any]) -> None:
10191019
Checks and removes if there are any keys in the dict that should not be serialized when saving the config.
10201020
Runs recursive check on the dict, to remove from all sub configs.
10211021
"""
1022-
if hasattr(self, "quantization_config"):
1023-
# Pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
1024-
_ = d.pop("_pre_quantization_dtype", None)
1025-
10261022
if "_auto_class" in d:
10271023
del d["_auto_class"]
10281024
if "_output_attentions" in d:

src/transformers/integrations/bitsandbytes.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def replace_with_bnb_linear(
233233

234234

235235
# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
236-
def dequantize_bnb_weight(weight: "torch.nn.Parameter", dtype: "torch.dtype", state=None):
236+
def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
237237
"""
238238
Helper function to dequantize 4bit or 8bit bnb weights.
239239
@@ -248,10 +248,7 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", dtype: "torch.dtype", st
248248

249249
if cls_name == "Params4bit":
250250
output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
251-
logger.warning_once(
252-
f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
253-
)
254-
return output_tensor.to(dtype)
251+
return output_tensor
255252

256253
if state.SCB is None:
257254
state.SCB = weight.SCB
@@ -263,7 +260,7 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", dtype: "torch.dtype", st
263260
# Multiply by (scale/127) to dequantize.
264261
dequantized = weight.data * state.SCB.view(-1, 1) * 7.874015718698502e-3
265262

266-
return dequantized.to(dtype)
263+
return dequantized
267264

268265

269266
def _create_accelerate_new_hook(old_hook):
@@ -283,10 +280,7 @@ def _create_accelerate_new_hook(old_hook):
283280
return new_hook
284281

285282

286-
def dequantize_and_replace(
287-
model,
288-
quantization_config=None,
289-
):
283+
def dequantize_and_replace(model, quantization_config=None, dtype=None):
290284
"""
291285
Converts a quantized model into its dequantized original version. The newly converted model will have
292286
some performance drop compared to the original model before quantization - use it only for specific usecases
@@ -297,14 +291,22 @@ def dequantize_and_replace(
297291
quant_method = quantization_config.quantization_method()
298292

299293
target_cls = bnb.nn.Linear8bitLt if quant_method == "llm_int8" else bnb.nn.Linear4bit
300-
301294
for module_name, module in model.named_modules():
302295
if isinstance(module, target_cls):
303296
with init_empty_weights():
304297
bias = getattr(module, "bias", None)
305298
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=bias is not None)
306299
state = module.state if quant_method == "llm_int8" else None
307-
new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, model.dtype, state))
300+
new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state))
301+
weight = dequantize_bnb_weight(module.weight, state)
302+
if dtype is None:
303+
logger.warning_once(
304+
f"The modules are dequantized in {weight.dtype}. If you want to change the dtype, please specify `dtype` in `dequantize`. "
305+
)
306+
else:
307+
logger.warning_once(f"The modules are dequantized in {weight.dtype} and casted to {dtype}.")
308+
weight = weight.to(dtype)
309+
new_module.weight = torch.nn.Parameter(weight)
308310
if bias is not None:
309311
new_module.bias = bias
310312
if hasattr(module, "_hf_hook"):

src/transformers/integrations/flash_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ def get_target_dtype(query: torch.Tensor, module: torch.nn.Module) -> torch.dtyp
2020
else torch.get_autocast_gpu_dtype()
2121
)
2222
# Handle the case where the model is quantized
23-
elif hasattr(module.config, "_pre_quantization_dtype"):
24-
return module.config._pre_quantization_dtype
23+
elif hasattr(module.config, "quantization_config"):
24+
return module.config.dtype
2525
else:
2626
return next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
2727
return None

src/transformers/modeling_utils.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,7 @@ def _get_dtype(
792792
sharded_metadata: Optional[dict],
793793
state_dict: Optional[dict],
794794
weights_only: bool,
795+
hf_quantizer: Optional[HfQuantizer] = None,
795796
) -> tuple[PreTrainedConfig, torch.dtype]:
796797
"""Find the correct `dtype` to use based on provided arguments. Also update the `config` based on the
797798
inferred dtype. We do the following:
@@ -840,6 +841,9 @@ def _get_dtype(
840841
# set torch.get_default_dtype() (usually fp32) as the default dtype if `None` is provided
841842
dtype = torch.get_default_dtype()
842843

844+
if hf_quantizer is not None:
845+
hf_quantizer.update_dtype(dtype)
846+
843847
# Get the main dtype
844848
if isinstance(dtype, dict):
845849
main_dtype = dtype.get("", torch.get_default_dtype())
@@ -1433,7 +1437,7 @@ def tp_plan(self, plan: dict[str, str] | None):
14331437
def pp_plan(self, plan: dict[str, tuple[str, str]]):
14341438
self._pp_plan = plan
14351439

1436-
def dequantize(self):
1440+
def dequantize(self, dtype=None):
14371441
"""
14381442
Potentially dequantize the model in case it has been quantized by a quantization method that support
14391443
dequantization.
@@ -1443,7 +1447,7 @@ def dequantize(self):
14431447
if hf_quantizer is None:
14441448
raise ValueError("You need to first quantize your model in order to dequantize it")
14451449

1446-
return hf_quantizer.dequantize(self)
1450+
return hf_quantizer.dequantize(self, dtype=dtype)
14471451

14481452
def _backward_compatibility_gradient_checkpointing(self):
14491453
if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
@@ -3875,8 +3879,8 @@ def from_pretrained(
38753879
if "attn_implementation" in kwargs:
38763880
config._attn_implementation = kwargs.pop("attn_implementation")
38773881

3878-
hf_quantizer, config, dtype, device_map = get_hf_quantizer(
3879-
config, quantization_config, dtype, device_map, weights_only, user_agent
3882+
hf_quantizer, config, device_map = get_hf_quantizer(
3883+
config, quantization_config, device_map, weights_only, user_agent
38803884
)
38813885

38823886
if gguf_file:
@@ -3923,7 +3927,9 @@ def from_pretrained(
39233927
]
39243928

39253929
# Find the correct dtype based on current state
3926-
config, dtype = _get_dtype(dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only)
3930+
config, dtype = _get_dtype(
3931+
dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only, hf_quantizer
3932+
)
39273933

39283934
config.name_or_path = pretrained_model_name_or_path
39293935
model_init_context = cls.get_init_context(dtype, is_quantized, _is_ds_init_called)
@@ -3932,22 +3938,18 @@ def from_pretrained(
39323938
# Let's make sure we don't run the init function of buffer modules
39333939
model = cls(config, *model_args, **model_kwargs)
39343940

3941+
if hf_quantizer is not None: # replace module with quantized modules (does not touch weights)
3942+
hf_quantizer.preprocess_model(
3943+
model=model,
3944+
dtype=dtype,
3945+
device_map=device_map,
3946+
checkpoint_files=checkpoint_files,
3947+
use_kernels=use_kernels,
3948+
)
3949+
39353950
# Obtain the weight conversion mapping for this model if any are registered
39363951
weight_conversions = get_model_conversion_mapping(model, key_mapping, hf_quantizer)
39373952

3938-
# make sure we use the model's config since the __init__ call might have copied it
3939-
config = model.config
3940-
3941-
if hf_quantizer is not None: # replace module with quantized modules (does not touch weights)
3942-
hf_quantizer.preprocess_model(
3943-
model=model,
3944-
device_map=device_map,
3945-
keep_in_fp32_modules=model._keep_in_fp32_modules, # TODO prob no longer needed?
3946-
config=config,
3947-
checkpoint_files=checkpoint_files,
3948-
use_kernels=use_kernels,
3949-
)
3950-
39513953
if _torch_distributed_available and device_mesh is not None: # add hooks to nn.Modules: no weights
39523954
model = distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size)
39533955

@@ -3994,7 +3996,9 @@ def from_pretrained(
39943996

39953997
if hf_quantizer is not None:
39963998
model.hf_quantizer = hf_quantizer
3997-
hf_quantizer.postprocess_model(model, config=config) # usually a no-op but sometimes needed
3999+
hf_quantizer.postprocess_model(
4000+
model
4001+
) # usually a no-op but sometimes needed, e.g to remove the quant config when dequantizing
39984002

39994003
if _adapter_model_path is not None:
40004004
adapter_kwargs["key_mapping"] = key_mapping

src/transformers/models/diffllama/modeling_diffllama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,8 @@ def forward(
361361
else torch.get_autocast_gpu_dtype()
362362
)
363363
# Handle the case where the model is quantized
364-
elif hasattr(self.config, "_pre_quantization_dtype"):
365-
target_dtype = self.config._pre_quantization_dtype
364+
elif hasattr(self.config, "quantization_config"):
365+
target_dtype = self.config.dtype
366366
else:
367367
target_dtype = self.q_proj.weight.dtype
368368

src/transformers/models/diffllama/modular_diffllama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,8 @@ def forward(
236236
else torch.get_autocast_gpu_dtype()
237237
)
238238
# Handle the case where the model is quantized
239-
elif hasattr(self.config, "_pre_quantization_dtype"):
240-
target_dtype = self.config._pre_quantization_dtype
239+
elif hasattr(self.config, "quantization_config"):
240+
target_dtype = self.config.dtype
241241
else:
242242
target_dtype = self.q_proj.weight.dtype
243243

src/transformers/models/falcon/modeling_falcon.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -521,8 +521,8 @@ def forward(
521521
else torch.get_autocast_gpu_dtype()
522522
)
523523
# Handle the case where the model is quantized
524-
elif hasattr(self.config, "_pre_quantization_dtype"):
525-
target_dtype = self.config._pre_quantization_dtype
524+
elif hasattr(self.config, "quantization_config"):
525+
target_dtype = self.config.dtype
526526
else:
527527
target_dtype = self.query_key_value.weight.dtype
528528

src/transformers/models/falcon_mamba/modeling_falcon_mamba.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def cuda_kernels_forward(
345345

346346
# In case the model has been quantized, we need a hack to properly call the `nn.Linear` module
347347
# at the price of a small overhead.
348-
if hasattr(self.config, "_pre_quantization_dtype"):
348+
if hasattr(self.config, "quantization_config"):
349349
discrete_time_step = (self.dt_proj(time_step) - self.dt_proj.bias).transpose(1, 2)
350350
else:
351351
discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)

src/transformers/models/falcon_mamba/modular_falcon_mamba.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def cuda_kernels_forward(
357357

358358
# In case the model has been quantized, we need a hack to properly call the `nn.Linear` module
359359
# at the price of a small overhead.
360-
if hasattr(self.config, "_pre_quantization_dtype"):
360+
if hasattr(self.config, "quantization_config"):
361361
discrete_time_step = (self.dt_proj(time_step) - self.dt_proj.bias).transpose(1, 2)
362362
else:
363363
discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)

src/transformers/models/gpt_neo/modeling_gpt_neo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,8 @@ def forward(
237237
else torch.get_autocast_gpu_dtype()
238238
)
239239
# Handle the case where the model is quantized
240-
elif hasattr(self.config, "_pre_quantization_dtype"):
241-
target_dtype = self.config._pre_quantization_dtype
240+
elif hasattr(self.config, "quantization_config"):
241+
target_dtype = self.config.dtype
242242
else:
243243
target_dtype = self.q_proj.weight.dtype
244244

0 commit comments

Comments
 (0)