Skip to content

Commit a908e78

Browse files
jainapurvafacebook-github-bot
authored andcommitted
Skip tests on fbcode (#1532)
Summary: Pull Request resolved: #1532 Skip tests on fbcode. Missing model checkpoints Differential Revision: D67982501
1 parent 8259a38 commit a908e78

File tree

1 file changed

+91
-81
lines changed

1 file changed

+91
-81
lines changed

test/quantization/test_gptq_mt.py

Lines changed: 91 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@
88
from torchao._models.llama.tokenizer import get_tokenizer
99
from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer, MultiTensor
1010
from torchao.quantization.utils import _lm_eval_available
11+
from torchao.utils import is_fbcode
12+
from torch.testing._internal.common_utils import run_tests
13+
14+
if is_fbcode():
15+
pytest.skip(
16+
"Skipping the test in fbcode due to missing model and tokenizer files"
17+
)
1118

1219
if _lm_eval_available:
1320
hqq_core = pytest.importorskip("hqq.core", reason="requires hqq")
@@ -246,89 +253,92 @@ def run_eval(self, tasks, limit):
246253

247254
return result
248255

256+
def test_gptq_mt():
257+
precision = torch.bfloat16
258+
device = "cuda"
259+
print("Loading model")
260+
checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
261+
model = Transformer.from_name(checkpoint_path.parent.name)
262+
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
263+
model.load_state_dict(checkpoint, assign=True)
264+
model = model.to(dtype=precision, device="cpu")
265+
model.eval()
266+
print("Model loaded")
267+
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
268+
assert tokenizer_path.is_file(), tokenizer_path
269+
tokenizer = get_tokenizer( # pyre-ignore[28]
270+
tokenizer_path,
271+
"Llama-2-7b-chat-hf",
272+
)
273+
print("Tokenizer loaded")
274+
275+
276+
blocksize = 128
277+
percdamp = 0.01
278+
groupsize = 64
279+
calibration_tasks = ["wikitext"]
280+
calibration_limit = None
281+
calibration_seq_length = 100
282+
input_prep_func = prepare_inputs_for_model
283+
pad_calibration_inputs = False
284+
print("Recording inputs")
285+
inputs = (
286+
InputRecorder(
287+
tokenizer,
288+
calibration_seq_length,
289+
input_prep_func,
290+
pad_calibration_inputs,
291+
model.config.vocab_size,
292+
device="cpu",
293+
)
294+
.record_inputs(
295+
calibration_tasks,
296+
calibration_limit,
297+
)
298+
.get_inputs()
299+
)
300+
print("Inputs recorded")
301+
quantizer = Int4WeightOnlyGPTQQuantizer(
302+
blocksize,
303+
percdamp,
304+
groupsize,
305+
)
249306

250-
precision = torch.bfloat16
251-
device = "cuda"
252-
print("Loading model")
253-
checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
254-
model = Transformer.from_name(checkpoint_path.parent.name)
255-
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
256-
model.load_state_dict(checkpoint, assign=True)
257-
model = model.to(dtype=precision, device="cpu")
258-
model.eval()
259-
print("Model loaded")
260-
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
261-
assert tokenizer_path.is_file(), tokenizer_path
262-
tokenizer = get_tokenizer( # pyre-ignore[28]
263-
tokenizer_path,
264-
"Llama-2-7b-chat-hf",
265-
)
266-
print("Tokenizer loaded")
267-
268-
269-
blocksize = 128
270-
percdamp = 0.01
271-
groupsize = 64
272-
calibration_tasks = ["wikitext"]
273-
calibration_limit = None
274-
calibration_seq_length = 100
275-
input_prep_func = prepare_inputs_for_model
276-
pad_calibration_inputs = False
277-
print("Recording inputs")
278-
inputs = (
279-
InputRecorder(
307+
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
308+
multi = [
309+
MultiTensor([inp for inp, _ in inputs]),
310+
MultiTensor([inds for _, inds in inputs]),
311+
]
312+
print("Quantizing model")
313+
model = quantizer.quantize(model, multi).cuda()
314+
print("Model quantized")
315+
print("Saving model and fixing state dict")
316+
regular_state_dict = model.state_dict() # defaultdict(torch.tensor)
317+
for key, value in model.state_dict().items():
318+
if isinstance(value, MultiTensor):
319+
regular_state_dict[key] = value.values[0]
320+
else:
321+
regular_state_dict[key] = value
322+
323+
model = Transformer.from_name(checkpoint_path.parent.name)
324+
remove = [k for k in regular_state_dict if "kv_cache" in k]
325+
for k in remove:
326+
del regular_state_dict[k]
327+
328+
model.load_state_dict(regular_state_dict, assign=True)
329+
torch.save(model.state_dict(), "model.pth")
330+
print("Running evaluation")
331+
result = TransformerEvalWrapper(
332+
model.to(device), # quantized model needs to run on cuda
280333
tokenizer,
281-
calibration_seq_length,
282-
input_prep_func,
283-
pad_calibration_inputs,
284-
model.config.vocab_size,
285-
device="cpu",
286-
)
287-
.record_inputs(
288-
calibration_tasks,
289-
calibration_limit,
334+
model.config.block_size,
335+
prepare_inputs_for_model,
336+
).run_eval(
337+
["wikitext"],
338+
None,
290339
)
291-
.get_inputs()
292-
)
293-
print("Inputs recorded")
294-
quantizer = Int4WeightOnlyGPTQQuantizer(
295-
blocksize,
296-
percdamp,
297-
groupsize,
298-
)
299-
300-
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
301-
multi = [
302-
MultiTensor([inp for inp, _ in inputs]),
303-
MultiTensor([inds for _, inds in inputs]),
304-
]
305-
print("Quantizing model")
306-
model = quantizer.quantize(model, multi).cuda()
307-
print("Model quantized")
308-
print("Saving model and fixing state dict")
309-
regular_state_dict = model.state_dict() # defaultdict(torch.tensor)
310-
for key, value in model.state_dict().items():
311-
if isinstance(value, MultiTensor):
312-
regular_state_dict[key] = value.values[0]
313-
else:
314-
regular_state_dict[key] = value
315-
316-
model = Transformer.from_name(checkpoint_path.parent.name)
317-
remove = [k for k in regular_state_dict if "kv_cache" in k]
318-
for k in remove:
319-
del regular_state_dict[k]
320-
321-
model.load_state_dict(regular_state_dict, assign=True)
322-
torch.save(model.state_dict(), "model.pth")
323-
print("Running evaluation")
324-
result = TransformerEvalWrapper(
325-
model.to(device), # quantized model needs to run on cuda
326-
tokenizer,
327-
model.config.block_size,
328-
prepare_inputs_for_model,
329-
).run_eval(
330-
["wikitext"],
331-
None,
332-
)
340+
341+
if __name__ == "__main__":
342+
run_tests()
333343

334344
# wikitext: {'word_perplexity,none': 12.523175352665858, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.6042723245990418, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.681919059499152, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}

0 commit comments

Comments
 (0)