11
11
import numpy as np
12
12
from safetensors .torch import load_file , save_file , safe_open
13
13
import torch
14
- from transformers import PretrainedConfig
15
14
from tqdm .auto import tqdm
16
15
16
+ from vllm .config import ModelConfig
17
17
from vllm .logger import init_logger
18
18
from vllm .model_executor .layers .quantization import (get_quantization_config ,
19
19
QuantizationConfig )
@@ -102,25 +102,22 @@ def get_sparse_config(
102
102
103
103
104
104
# TODO(woosuk): Move this to other place.
105
- def get_quant_config (
106
- quantization : str ,
107
- model_name_or_path : str ,
108
- hf_config : PretrainedConfig ,
109
- cache_dir : Optional [str ] = None ,
110
- ) -> QuantizationConfig :
111
- quant_cls = get_quantization_config (quantization )
105
+ def get_quant_config (model_config : ModelConfig ) -> QuantizationConfig :
106
+ quant_cls = get_quantization_config (model_config .quantization )
112
107
# Read the quantization config from the HF model config, if available.
113
- hf_quant_config = getattr (hf_config , "quantization_config" , None )
108
+ hf_quant_config = getattr (model_config .hf_config , "quantization_config" ,
109
+ None )
114
110
if hf_quant_config is not None :
115
111
return quant_cls .from_config (hf_quant_config )
116
-
112
+ model_name_or_path = model_config . model
117
113
is_local = os .path .isdir (model_name_or_path )
118
114
if not is_local :
119
115
# Download the config files.
120
- with get_lock (model_name_or_path , cache_dir ):
116
+ with get_lock (model_name_or_path , model_config . download_dir ):
121
117
hf_folder = snapshot_download (model_name_or_path ,
118
+ revision = model_config .revision ,
122
119
allow_patterns = "*.json" ,
123
- cache_dir = cache_dir ,
120
+ cache_dir = model_config . download_dir ,
124
121
tqdm_class = Disabledtqdm )
125
122
else :
126
123
hf_folder = model_name_or_path
@@ -131,10 +128,12 @@ def get_quant_config(
131
128
f .endswith (x ) for x in quant_cls .get_config_filenames ())
132
129
]
133
130
if len (quant_config_files ) == 0 :
134
- raise ValueError (f"Cannot find the config file for { quantization } " )
131
+ raise ValueError (
132
+ f"Cannot find the config file for { model_config .quantization } " )
135
133
if len (quant_config_files ) > 1 :
136
- raise ValueError (f"Found multiple config files for { quantization } : "
137
- f"{ quant_config_files } " )
134
+ raise ValueError (
135
+ f"Found multiple config files for { model_config .quantization } : "
136
+ f"{ quant_config_files } " )
138
137
139
138
quant_config_file = quant_config_files [0 ]
140
139
with open (quant_config_file , "r" ) as f :
0 commit comments