Skip to content

Commit 94ed13c

Browse files
SunMarcArthurZucker
authored andcommitted
Fix regression loading dtype (#34409)
* fix regression * add test for torchao * expected output * better fix
1 parent 72c716d commit 94ed13c

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

src/transformers/modeling_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -943,13 +943,14 @@ def _load_state_dict_into_meta_model(
943943
old_param = model
944944
splits = param_name.split(".")
945945
for split in splits:
946-
old_param = getattr(old_param, split)
947-
# Not all the attributes of a module are Parameters/Tensor
948-
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
949-
old_param = None
946+
# We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys.
947+
old_param = getattr(old_param, split, None)
950948
if old_param is None:
951949
break
952950

951+
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
952+
old_param = None
953+
953954
if old_param is not None:
954955
if dtype is None:
955956
param = param.to(old_param.dtype)

tests/quantization/torchao_integration/test_torchao.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,26 @@ def test_int4wo_offload(self):
208208

209209
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
210210

211+
def test_int8_dynamic_activation_int8_weight_quant(self):
212+
"""
213+
Simple LLM model testing int8_dynamic_activation_int8_weight
214+
"""
215+
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
216+
217+
# Note: we quantize the bfloat16 model on the fly to int4
218+
quantized_model = AutoModelForCausalLM.from_pretrained(
219+
self.model_name,
220+
device_map=torch_device,
221+
quantization_config=quant_config,
222+
)
223+
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
224+
225+
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
226+
227+
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
228+
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
229+
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
230+
211231

212232
if __name__ == "__main__":
213233
unittest.main()

0 commit comments

Comments
 (0)