File tree Expand file tree Collapse file tree 4 files changed +57
-1
lines changed
vllm/model_executor/layers/quantization/compressed_tensors Expand file tree Collapse file tree 4 files changed +57
-1
lines changed Original file line number Diff line number Diff line change 14
14
requests
15
15
ray
16
16
sentence-transformers # required for embedding
17
+ sparseml==1.8.0 # required for compressed-tensors
18
+ compressed-tensors==0.4.0 # required for compressed-tensors
17
19
18
20
# Benchmarking
19
21
aiohttp
Original file line number Diff line number Diff line change @@ -176,6 +176,7 @@ def __init__(
176
176
model_kwargs : Optional [Dict [str , Any ]] = None ,
177
177
is_embedding_model : bool = False ,
178
178
is_vision_model : bool = False ,
179
+ is_sparseml_model : bool = False ,
179
180
) -> None :
180
181
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
181
182
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE [dtype ]
@@ -193,6 +194,9 @@ def __init__(
193
194
else :
194
195
if is_vision_model :
195
196
auto_cls = AutoModelForVision2Seq
197
+ elif is_sparseml_model :
198
+ from sparseml .transformers import SparseAutoModelForCausalLM
199
+ auto_cls = SparseAutoModelForCausalLM
196
200
else :
197
201
auto_cls = AutoModelForCausalLM
198
202
Original file line number Diff line number Diff line change
1
+ """Compares vllm vs sparseml for compressed-tensors
2
+
3
+ Note: vllm and sparseml do not have bitwise correctness,
4
+ so in this test, we just confirm that the top selected
5
+ tokens of the are in the top 5 selections of each other.
6
+ """
7
+
8
+ import pytest
9
+
10
+ from tests .quantization .utils import is_quant_method_supported
11
+
12
+ from .utils import check_logprobs_close
13
+
14
+ MODELS = [
15
+ "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test" ,
16
+ ]
17
+
18
+ MAX_TOKENS = 32
19
+ NUM_LOGPROBS = 5
20
+
21
+
22
+ @pytest .mark .skipif (
23
+ not is_quant_method_supported ("compressed-tensors" ),
24
+ reason = "compressed-tensors is not supported on this machine type." )
25
+ @pytest .mark .parametrize ("model_name" , MODELS )
26
+ def test_models (
27
+ vllm_runner ,
28
+ hf_runner ,
29
+ example_prompts ,
30
+ model_name ,
31
+ ) -> None :
32
+ # Run sparseml.
33
+ with hf_runner (model_name = model_name ,
34
+ is_sparseml_model = True ) as sparseml_model :
35
+
36
+ sparseml_outputs = sparseml_model .generate_greedy_logprobs_limit (
37
+ example_prompts , MAX_TOKENS , NUM_LOGPROBS )
38
+
39
+ # Run vllm.
40
+ with vllm_runner (model_name = model_name ) as vllm_model :
41
+ vllm_outputs = vllm_model .generate_greedy_logprobs (
42
+ example_prompts , MAX_TOKENS , NUM_LOGPROBS )
43
+
44
+ check_logprobs_close (
45
+ outputs_0_lst = sparseml_outputs ,
46
+ outputs_1_lst = vllm_outputs ,
47
+ name_0 = "sparseml" ,
48
+ name_1 = "vllm" ,
49
+ )
Original file line number Diff line number Diff line change @@ -34,7 +34,8 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
34
34
return [torch .float16 , torch .bfloat16 ]
35
35
36
36
# Need to figure it out
37
- def get_min_capability (self ) -> int :
37
+ @classmethod
38
+ def get_min_capability (cls ) -> int :
38
39
return 60
39
40
40
41
def get_name (self ) -> str :
You can’t perform that action at this time.
0 commit comments