Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 3b1644e

Browse files
Pernekhanalexm-redhat
authored andcommitted
Use revision when downloading the quantization config file (vllm-project#2697)
Co-authored-by: Pernekhan Utemuratov <pernekhan@deepinfra.com>
1 parent b8d26a8 commit 3b1644e

File tree

2 files changed

+15
-19
lines changed

2 files changed

+15
-19
lines changed

vllm/model_executor/model_loader.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,7 @@ def get_model(model_config: ModelConfig,
4545
# Get the (maybe sparse or quantized) linear method.
4646
linear_method = None
4747
if model_config.quantization is not None:
48-
quant_config = get_quant_config(model_config.quantization,
49-
model_config.model,
50-
model_config.hf_config,
51-
model_config.download_dir)
48+
quant_config = get_quant_config(model_config)
5249
capability = torch.cuda.get_device_capability()
5350
capability = capability[0] * 10 + capability[1]
5451
if capability < quant_config.get_min_capability():

vllm/model_executor/weight_utils.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
import numpy as np
1212
from safetensors.torch import load_file, save_file, safe_open
1313
import torch
14-
from transformers import PretrainedConfig
1514
from tqdm.auto import tqdm
1615

16+
from vllm.config import ModelConfig
1717
from vllm.logger import init_logger
1818
from vllm.model_executor.layers.quantization import (get_quantization_config,
1919
QuantizationConfig)
@@ -102,25 +102,22 @@ def get_sparse_config(
102102

103103

104104
# TODO(woosuk): Move this to other place.
105-
def get_quant_config(
106-
quantization: str,
107-
model_name_or_path: str,
108-
hf_config: PretrainedConfig,
109-
cache_dir: Optional[str] = None,
110-
) -> QuantizationConfig:
111-
quant_cls = get_quantization_config(quantization)
105+
def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
106+
quant_cls = get_quantization_config(model_config.quantization)
112107
# Read the quantization config from the HF model config, if available.
113-
hf_quant_config = getattr(hf_config, "quantization_config", None)
108+
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
109+
None)
114110
if hf_quant_config is not None:
115111
return quant_cls.from_config(hf_quant_config)
116-
112+
model_name_or_path = model_config.model
117113
is_local = os.path.isdir(model_name_or_path)
118114
if not is_local:
119115
# Download the config files.
120-
with get_lock(model_name_or_path, cache_dir):
116+
with get_lock(model_name_or_path, model_config.download_dir):
121117
hf_folder = snapshot_download(model_name_or_path,
118+
revision=model_config.revision,
122119
allow_patterns="*.json",
123-
cache_dir=cache_dir,
120+
cache_dir=model_config.download_dir,
124121
tqdm_class=Disabledtqdm)
125122
else:
126123
hf_folder = model_name_or_path
@@ -131,10 +128,12 @@ def get_quant_config(
131128
f.endswith(x) for x in quant_cls.get_config_filenames())
132129
]
133130
if len(quant_config_files) == 0:
134-
raise ValueError(f"Cannot find the config file for {quantization}")
131+
raise ValueError(
132+
f"Cannot find the config file for {model_config.quantization}")
135133
if len(quant_config_files) > 1:
136-
raise ValueError(f"Found multiple config files for {quantization}: "
137-
f"{quant_config_files}")
134+
raise ValueError(
135+
f"Found multiple config files for {model_config.quantization}: "
136+
f"{quant_config_files}")
138137

139138
quant_config_file = quant_config_files[0]
140139
with open(quant_config_file, "r") as f:

0 commit comments

Comments
 (0)