From 7d508d8b4d60f37cdef4cbb53bc79279ee2014a1 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Sat, 29 Jun 2024 09:12:58 -0400 Subject: [PATCH] [ CI/Build ] Added E2E Test For Compressed Tensors (#5839) Co-authored-by: Michael Goin Co-authored-by: Robert Shaw --- requirements-test.txt | 2 + tests/conftest.py | 4 ++ tests/models/test_compressed_tensors.py | 49 +++++++++++++++++++ .../compressed_tensors/compressed_tensors.py | 3 +- 4 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 tests/models/test_compressed_tensors.py diff --git a/requirements-test.txt b/requirements-test.txt index 3ebfc16547e44..a7604d2e1015e 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -14,6 +14,8 @@ peft requests ray sentence-transformers # required for embedding +sparseml==1.8.0 # required for compressed-tensors +compressed-tensors==0.4.0 # required for compressed-tensors # Benchmarking aiohttp diff --git a/tests/conftest.py b/tests/conftest.py index 9d00c76766943..b429d8d0b5600 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -176,6 +176,7 @@ def __init__( model_kwargs: Optional[Dict[str, Any]] = None, is_embedding_model: bool = False, is_vision_model: bool = False, + is_sparseml_model: bool = False, ) -> None: assert dtype in _STR_DTYPE_TO_TORCH_DTYPE torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] @@ -193,6 +194,9 @@ def __init__( else: if is_vision_model: auto_cls = AutoModelForVision2Seq + elif is_sparseml_model: + from sparseml.transformers import SparseAutoModelForCausalLM + auto_cls = SparseAutoModelForCausalLM else: auto_cls = AutoModelForCausalLM diff --git a/tests/models/test_compressed_tensors.py b/tests/models/test_compressed_tensors.py new file mode 100644 index 0000000000000..9a0054c5aff53 --- /dev/null +++ b/tests/models/test_compressed_tensors.py @@ -0,0 +1,49 @@ +"""Compares vllm vs sparseml for compressed-tensors + +Note: vllm and sparseml do not have bitwise correctness, +so in this test, we just confirm that the top selected +tokens of the are in the top 5 selections of each other. +""" + +import pytest + +from tests.quantization.utils import is_quant_method_supported + +from .utils import check_logprobs_close + +MODELS = [ + "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test", +] + +MAX_TOKENS = 32 +NUM_LOGPROBS = 5 + + +@pytest.mark.skipif( + not is_quant_method_supported("compressed-tensors"), + reason="compressed-tensors is not supported on this machine type.") +@pytest.mark.parametrize("model_name", MODELS) +def test_models( + vllm_runner, + hf_runner, + example_prompts, + model_name, +) -> None: + # Run sparseml. + with hf_runner(model_name=model_name, + is_sparseml_model=True) as sparseml_model: + + sparseml_outputs = sparseml_model.generate_greedy_logprobs_limit( + example_prompts, MAX_TOKENS, NUM_LOGPROBS) + + # Run vllm. + with vllm_runner(model_name=model_name) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, MAX_TOKENS, NUM_LOGPROBS) + + check_logprobs_close( + outputs_0_lst=sparseml_outputs, + outputs_1_lst=vllm_outputs, + name_0="sparseml", + name_1="vllm", + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index c69e2f3bcf9fa..0cf224cc05479 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -34,7 +34,8 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: return [torch.float16, torch.bfloat16] # Need to figure it out - def get_min_capability(self) -> int: + @classmethod + def get_min_capability(cls) -> int: return 60 def get_name(self) -> str: