Skip to content

Commit 8dbfcd3

Browse files
robertgshaw2-redhatmgoinRobert Shaw
authored
[ CI/Build ] Added E2E Test For Compressed Tensors (#5839)
Co-authored-by: Michael Goin <michael@neuralmagic.com> Co-authored-by: Robert Shaw <rshaw@neuralmagic>
1 parent f7dac83 commit 8dbfcd3

File tree

4 files changed

+57
-1
lines changed

4 files changed

+57
-1
lines changed

requirements-test.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ peft
1414
requests
1515
ray
1616
sentence-transformers # required for embedding
17+
sparseml==1.8.0 # required for compressed-tensors
18+
compressed-tensors==0.4.0 # required for compressed-tensors
1719

1820
# Benchmarking
1921
aiohttp

tests/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def __init__(
176176
model_kwargs: Optional[Dict[str, Any]] = None,
177177
is_embedding_model: bool = False,
178178
is_vision_model: bool = False,
179+
is_sparseml_model: bool = False,
179180
) -> None:
180181
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
181182
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
@@ -193,6 +194,9 @@ def __init__(
193194
else:
194195
if is_vision_model:
195196
auto_cls = AutoModelForVision2Seq
197+
elif is_sparseml_model:
198+
from sparseml.transformers import SparseAutoModelForCausalLM
199+
auto_cls = SparseAutoModelForCausalLM
196200
else:
197201
auto_cls = AutoModelForCausalLM
198202

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""Compares vllm vs sparseml for compressed-tensors
2+
3+
Note: vllm and sparseml do not have bitwise correctness,
4+
so in this test, we just confirm that the top selected
5+
tokens of the are in the top 5 selections of each other.
6+
"""
7+
8+
import pytest
9+
10+
from tests.quantization.utils import is_quant_method_supported
11+
12+
from .utils import check_logprobs_close
13+
14+
MODELS = [
15+
"nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test",
16+
]
17+
18+
MAX_TOKENS = 32
19+
NUM_LOGPROBS = 5
20+
21+
22+
@pytest.mark.skipif(
23+
not is_quant_method_supported("compressed-tensors"),
24+
reason="compressed-tensors is not supported on this machine type.")
25+
@pytest.mark.parametrize("model_name", MODELS)
26+
def test_models(
27+
vllm_runner,
28+
hf_runner,
29+
example_prompts,
30+
model_name,
31+
) -> None:
32+
# Run sparseml.
33+
with hf_runner(model_name=model_name,
34+
is_sparseml_model=True) as sparseml_model:
35+
36+
sparseml_outputs = sparseml_model.generate_greedy_logprobs_limit(
37+
example_prompts, MAX_TOKENS, NUM_LOGPROBS)
38+
39+
# Run vllm.
40+
with vllm_runner(model_name=model_name) as vllm_model:
41+
vllm_outputs = vllm_model.generate_greedy_logprobs(
42+
example_prompts, MAX_TOKENS, NUM_LOGPROBS)
43+
44+
check_logprobs_close(
45+
outputs_0_lst=sparseml_outputs,
46+
outputs_1_lst=vllm_outputs,
47+
name_0="sparseml",
48+
name_1="vllm",
49+
)

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
3434
return [torch.float16, torch.bfloat16]
3535

3636
# Need to figure it out
37-
def get_min_capability(self) -> int:
37+
@classmethod
38+
def get_min_capability(cls) -> int:
3839
return 60
3940

4041
def get_name(self) -> str:

0 commit comments

Comments
 (0)