Skip to content

Commit 956e16c

Browse files
committed
GPTQ updates
Summary: 1) reorganized GPTQ a) got rid of old GPTQ and renamed GPTQ_MT to GPTQ b) moved new GPTQ to prototype c) moved quantized linear modules in GPTQ.py to linear_quant_modules.py 2) removed dependence on lm_eval for input_recorder a) created new input recorder that doesn't depend on lm_eval b) made lm_eval input recorder depend on new generic input_recorder c) made TransformerEvalWrapper the base class and made d) updated apis generally to work with new input recorder LMEvalInputRecorder inherit from it instead of vice-versa 3) reorganized GPTQ tests a) moved tests from test_quant_api.py to test_gptq.py b) added new test that can be run in CI that doesn't depend on lm_eval/llama weights c) got rid of test_gptq_mt.py 4) added new documentation for lm_eval 5) GPTQ improvements a) reimplemented faster quant b) tested compilation of hessian calculation and parts of faster quant, generally they were slower. c) moved helper functions out of the class. They're largely generic and this is less cluttered. d) some improvements to the duplication checking and copying to be faster when possible e) fixed some bugs due to this not being in CI and things changing for int4wo tensor subclass. Test Plan: 1) `python test_gptq.py` note: the skipped test test_gptq_quantizer_int4_weight_only also ran. 2) I verified that all activation match between old GPTQ and current GPTQ 3) ```shell export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64 python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-gptq-64 --calibration_limit 10 export MODEL_REPO=meta-llama/Meta-Llama-3-8B python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64 python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-gptq-64 --calibration_limit 10 ``` see README.md for results but they show GPTQ is working Reviewers: Subscribers: Tasks: Tags:
1 parent 446f07d commit 956e16c

17 files changed

+1602
-2480
lines changed

test/quantization/test_gptq.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
import unittest
2+
from pathlib import Path
3+
4+
import torch
5+
from torch.testing._internal.common_utils import TestCase
6+
7+
from torchao._models.llama.model import (
8+
ModelArgs,
9+
Transformer,
10+
prepare_inputs_for_model,
11+
)
12+
from torchao._models.llama.tokenizer import get_tokenizer
13+
from torchao.quantization import Int4WeightOnlyConfig, quantize_
14+
from torchao.quantization.utils import compute_error
15+
from torchao.utils import (
16+
TORCH_VERSION_AT_LEAST_2_4,
17+
)
18+
19+
torch.manual_seed(0)
20+
21+
22+
class TestGPTQ(TestCase):
23+
@unittest.skip("skipping until we get checkpoints for gpt-fast")
24+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
25+
def test_gptq_quantizer_int4_weight_only(self):
26+
from torchao._models._eval import (
27+
LMEvalInputRecorder,
28+
TransformerEvalWrapper,
29+
)
30+
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
31+
32+
precision = torch.bfloat16
33+
device = "cuda"
34+
checkpoint_path = Path(
35+
"../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"
36+
)
37+
model = Transformer.from_name(checkpoint_path.parent.name)
38+
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
39+
model.load_state_dict(checkpoint, assign=True)
40+
model = model.to(dtype=precision, device="cpu")
41+
model.eval()
42+
43+
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
44+
assert tokenizer_path.is_file(), tokenizer_path
45+
tokenizer = get_tokenizer( # pyre-ignore[28]
46+
tokenizer_path,
47+
"Llama-2-7b-chat-hf",
48+
)
49+
groupsize = 64
50+
blocksize = 128
51+
percdamp = 0.01
52+
calibration_tasks = ["wikitext"]
53+
calibration_limit = 1
54+
calibration_seq_length = 100
55+
input_prep_func = prepare_inputs_for_model
56+
pad_calibration_inputs = False
57+
inputs = (
58+
LMEvalInputRecorder(
59+
tokenizer,
60+
calibration_seq_length,
61+
input_prep_func,
62+
model.config.vocab_size,
63+
pad_calibration_inputs,
64+
device="cpu",
65+
)
66+
.record_inputs(
67+
calibration_tasks,
68+
calibration_limit,
69+
)
70+
.get_recorded_inputs()
71+
)
72+
73+
quantizer = Int4WeightOnlyGPTQQuantizer(
74+
groupsize,
75+
blocksize,
76+
percdamp,
77+
)
78+
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
79+
80+
model = quantizer.quantize(model, *inputs).cuda()
81+
82+
model.reset_caches()
83+
with torch.device("cuda"):
84+
model.setup_caches(max_batch_size=1, max_seq_length=model.config.block_size)
85+
86+
limit = 1
87+
result = TransformerEvalWrapper(
88+
model.cuda(),
89+
tokenizer,
90+
model.config.block_size,
91+
prepare_inputs_for_model,
92+
device,
93+
).run_eval(
94+
["wikitext"],
95+
limit,
96+
)
97+
98+
assert result["results"]["wikitext"]["word_perplexity,none"] < 7.77, (
99+
f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}"
100+
)
101+
102+
103+
class TestMultiTensorFlow(TestCase):
104+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
105+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
106+
def test_multitensor_add_tensors(self):
107+
from torchao.quantization.GPTQ import MultiTensor
108+
109+
tensor1 = torch.randn(3, 3)
110+
tensor2 = torch.randn(3, 3)
111+
mt = MultiTensor(tensor1)
112+
mt.add_tensors(tensor2)
113+
self.assertEqual(mt.count, 2)
114+
self.assertTrue(torch.equal(mt.values[0], tensor1))
115+
self.assertTrue(torch.equal(mt.values[1], tensor2))
116+
117+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
118+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
119+
def test_multitensor_pad_unpad(self):
120+
from torchao.quantization.GPTQ import MultiTensor
121+
122+
tensor1 = torch.randn(3, 3)
123+
mt = MultiTensor(tensor1)
124+
mt.pad_to_length(3)
125+
self.assertEqual(mt.count, 3)
126+
mt.unpad()
127+
self.assertEqual(mt.count, 1)
128+
129+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
130+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
131+
def test_multitensor_inplace_operation(self):
132+
from torchao.quantization.GPTQ import MultiTensor
133+
134+
tensor1 = torch.ones(3, 3)
135+
mt = MultiTensor(tensor1)
136+
mt += 1 # In-place addition
137+
self.assertTrue(torch.equal(mt.values[0], torch.full((3, 3), 2)))
138+
139+
140+
class TestMultiTensorInputRecorder(TestCase):
141+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
142+
def test_multitensor_input_recorder(self):
143+
from torchao.quantization.GPTQ import MultiTensor, MultiTensorInputRecorder
144+
145+
input_recorder = MultiTensorInputRecorder()
146+
in1 = ([1], torch.randn(3, 3), (1, "dog", torch.randn(3, 3)), torch.float)
147+
in2 = ([1], torch.randn(3, 3), (1, "dog", torch.randn(3, 3)), torch.float)
148+
149+
input_recorder(*in1)
150+
input_recorder(*in2)
151+
152+
MT_input = input_recorder.get_recorded_inputs()
153+
154+
self.assertEqual(MT_input[0], [1])
155+
self.assertTrue(isinstance(MT_input[1], MultiTensor))
156+
self.assertTrue(isinstance(MT_input[2], tuple))
157+
self.assertEqual(MT_input[2][0], 1)
158+
self.assertEqual(MT_input[2][1], "dog")
159+
self.assertTrue(isinstance(MT_input[2][2], MultiTensor))
160+
self.assertEqual(MT_input[3], torch.float)
161+
162+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
163+
def test_gptq_with_input_recorder(self):
164+
from torchao.quantization.GPTQ import (
165+
Int4WeightOnlyGPTQQuantizer,
166+
MultiTensorInputRecorder,
167+
)
168+
169+
torch.set_default_dtype(torch.bfloat16)
170+
171+
config = ModelArgs(n_layer=2)
172+
173+
with torch.device("cuda"):
174+
model = Transformer(config)
175+
model.setup_caches(max_batch_size=2, max_seq_length=100)
176+
idx = torch.randint(1, 10000, (10, 2, 50)).to(torch.int32)
177+
test_input = prepare_inputs_for_model(idx[0])
178+
import copy
179+
180+
model2 = copy.deepcopy(model)
181+
out = model(*test_input)
182+
quantize_(model2, Int4WeightOnlyConfig())
183+
184+
outq = model2(*test_input)
185+
del model2
186+
187+
input_recorder = MultiTensorInputRecorder()
188+
for i in range(10):
189+
input = prepare_inputs_for_model(idx[i])
190+
input_recorder(*input)
191+
192+
args = input_recorder.get_recorded_inputs()
193+
194+
quantizer = Int4WeightOnlyGPTQQuantizer()
195+
196+
quantizer.quantize(model, *args)
197+
198+
outgptq = model(*test_input)
199+
200+
self.assertGreater(compute_error(outgptq, out), 30)
201+
self.assertGreater(compute_error(outgptq, out), compute_error(outq, out))
202+
203+
204+
if __name__ == "__main__":
205+
unittest.main()

0 commit comments

Comments
 (0)