File tree 2 files changed +25
-4
lines changed
tests/quantization/torchao_integration
2 files changed +25
-4
lines changed Original file line number Diff line number Diff line change @@ -943,13 +943,14 @@ def _load_state_dict_into_meta_model(
943
943
old_param = model
944
944
splits = param_name .split ("." )
945
945
for split in splits :
946
- old_param = getattr (old_param , split )
947
- # Not all the attributes of a module are Parameters/Tensor
948
- if not isinstance (old_param , (torch .nn .Parameter , torch .Tensor )):
949
- old_param = None
946
+ # We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys.
947
+ old_param = getattr (old_param , split , None )
950
948
if old_param is None :
951
949
break
952
950
951
+ if not isinstance (old_param , (torch .nn .Parameter , torch .Tensor )):
952
+ old_param = None
953
+
953
954
if old_param is not None :
954
955
if dtype is None :
955
956
param = param .to (old_param .dtype )
Original file line number Diff line number Diff line change @@ -208,6 +208,26 @@ def test_int4wo_offload(self):
208
208
209
209
self .assertEqual (tokenizer .decode (output [0 ], skip_special_tokens = True ), EXPECTED_OUTPUT )
210
210
211
+ def test_int8_dynamic_activation_int8_weight_quant (self ):
212
+ """
213
+ Simple LLM model testing int8_dynamic_activation_int8_weight
214
+ """
215
+ quant_config = TorchAoConfig ("int8_dynamic_activation_int8_weight" )
216
+
217
+ # Note: we quantize the bfloat16 model on the fly to int4
218
+ quantized_model = AutoModelForCausalLM .from_pretrained (
219
+ self .model_name ,
220
+ device_map = torch_device ,
221
+ quantization_config = quant_config ,
222
+ )
223
+ tokenizer = AutoTokenizer .from_pretrained (self .model_name )
224
+
225
+ input_ids = tokenizer (self .input_text , return_tensors = "pt" ).to (torch_device )
226
+
227
+ output = quantized_model .generate (** input_ids , max_new_tokens = self .max_new_tokens )
228
+ EXPECTED_OUTPUT = "What are we having for dinner?\n \n Jessica: (smiling)"
229
+ self .assertEqual (tokenizer .decode (output [0 ], skip_special_tokens = True ), EXPECTED_OUTPUT )
230
+
211
231
212
232
if __name__ == "__main__" :
213
233
unittest .main ()
You can’t perform that action at this time.
0 commit comments