Skip to content

Commit bc68386

Browse files
younesbelkadaArthurZucker
authored andcommitted
🚨🚨🚨 [Quantization] Store the original dtype in the config as a private attribute 🚨🚨🚨 (huggingface#26761)
* First step * fix * add adjustements for gptq * change to `_pre_quantization_dtype` * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix serialization * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fixup --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
1 parent 38f683f commit bc68386

File tree

5 files changed

+67
-2
lines changed

5 files changed

+67
-2
lines changed

‎src/transformers/configuration_utils.py‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,9 @@ def to_diff_dict(self) -> Dict[str, Any]:
854854
else self.quantization_config
855855
)
856856

857+
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
858+
_ = serializable_config_dict.pop("_pre_quantization_dtype", None)
859+
857860
self.dict_torch_dtype_to_str(serializable_config_dict)
858861

859862
if "_flash_attn_2_enabled" in serializable_config_dict:
@@ -896,6 +899,9 @@ def to_dict(self) -> Dict[str, Any]:
896899
else self.quantization_config
897900
)
898901

902+
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
903+
_ = output.pop("_pre_quantization_dtype", None)
904+
899905
self.dict_torch_dtype_to_str(output)
900906

901907
return output

‎src/transformers/modeling_utils.py‎

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2178,8 +2178,25 @@ def to(self, *args, **kwargs):
21782178
"`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the"
21792179
" model has already been set to the correct devices and casted to the correct `dtype`."
21802180
)
2181-
else:
2182-
return super().to(*args, **kwargs)
2181+
elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ:
2182+
# For GPTQ models, we prevent users from casting the model to another dytpe to restrict unwanted behaviours.
2183+
# the correct API should be to load the model with the desired dtype directly through `from_pretrained`.
2184+
dtype_present_in_args = False
2185+
2186+
if "dtype" not in kwargs:
2187+
for arg in args:
2188+
if isinstance(arg, torch.dtype):
2189+
dtype_present_in_args = True
2190+
break
2191+
else:
2192+
dtype_present_in_args = True
2193+
2194+
if dtype_present_in_args:
2195+
raise ValueError(
2196+
"You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired"
2197+
" `dtype` by passing the correct `torch_dtype` argument."
2198+
)
2199+
return super().to(*args, **kwargs)
21832200

21842201
def half(self, *args):
21852202
# Checks if the model is quantized
@@ -3165,6 +3182,12 @@ def from_pretrained(
31653182
if hasattr(model, "quantization_method"):
31663183
model.is_quantized = True
31673184

3185+
# We store the original dtype for quantized models as we cannot easily retrieve it
3186+
# once the weights have been quantized
3187+
# Note that once you have loaded a quantized model, you can't change its dtype so this will
3188+
# remain a single source of truth
3189+
config._pre_quantization_dtype = torch_dtype
3190+
31683191
if isinstance(device_map, str):
31693192
special_dtypes = {}
31703193
if load_in_8bit or load_in_4bit:

‎tests/quantization/bnb/test_4bit.py‎

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,14 @@ def test_memory_footprint(self):
156156
linear = get_some_linear_layer(self.model_4bit)
157157
self.assertTrue(linear.weight.__class__ == Params4bit)
158158

159+
def test_original_dtype(self):
160+
r"""
161+
A simple test to check if the model succesfully stores the original dtype
162+
"""
163+
self.assertTrue(hasattr(self.model_4bit.config, "_pre_quantization_dtype"))
164+
self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype"))
165+
self.assertTrue(self.model_4bit.config._pre_quantization_dtype == torch.float16)
166+
159167
def test_linear_are_4bit(self):
160168
r"""
161169
A simple test to check if the model conversion has been done correctly by checking on the

‎tests/quantization/bnb/test_mixed_int8.py‎

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,14 @@ def test_quantization_config_json_serialization(self):
186186

187187
_ = config.to_json_string()
188188

189+
def test_original_dtype(self):
190+
r"""
191+
A simple test to check if the model succesfully stores the original dtype
192+
"""
193+
self.assertTrue(hasattr(self.model_8bit.config, "_pre_quantization_dtype"))
194+
self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype"))
195+
self.assertTrue(self.model_8bit.config._pre_quantization_dtype == torch.float16)
196+
189197
def test_memory_footprint(self):
190198
r"""
191199
A simple test to check if the model conversion has been done correctly by checking on the

‎tests/quantization/gptq/test_gptq.py‎

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,26 @@ def test_memory_footprint(self):
145145

146146
self.assertAlmostEqual(self.mem_fp16 / mem_quantized, self.EXPECTED_RELATIVE_DIFFERENCE)
147147

148+
def test_device_and_dtype_assignment(self):
149+
r"""
150+
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
151+
Checks also if other models are casted correctly.
152+
"""
153+
# This should work
154+
_ = self.quantized_model.to(0)
155+
156+
with self.assertRaises(ValueError):
157+
# Tries with a `dtype``
158+
self.quantized_model.to(torch.float16)
159+
160+
def test_original_dtype(self):
161+
r"""
162+
A simple test to check if the model succesfully stores the original dtype
163+
"""
164+
self.assertTrue(hasattr(self.quantized_model.config, "_pre_quantization_dtype"))
165+
self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype"))
166+
self.assertTrue(self.quantized_model.config._pre_quantization_dtype == torch.float16)
167+
148168
def test_quantized_layers_class(self):
149169
"""
150170
Simple test to check if the model conversion has been done correctly by checking on

0 commit comments

Comments
 (0)