Skip to content

Commit b0a333c

Browse files
authored
add int4 gptq and eval (#116)
* add int4 gptq and eval Summary: adding int4 gptq and eval support. Also fixed a few bugs relating to quantizing the activation both during gptq calculation and when calculating the output. Test Plan: python test/quantization/test_quant_api.py Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: d29b6d7 Pull Request resolved: #115 * add int4 gptq and eval Summary: adding int4 gptq and eval support. Also fixed a few bugs relating to quantizing the activation both during gptq calculation and when calculating the output. Test Plan: python test/quantization/test_quant_api.py Reviewers: Subscribers: Tasks: Tags: * remove debug from GPTQ Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 046dc98 commit b0a333c

File tree

6 files changed

+522
-339
lines changed

6 files changed

+522
-339
lines changed

test/quantization/model.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@
1010
import torch.nn as nn
1111
from torch import Tensor
1212
from torch.nn import functional as F
13+
from torchao.quantization.utils import find_multiple
1314

14-
def prepare_inputs_for_model(inps):
15+
def prepare_inputs_for_model(inps, max_new_tokens=1):
1516
# this is because input from lm-eval is 2d
16-
if input.dim() != 2:
17-
raise ValueError(f"Expected input to be of dim 2, but got {input.dim()}")
17+
if inps.dim() != 2:
18+
raise ValueError(f"Expected input to be of dim 2, but got {inps.dim()}")
1819

1920
inps = inps.squeeze(0)
2021
# setup inputs in correct format
21-
max_new_tokens = 1
2222
T = inps.size(0)
2323
T_new = T + max_new_tokens
2424
seq = torch.empty(T_new, dtype=inps.dtype, device=inps.device)
@@ -27,11 +27,6 @@ def prepare_inputs_for_model(inps):
2727
x = seq.index_select(0, input_pos).view(1, -1)
2828
return (x, input_pos)
2929

30-
def find_multiple(n: int, k: int) -> int:
31-
if n % k == 0:
32-
return n
33-
return n + k - (n % k)
34-
3530
@dataclass
3631
class ModelArgs:
3732
block_size: int = 2048

test/quantization/test_quant_api.py

Lines changed: 144 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ def test_8da4w_quantizer(self):
151151
m(*example_inputs)
152152

153153
@unittest.skip("skipping until we get checkpoints for gpt-fast")
154-
def test_gptq_quantizer(self):
155-
from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer, InputRecorder
154+
def test_8da4w_gptq_quantizer(self):
155+
from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer, InputRecorder, TransformerEvalWrapper
156156
# should be similar to TorchCompileDynamicQuantizer
157157
precision = torch.bfloat16
158158
device = "cpu"
@@ -161,6 +161,7 @@ def test_gptq_quantizer(self):
161161
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
162162
model.load_state_dict(checkpoint, assign=True)
163163
model = model.to(dtype=precision, device=device)
164+
model.eval()
164165
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
165166
assert tokenizer_path.is_file(), tokenizer_path
166167
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
@@ -190,12 +191,60 @@ def test_gptq_quantizer(self):
190191
blocksize,
191192
percdamp,
192193
groupsize,
194+
precision=precision,
193195
)
194196
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
195197
model = quantizer.quantize(model, inputs)
196-
compiled = torch.compile(model, mode="max-autotune")
197-
with torch.no_grad():
198-
compiled(inputs[0].values[0], inputs[1].values[0])
198+
result=TransformerEvalWrapper(
199+
model,
200+
tokenizer,
201+
model.config.block_size,
202+
prepare_inputs_for_model,
203+
device,
204+
).run_eval(
205+
["wikitext"],
206+
1,
207+
)
208+
209+
assert result['results']['wikitext']['word_perplexity,none'] < 7.88, (
210+
f"accuracy regressed from 7.87 to {result['results']['wikitext']['word_perplexity,none']}"
211+
)
212+
213+
@unittest.skip("skipping until we get checkpoints for gpt-fast")
214+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
215+
def test_8da4w_quantizer_eval(self):
216+
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
217+
from torchao.quantization.GPTQ import TransformerEvalWrapper
218+
219+
precision = torch.bfloat16
220+
device = "cpu"
221+
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
222+
model = Transformer.from_name(checkpoint_path.parent.name)
223+
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
224+
model.load_state_dict(checkpoint, assign=True)
225+
model = model.to(dtype=precision, device=device)
226+
model.eval()
227+
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
228+
assert tokenizer_path.is_file(), tokenizer_path
229+
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
230+
model_file=str(tokenizer_path)
231+
)
232+
233+
quantizer = Int8DynActInt4WeightQuantizer(groupsize=128, precision=precision)
234+
q_model = quantizer.quantize(model)
235+
result=TransformerEvalWrapper(
236+
q_model,
237+
tokenizer,
238+
q_model.config.block_size,
239+
prepare_inputs_for_model,
240+
device,
241+
).run_eval(
242+
["wikitext"],
243+
1,
244+
)
245+
assert result['results']['wikitext']['word_perplexity,none'] < 8.24, (
246+
f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}"
247+
)
199248

200249
@unittest.skip("skipping until we get checkpoints for gpt-fast")
201250
def test_gptq_quantizer_gpt_fast(self):
@@ -248,5 +297,95 @@ def test_gptq_quantizer_gpt_fast(self):
248297
with torch.no_grad():
249298
compiled(inputs[0].values[0], inputs[1].values[0])
250299

300+
@unittest.skip("skipping until we get checkpoints for gpt-fast")
301+
def test_gptq_quantizer_int4wo(self):
302+
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer, InputRecorder, TransformerEvalWrapper
303+
# should be similar to TorchCompileDynamicQuantizer
304+
precision = torch.bfloat16
305+
device = "cuda"
306+
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
307+
model = Transformer.from_name(checkpoint_path.parent.name)
308+
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
309+
model.load_state_dict(checkpoint, assign=True)
310+
model = model.to(dtype=precision, device="cpu")
311+
model.eval()
312+
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
313+
assert tokenizer_path.is_file(), tokenizer_path
314+
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
315+
model_file=str(tokenizer_path)
316+
)
317+
blocksize = 128
318+
percdamp = 0.01
319+
groupsize = 128
320+
calibration_tasks = ["wikitext"]
321+
calibration_limit = 1
322+
calibration_seq_length = 100
323+
input_prep_func = prepare_inputs_for_model
324+
pad_calibration_inputs = False
325+
326+
inputs = InputRecorder(
327+
tokenizer,
328+
calibration_seq_length,
329+
input_prep_func,
330+
pad_calibration_inputs,
331+
model.config.vocab_size,
332+
device="cpu",
333+
).record_inputs(
334+
calibration_tasks,
335+
calibration_limit,
336+
).get_inputs()
337+
338+
quantizer = Int4WeightOnlyGPTQQuantizer(
339+
blocksize,
340+
percdamp,
341+
groupsize,
342+
)
343+
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
344+
345+
model = quantizer.quantize(model, inputs).cuda()
346+
result = TransformerEvalWrapper(
347+
model.cuda(),
348+
tokenizer,
349+
model.config.block_size,
350+
prepare_inputs_for_model,
351+
device,
352+
).run_eval(
353+
["wikitext"],
354+
1,
355+
)
356+
assert result['results']['wikitext']['word_perplexity,none'] < 7.77, (
357+
f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}"
358+
)
359+
360+
@unittest.skip("skipping until we get checkpoints for gpt-fast")
361+
def test_eval_wrapper(self):
362+
from torchao.quantization.GPTQ import TransformerEvalWrapper
363+
precision = torch.bfloat16
364+
device = "cuda"
365+
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
366+
model = Transformer.from_name(checkpoint_path.parent.name)
367+
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
368+
model.load_state_dict(checkpoint, assign=True)
369+
model = model.to(dtype=precision, device=device)
370+
model.eval()
371+
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
372+
assert tokenizer_path.is_file(), tokenizer_path
373+
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
374+
model_file=str(tokenizer_path)
375+
)
376+
result=TransformerEvalWrapper(
377+
model,
378+
tokenizer,
379+
model.config.block_size,
380+
prepare_inputs_for_model,
381+
device,
382+
).run_eval(
383+
["wikitext"],
384+
1,
385+
)
386+
assert result['results']['wikitext']['word_perplexity,none']<7.77, (
387+
f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}"
388+
)
389+
251390
if __name__ == "__main__":
252391
unittest.main()

0 commit comments

Comments
 (0)