Skip to content

Commit

Permalink
Extract eval code from GPTQ for more general usage
Browse files Browse the repository at this point in the history
Summary: This commit extracts all the eval code from GPTQ.py.
This is the first step towards having a general eval framework
in torchao. The eventual goal is to use lm_eval to produce
reproducible benchmarks for the quantization APIs in torchao
that we can showcase on the main README. This will have the
added benefit of allowing us to add (possibly nightly)
regression test suites for important models.

Test Plan:

```
2024-05-24:14:50:32,647 INFO     [task.py:395] Building contexts for wikitext on rank 0...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1132.98it/s]
2024-05-24:14:50:32,648 INFO     [evaluator.py:362] Running loglikelihood_rolling requests
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:51<00:00, 51.39s/it]
wikitext: {'word_perplexity,none': 7.877762491958485, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.488984329919892, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.5743285710685551, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}
.
----------------------------------------------------------------------
Ran 1 test in 858.105s

OK
```

python test/quantization/test_quant_api.py -k test_8da4w_gptq_quantizer
  • Loading branch information
andrewor14 committed May 28, 2024
1 parent f8f74c7 commit 4318395
Show file tree
Hide file tree
Showing 4 changed files with 284 additions and 263 deletions.
17 changes: 12 additions & 5 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,14 @@ def test_8da4w_quantizer(self):
assert isinstance(m.linear2, Int8DynActInt4WeightLinear)
m(*example_inputs)

# TODO: save model weights as artifacts and re-enable in CI
# For now, to run this test, you will need to download the weights from HF
# and run this script to convert them:
# https://github.com/pytorch-labs/gpt-fast/blob/6253c6bb054e658d67566150f87329b87815ae63/scripts/convert_hf_checkpoint.py
@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_8da4w_gptq_quantizer(self):
from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer, InputRecorder, TransformerEvalWrapper
from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer
from torchao._eval import InputRecorder, TransformerEvalWrapper
# should be similar to TorchCompileDynamicQuantizer
precision = torch.bfloat16
device = "cpu"
Expand Down Expand Up @@ -250,7 +255,7 @@ def test_8da4w_gptq_quantizer(self):
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_8da4w_quantizer_eval(self):
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
from torchao.quantization.GPTQ import TransformerEvalWrapper
from torchao._eval import TransformerEvalWrapper

precision = torch.bfloat16
device = "cpu"
Expand Down Expand Up @@ -284,7 +289,8 @@ def test_8da4w_quantizer_eval(self):

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_gptq_quantizer_int4wo(self):
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer, InputRecorder, TransformerEvalWrapper
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
from torchao._eval import InputRecorder, TransformerEvalWrapper
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
Expand Down Expand Up @@ -343,7 +349,8 @@ def test_gptq_quantizer_int4wo(self):

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_quantizer_int4wo(self):
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer, TransformerEvalWrapper
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
from torchao._eval import TransformerEvalWrapper
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
Expand Down Expand Up @@ -378,7 +385,7 @@ def test_quantizer_int4wo(self):

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_eval_wrapper(self):
from torchao.quantization.GPTQ import TransformerEvalWrapper
from torchao._eval import TransformerEvalWrapper
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
Expand Down
228 changes: 228 additions & 0 deletions torchao/_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch

from .utils import _lm_eval_available, _MultiInput

if _lm_eval_available:
try: # lm_eval version 0.4
from lm_eval.evaluator import evaluate # pyre-ignore[21]
from lm_eval.models.huggingface import HFLM as eval_wrapper # pyre-ignore[21]
from lm_eval.tasks import get_task_dict # pyre-ignore[21]
except: # lm_eval version 0.3
from lm_eval import base, evaluator, tasks

eval_wrapper = base.BaseLM
get_task_dict = tasks.get_task_dict
evaluate = evaluator.evaluate

class InputRecorder(eval_wrapper):
"""
This is a fake evaluation wrapper from the lm_eval library that just records the inputs
so that they can be used in calibration.
If pad_calibration_inputs is enabled, the input recorder will take
each input and pad/truncate it down to the calibration_seq_length.
(if using padding you should set the embeddings for the pad_token to 0
in the model)
Note: after padding/truncation, input_prep_function is called to bring
it to the proper form to be inserted into a given model.
If not, it will only truncate inputs to the desired length.
"""

def __init__(
self,
tokenizer,
calibration_seq_length,
input_prep_func=None,
pad_calibration_inputs=False,
vocab_size=32000,
pad_token=0,
device="cpu",
):
super().__init__()
self._tokenizer = tokenizer
self._device = torch.device(device)
self.vocab_size = vocab_size
self._max_seq_length = calibration_seq_length
self.calibration_seq_length = calibration_seq_length

# need to take inps and convert to corrent input
# for model
self.input_prep_func = (
input_prep_func if input_prep_func is not None
else lambda x: (x,)
)

self.pad_calibration_inputs = pad_calibration_inputs
self.pad_token = pad_token

self.inputs = None

@property
def eot_token_id(self):
try:
return self._tokenizer.eos_id()
except:
return self._tokenizer.eos_id

@property
def max_length(self):
return self._max_seq_length

@property
def max_gen_toks(self):
return 50

@property
def batch_size(self):
return 1

@property
def device(self):
return self._device

def tok_encode(self, string: str, **kwargs):
# TODO: verify this for multi-batch as well
tokens = self._tokenizer.encode(string)
if hasattr(self._tokenizer, "bos_id"):
try:
tokens = [self._tokenizer.bos_id()] + tokens
except:
tokens = [self._tokenizer.bos_id] + tokens
return tokens

def tok_decode(self, tokens):
decoded = self._tokenizer.decode(tokens)
return decoded

def add_input(self, args):
if self.inputs is None:
self.inputs = [_MultiInput([arg]) for arg in args]
else:
self.inputs = [
multi.add_input(arg) for (multi, arg) in zip(self.inputs, args)
]

def record_inputs(
self,
calibration_tasks,
calibration_limit,
):
try:
lm_eval.tasks.initialize_tasks()
except:
pass

task_dict = get_task_dict(calibration_tasks)
print("Obtaining GPTQ calibration inputs on: ", calibration_tasks)

evaluate(
self,
task_dict,
limit=calibration_limit,
)
return self

def get_inputs(self):
return self.inputs

def _model_call(self, inps):
inps = inps.squeeze(0)
T = len(inps)
if (
# can't use inputs that are too short when padding disabled
(T < self.calibration_seq_length and not self.pad_calibration_inputs)
or
# can't use inputs that actually use token we use for padding
(self.pad_calibration_inputs and self.pad_token in inps)
):
# give random output
return torch.randn(
(1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device
)

# pad or truncate to the right size
if T >= self.calibration_seq_length:
inps = inps[: self.calibration_seq_length]
else:
inps = F.pad(inps, (self.pad_token, self.calibration_seq_length - T))

inps = inps.unsqueeze(0)
model_in = self.input_prep_func(inps)

self.add_input(model_in)

# output `something` with correct shape to keep eval going
return torch.randn(
(1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device
)

def _model_generate(self, context, max_length, eos_token_id):
raise Exception("unimplemented")

class TransformerEvalWrapper(InputRecorder):
"""
A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library.
"""
def __init__(
self,
model,
tokenizer,
max_seq_length,
input_prep_func=None,
device="cuda"
):
super().__init__(None, None)
self._model = model
self._tokenizer = tokenizer
self._device = torch.device(device)
self._max_seq_length = max_seq_length

# need to take inps and convert to corrent input
# for model
self.input_prep_func = (
input_prep_func if input_prep_func is not None
else lambda x: (x,)
)

def _model_call(self, inps):
# TODO: make batches work
input = self.input_prep_func(inps)

max_seq_length = min(inps.size(1), self.max_length)
with torch.device(self._device):
self._model.setup_caches(self.batch_size, max_seq_length)
logits = self._model(*input)
return logits

def _model_generate(self, context, max_length, eos_token_id):
raise Exception('unimplemented')

def run_eval(self, tasks, limit):
try:
lm_eval.tasks.initialize_tasks()
except:
pass

task_dict = get_task_dict(tasks)
print("Evaluating Model On: ", task_dict)
with torch.no_grad():
result = evaluate(
self,
task_dict,
limit=limit,
)
for task, res in result["results"].items():
print(f"{task}: {res}")
return result
Loading

0 comments on commit 4318395

Please sign in to comment.