11
11
import numpy as np
12
12
from safetensors .torch import load_file , save_file , safe_open
13
13
import torch
14
- from transformers import PretrainedConfig
15
14
from tqdm .auto import tqdm
16
15
16
+ from vllm .config import ModelConfig
17
17
from vllm .logger import init_logger
18
18
from vllm .model_executor .layers .quantization import (get_quantization_config ,
19
19
QuantizationConfig )
@@ -83,25 +83,22 @@ def convert_bin_to_safetensor_file(
83
83
84
84
85
85
# TODO(woosuk): Move this to other place.
86
- def get_quant_config (
87
- quantization : str ,
88
- model_name_or_path : str ,
89
- hf_config : PretrainedConfig ,
90
- cache_dir : Optional [str ] = None ,
91
- ) -> QuantizationConfig :
92
- quant_cls = get_quantization_config (quantization )
86
+ def get_quant_config (model_config : ModelConfig ) -> QuantizationConfig :
87
+ quant_cls = get_quantization_config (model_config .quantization )
93
88
# Read the quantization config from the HF model config, if available.
94
- hf_quant_config = getattr (hf_config , "quantization_config" , None )
89
+ hf_quant_config = getattr (model_config .hf_config , "quantization_config" ,
90
+ None )
95
91
if hf_quant_config is not None :
96
92
return quant_cls .from_config (hf_quant_config )
97
-
93
+ model_name_or_path = model_config . model
98
94
is_local = os .path .isdir (model_name_or_path )
99
95
if not is_local :
100
96
# Download the config files.
101
- with get_lock (model_name_or_path , cache_dir ):
97
+ with get_lock (model_name_or_path , model_config . download_dir ):
102
98
hf_folder = snapshot_download (model_name_or_path ,
99
+ revision = model_config .revision ,
103
100
allow_patterns = "*.json" ,
104
- cache_dir = cache_dir ,
101
+ cache_dir = model_config . download_dir ,
105
102
tqdm_class = Disabledtqdm )
106
103
else :
107
104
hf_folder = model_name_or_path
@@ -112,10 +109,12 @@ def get_quant_config(
112
109
f .endswith (x ) for x in quant_cls .get_config_filenames ())
113
110
]
114
111
if len (quant_config_files ) == 0 :
115
- raise ValueError (f"Cannot find the config file for { quantization } " )
112
+ raise ValueError (
113
+ f"Cannot find the config file for { model_config .quantization } " )
116
114
if len (quant_config_files ) > 1 :
117
- raise ValueError (f"Found multiple config files for { quantization } : "
118
- f"{ quant_config_files } " )
115
+ raise ValueError (
116
+ f"Found multiple config files for { model_config .quantization } : "
117
+ f"{ quant_config_files } " )
119
118
120
119
quant_config_file = quant_config_files [0 ]
121
120
with open (quant_config_file , "r" ) as f :
0 commit comments