Skip to content

Commit

Permalink
Extract eval code from GPTQ for more general usage
Browse files Browse the repository at this point in the history
Summary: This commit extracts all the eval code from GPTQ.py.
This is the first step towards having a general eval framework
in torchao. The eventual goal is to use lm_eval to produce
reproducible benchmarks for the quantization APIs in torchao
that we can showcase on the main README. This will have the
added benefit of allowing us to add (possibly nightly)
regression test suites for important models.

Test Plan:

```
2024-05-24:14:50:32,647 INFO     [task.py:395] Building contexts for wikitext on rank 0...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1132.98it/s]
2024-05-24:14:50:32,648 INFO     [evaluator.py:362] Running loglikelihood_rolling requests
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:51<00:00, 51.39s/it]
wikitext: {'word_perplexity,none': 7.877762491958485, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.488984329919892, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.5743285710685551, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}
.
----------------------------------------------------------------------
Ran 1 test in 858.105s

OK
```

python test/quantization/test_quant_api.py -k test_8da4w_gptq_quantizer
  • Loading branch information
andrewor14 committed May 24, 2024
1 parent 163cb93 commit caf13a6
Show file tree
Hide file tree
Showing 4 changed files with 284 additions and 263 deletions.
17 changes: 12 additions & 5 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,14 @@ def test_8da4w_quantizer(self):
assert isinstance(m.linear2, Int8DynActInt4WeightLinear)
m(*example_inputs)

# TODO: save model weights as artifacts and re-enable in CI
# For now, to run this test, you will need to download the weights from HF
# and run this script to convert them:
# https://github.com/pytorch-labs/gpt-fast/blob/6253c6bb054e658d67566150f87329b87815ae63/scripts/convert_hf_checkpoint.py
@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_8da4w_gptq_quantizer(self):
from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer, InputRecorder, TransformerEvalWrapper
from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer
from torchao.quantization._eval import InputRecorder, TransformerEvalWrapper
# should be similar to TorchCompileDynamicQuantizer
precision = torch.bfloat16
device = "cpu"
Expand Down Expand Up @@ -236,7 +241,7 @@ def test_8da4w_gptq_quantizer(self):
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_8da4w_quantizer_eval(self):
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
from torchao.quantization.GPTQ import TransformerEvalWrapper
from torchao.quantization._eval import TransformerEvalWrapper

precision = torch.bfloat16
device = "cpu"
Expand Down Expand Up @@ -270,7 +275,8 @@ def test_8da4w_quantizer_eval(self):

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_gptq_quantizer_int4wo(self):
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer, InputRecorder, TransformerEvalWrapper
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
from torchao.quantization._eval import InputRecorder, TransformerEvalWrapper
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
Expand Down Expand Up @@ -329,7 +335,8 @@ def test_gptq_quantizer_int4wo(self):

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_quantizer_int4wo(self):
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer, TransformerEvalWrapper
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
from torchao.quantization._eval import TransformerEvalWrapper
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
Expand Down Expand Up @@ -364,7 +371,7 @@ def test_quantizer_int4wo(self):

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_eval_wrapper(self):
from torchao.quantization.GPTQ import TransformerEvalWrapper
from torchao.quantization._eval import TransformerEvalWrapper
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
Expand Down
Loading

0 comments on commit caf13a6

Please sign in to comment.