Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
import gc
import unittest
from unittest import skip

from transformers import AutoModelForCausalLM, AutoTokenizer, CompressedTensorsConfig
from transformers.testing_utils import (
backend_empty_cache,
require_compressed_tensors,
require_deterministic_for_xpu,
require_torch,
torch_device,
)
from transformers.testing_utils import backend_empty_cache, require_compressed_tensors, require_torch, torch_device
from transformers.utils import is_torch_available


Expand All @@ -20,12 +13,12 @@
@require_compressed_tensors
@require_torch
class CompressedTensorsTest(unittest.TestCase):
tinyllama_w8a16 = "nm-testing/tinyllama-w8a16-dense"
tinyllama_w4a16 = "nm-testing/tinyllama-w4a16-compressed"
tinyllama_w8a8 = "nm-testing/tinyllama-w8a8-compressed"
llama3_8b_fp8 = "nm-testing/Meta-Llama-3-8B-Instruct-fp8-hf_compat"
tinyllama_w4a16 = "nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-e2e"
tinyllama_int8 = "nm-testing/TinyLlama-1.1B-Chat-v1.0-W8A8-e2e"
tinyllama_fp8 = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
tinyllama_w8a16 = "nm-testing/TinyLlama-1.1B-Chat-v1.0-W8A16-e2e"

prompt = "Paris is the capital of which country?"
prompt = "The capital of France is Paris, the capital of Germany is Berlin"

def tearDown(self):
gc.collect()
Expand Down Expand Up @@ -53,43 +46,30 @@ def test_config_to_from_dict(self):
self.assertIsInstance(config_from_dict.quantization_config, QuantizationConfig)
self.assertIsInstance(config_from_dict.sparsity_config, SparsityCompressionConfig)

@skip("Test too flaky, depends on hardware also")
def test_tinyllama_w8a8(self):
expected_out = [
"<s> Paris is the capital of which country?\n\n**A) 10** Paris is the capital of which country?\n\n**B) 11** Paris is the capital of which country?\n\n**C) 1",
"<s> Paris is the capital of which country?\n\n** 10.** Which country is the capital of which country?\n\n** 11.** Which country is the capital of which country?\n\n** 12.", # XPU
]
self._test_quantized_model(self.tinyllama_w8a8, expected_out)

def test_tinyllama_w4a16(self):
expected_out = [
"<s> Paris is the capital of which country?\nAnswer: Paris is the capital of France.\nQuestion: Which country is the capital of which city?\nAnswer: The capital of the city of New York is New York.\nQuestion: Which"
]
self._test_quantized_model(self.tinyllama_w4a16, expected_out)
self._test_quantized_model(self.tinyllama_w4a16, 20.0)

def test_tinyllama_int8(self):
self._test_quantized_model(self.tinyllama_int8, 30.0)

def test_tinyllama_fp8(self):
self._test_quantized_model(self.tinyllama_fp8, 20.0)

def test_tinyllama_w8a16(self):
expected_out = [
"<s> Paris is the capital of which country?\nA. France\nB. Germany\nC. Spain\nD. Italy\nE. Switzerland\nQ10. Which of the following is not a country in the European Union?\nA."
]
self._test_quantized_model(self.tinyllama_w8a16, expected_out)

def test_llama_8b_fp8(self):
expected_out = [
"<|begin_of_text|>Paris is the capital of which country? France\nWhat is the name of the famous art museum in Paris? The Louvre\nWhat is the name of the famous bridge in Paris? Pont des Arts\nWhat is the name of the famous opera? ",
"<|begin_of_text|>Paris is the capital of which country? France\nWhat is the name of the famous art museum in Paris? The Louvre\nWhat is the name of the famous bridge in Paris? Pont des Arts\nWhat is the name of the famous opera", # XPU
]
self._test_quantized_model(self.llama3_8b_fp8, expected_out)

@require_deterministic_for_xpu
def _test_quantized_model(self, model_name: str, expected_output: list):
"""Carry out generation"""
self._test_quantized_model(self.tinyllama_w8a16, 20.0)

def _test_quantized_model(self, model_name: str, expected_perplexity: float):
# load model
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = quantized_model.device

# check config
self.assertIsNotNone(
quantized_model.config.quantization_config,
"quantization_config should not be None",
)
# check scales
self.assertTrue(
any(
key
Expand All @@ -98,9 +78,13 @@ def _test_quantized_model(self, model_name: str, expected_output: list):
),
"quantized model should load a non-trivial scale into the state dict",
)

# compute outputs with loss
inputs = tokenizer(self.prompt, return_tensors="pt").to(device)
generated_ids = quantized_model.generate(**inputs, max_length=50, do_sample=False)
outputs = tokenizer.batch_decode(generated_ids)
labels = inputs["input_ids"]
with torch.no_grad():
outputs = quantized_model(**inputs, labels=labels)

self.assertIsNotNone(outputs)
self.assertIn(outputs[0], expected_output)
# check perplexity
perplexity = torch.exp(outputs.loss)
self.assertLessEqual(perplexity, expected_perplexity)