@@ -151,8 +151,8 @@ def test_8da4w_quantizer(self):
151
151
m (* example_inputs )
152
152
153
153
@unittest .skip ("skipping until we get checkpoints for gpt-fast" )
154
- def test_gptq_quantizer (self ):
155
- from torchao .quantization .GPTQ import Int8DynActInt4WeightGPTQQuantizer , InputRecorder
154
+ def test_8da4w_gptq_quantizer (self ):
155
+ from torchao .quantization .GPTQ import Int8DynActInt4WeightGPTQQuantizer , InputRecorder , TransformerEvalWrapper
156
156
# should be similar to TorchCompileDynamicQuantizer
157
157
precision = torch .bfloat16
158
158
device = "cpu"
@@ -161,6 +161,7 @@ def test_gptq_quantizer(self):
161
161
checkpoint = torch .load (str (checkpoint_path ), mmap = True , weights_only = True )
162
162
model .load_state_dict (checkpoint , assign = True )
163
163
model = model .to (dtype = precision , device = device )
164
+ model .eval ()
164
165
tokenizer_path = checkpoint_path .parent / "tokenizer.model"
165
166
assert tokenizer_path .is_file (), tokenizer_path
166
167
tokenizer = SentencePieceProcessor ( # pyre-ignore[28]
@@ -190,12 +191,60 @@ def test_gptq_quantizer(self):
190
191
blocksize ,
191
192
percdamp ,
192
193
groupsize ,
194
+ precision = precision ,
193
195
)
194
196
model .setup_caches (max_batch_size = 1 , max_seq_length = calibration_seq_length )
195
197
model = quantizer .quantize (model , inputs )
196
- compiled = torch .compile (model , mode = "max-autotune" )
197
- with torch .no_grad ():
198
- compiled (inputs [0 ].values [0 ], inputs [1 ].values [0 ])
198
+ result = TransformerEvalWrapper (
199
+ model ,
200
+ tokenizer ,
201
+ model .config .block_size ,
202
+ prepare_inputs_for_model ,
203
+ device ,
204
+ ).run_eval (
205
+ ["wikitext" ],
206
+ 1 ,
207
+ )
208
+
209
+ assert result ['results' ]['wikitext' ]['word_perplexity,none' ] < 7.88 , (
210
+ f"accuracy regressed from 7.87 to { result ['results' ]['wikitext' ]['word_perplexity,none' ]} "
211
+ )
212
+
213
+ @unittest .skip ("skipping until we get checkpoints for gpt-fast" )
214
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch verion is 2.4 or lower" )
215
+ def test_8da4w_quantizer_eval (self ):
216
+ from torchao .quantization .quant_api import Int8DynActInt4WeightQuantizer
217
+ from torchao .quantization .GPTQ import TransformerEvalWrapper
218
+
219
+ precision = torch .bfloat16
220
+ device = "cpu"
221
+ checkpoint_path = Path ("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth" )
222
+ model = Transformer .from_name (checkpoint_path .parent .name )
223
+ checkpoint = torch .load (str (checkpoint_path ), mmap = True , weights_only = True )
224
+ model .load_state_dict (checkpoint , assign = True )
225
+ model = model .to (dtype = precision , device = device )
226
+ model .eval ()
227
+ tokenizer_path = checkpoint_path .parent / "tokenizer.model"
228
+ assert tokenizer_path .is_file (), tokenizer_path
229
+ tokenizer = SentencePieceProcessor ( # pyre-ignore[28]
230
+ model_file = str (tokenizer_path )
231
+ )
232
+
233
+ quantizer = Int8DynActInt4WeightQuantizer (groupsize = 128 , precision = precision )
234
+ q_model = quantizer .quantize (model )
235
+ result = TransformerEvalWrapper (
236
+ q_model ,
237
+ tokenizer ,
238
+ q_model .config .block_size ,
239
+ prepare_inputs_for_model ,
240
+ device ,
241
+ ).run_eval (
242
+ ["wikitext" ],
243
+ 1 ,
244
+ )
245
+ assert result ['results' ]['wikitext' ]['word_perplexity,none' ] < 8.24 , (
246
+ f"accuracy regressed from 8.23 to { result ['results' ]['wikitext' ]['word_perplexity,none' ]} "
247
+ )
199
248
200
249
@unittest .skip ("skipping until we get checkpoints for gpt-fast" )
201
250
def test_gptq_quantizer_gpt_fast (self ):
@@ -248,5 +297,95 @@ def test_gptq_quantizer_gpt_fast(self):
248
297
with torch .no_grad ():
249
298
compiled (inputs [0 ].values [0 ], inputs [1 ].values [0 ])
250
299
300
+ @unittest .skip ("skipping until we get checkpoints for gpt-fast" )
301
+ def test_gptq_quantizer_int4wo (self ):
302
+ from torchao .quantization .GPTQ import Int4WeightOnlyGPTQQuantizer , InputRecorder , TransformerEvalWrapper
303
+ # should be similar to TorchCompileDynamicQuantizer
304
+ precision = torch .bfloat16
305
+ device = "cuda"
306
+ checkpoint_path = Path ("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth" )
307
+ model = Transformer .from_name (checkpoint_path .parent .name )
308
+ checkpoint = torch .load (str (checkpoint_path ), mmap = True , weights_only = True )
309
+ model .load_state_dict (checkpoint , assign = True )
310
+ model = model .to (dtype = precision , device = "cpu" )
311
+ model .eval ()
312
+ tokenizer_path = checkpoint_path .parent / "tokenizer.model"
313
+ assert tokenizer_path .is_file (), tokenizer_path
314
+ tokenizer = SentencePieceProcessor ( # pyre-ignore[28]
315
+ model_file = str (tokenizer_path )
316
+ )
317
+ blocksize = 128
318
+ percdamp = 0.01
319
+ groupsize = 128
320
+ calibration_tasks = ["wikitext" ]
321
+ calibration_limit = 1
322
+ calibration_seq_length = 100
323
+ input_prep_func = prepare_inputs_for_model
324
+ pad_calibration_inputs = False
325
+
326
+ inputs = InputRecorder (
327
+ tokenizer ,
328
+ calibration_seq_length ,
329
+ input_prep_func ,
330
+ pad_calibration_inputs ,
331
+ model .config .vocab_size ,
332
+ device = "cpu" ,
333
+ ).record_inputs (
334
+ calibration_tasks ,
335
+ calibration_limit ,
336
+ ).get_inputs ()
337
+
338
+ quantizer = Int4WeightOnlyGPTQQuantizer (
339
+ blocksize ,
340
+ percdamp ,
341
+ groupsize ,
342
+ )
343
+ model .setup_caches (max_batch_size = 1 , max_seq_length = calibration_seq_length )
344
+
345
+ model = quantizer .quantize (model , inputs ).cuda ()
346
+ result = TransformerEvalWrapper (
347
+ model .cuda (),
348
+ tokenizer ,
349
+ model .config .block_size ,
350
+ prepare_inputs_for_model ,
351
+ device ,
352
+ ).run_eval (
353
+ ["wikitext" ],
354
+ 1 ,
355
+ )
356
+ assert result ['results' ]['wikitext' ]['word_perplexity,none' ] < 7.77 , (
357
+ f"accuracy regressed from 7.76 to { result ['results' ]['wikitext' ]['word_perplexity,none' ]} "
358
+ )
359
+
360
+ @unittest .skip ("skipping until we get checkpoints for gpt-fast" )
361
+ def test_eval_wrapper (self ):
362
+ from torchao .quantization .GPTQ import TransformerEvalWrapper
363
+ precision = torch .bfloat16
364
+ device = "cuda"
365
+ checkpoint_path = Path ("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth" )
366
+ model = Transformer .from_name (checkpoint_path .parent .name )
367
+ checkpoint = torch .load (str (checkpoint_path ), mmap = True , weights_only = True )
368
+ model .load_state_dict (checkpoint , assign = True )
369
+ model = model .to (dtype = precision , device = device )
370
+ model .eval ()
371
+ tokenizer_path = checkpoint_path .parent / "tokenizer.model"
372
+ assert tokenizer_path .is_file (), tokenizer_path
373
+ tokenizer = SentencePieceProcessor ( # pyre-ignore[28]
374
+ model_file = str (tokenizer_path )
375
+ )
376
+ result = TransformerEvalWrapper (
377
+ model ,
378
+ tokenizer ,
379
+ model .config .block_size ,
380
+ prepare_inputs_for_model ,
381
+ device ,
382
+ ).run_eval (
383
+ ["wikitext" ],
384
+ 1 ,
385
+ )
386
+ assert result ['results' ]['wikitext' ]['word_perplexity,none' ]< 7.77 , (
387
+ f"accuracy regressed from 7.76 to { result ['results' ]['wikitext' ]['word_perplexity,none' ]} "
388
+ )
389
+
251
390
if __name__ == "__main__" :
252
391
unittest .main ()
0 commit comments