From 3eda4ec780d572120ed02e9e94bcef383c3e0399 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Mon, 22 Jul 2024 23:59:42 -0700 Subject: [PATCH] support ignore patterns in model loader (#6673) --- vllm/config.py | 15 +++++++++- vllm/engine/arg_utils.py | 10 +++++++ vllm/model_executor/model_loader/loader.py | 29 ++++++++++++++----- .../model_loader/weight_utils.py | 7 ++++- 4 files changed, 51 insertions(+), 10 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 9d60f07579217..6f0fdf8bc67e9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -599,12 +599,16 @@ class LoadConfig: mainly for profiling. "tensorizer" will use CoreWeave's tensorizer library for fast weight loading. + ignore_patterns: The list of patterns to ignore when loading the model. + Default to "original/**/*" to avoid repeated loading of llama's + checkpoints. """ load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO download_dir: Optional[str] = None model_loader_extra_config: Optional[Union[str, dict]] = field( default_factory=dict) + ignore_patterns: Optional[Union[List[str], str]] = None def __post_init__(self): model_loader_extra_config = self.model_loader_extra_config or {} @@ -613,6 +617,13 @@ def __post_init__(self): model_loader_extra_config) self._verify_load_format() + if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: + logger.info( + "Ignoring the following patterns when downloading weights: %s", + self.ignore_patterns) + else: + self.ignore_patterns = ["original/**/*"] + def _verify_load_format(self) -> None: if not isinstance(self.load_format, str): return @@ -801,7 +812,9 @@ def __init__(self, # for higher throughput. self.max_num_batched_tokens = max(max_model_len, 2048) if enable_chunked_prefill: - logger.info("Chunked prefill is enabled (EXPERIMENTAL).") + logger.info( + "Chunked prefill is enabled with max_num_batched_tokens=%d.", + max_num_batched_tokens) self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c34b88b53f656..05bfe7c24f978 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -95,6 +95,7 @@ class EngineArgs: num_gpu_blocks_override: Optional[int] = None num_lookahead_slots: int = 0 model_loader_extra_config: Optional[dict] = None + ignore_patterns: Optional[Union[str, List[str]]] = None preemption_mode: Optional[str] = None scheduler_delay_factor: float = 0.0 @@ -619,6 +620,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'corresponding to the chosen load_format. ' 'This should be a JSON string that will be ' 'parsed into a dictionary.') + parser.add_argument( + '--ignore-patterns', + action="append", + type=str, + default=[], + help="The pattern(s) to ignore when loading the model." + "Default to 'original/**/*' to avoid repeated loading of llama's " + "checkpoints.") parser.add_argument( '--preemption-mode', type=str, @@ -824,6 +833,7 @@ def create_engine_config(self, ) -> EngineConfig: load_format=self.load_format, download_dir=self.download_dir, model_loader_extra_config=self.model_loader_extra_config, + ignore_patterns=self.ignore_patterns, ) prompt_adapter_config = PromptAdapterConfig( diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index a1a2b0b323f67..88f16918b0119 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -161,6 +161,7 @@ def _maybe_download_from_modelscope( cache_dir=self.load_config.download_dir, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, revision=revision, + ignore_patterns=self.load_config.ignore_patterns, ) else: model_path = model @@ -196,9 +197,13 @@ def _prepare_weights(self, model_name_or_path: str, allow_patterns += ["*.pt"] if not is_local: - hf_folder = download_weights_from_hf(model_name_or_path, - self.load_config.download_dir, - allow_patterns, revision) + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) else: hf_folder = model_name_or_path @@ -489,9 +494,13 @@ def _prepare_weights(self, model_name_or_path: str, return model_name_or_path else: allow_patterns = ["*.safetensors"] - return download_weights_from_hf(model_name_or_path, - self.load_config.download_dir, - allow_patterns, revision) + return download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, @@ -663,8 +672,12 @@ def _get_weight_files( matching_files = fnmatch.filter(repo_files, pattern) if matching_files: hf_folder = download_weights_from_hf( - model_name_or_path, self.load_config.download_dir, - [pattern], revision) + model_name_or_path, + self.load_config.download_dir, + [pattern], + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) return glob.glob(os.path.join(hf_folder, pattern)), pattern raise RuntimeError( diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index ee3b2530880d1..dbba6ea358346 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -6,7 +6,7 @@ import os import tempfile from collections import defaultdict -from typing import Any, Generator, Iterable, List, Optional, Tuple +from typing import Any, Generator, Iterable, List, Optional, Tuple, Union import filelock import huggingface_hub.constants @@ -189,6 +189,7 @@ def download_weights_from_hf( cache_dir: Optional[str], allow_patterns: List[str], revision: Optional[str] = None, + ignore_patterns: Optional[Union[str, List[str]]] = None, ) -> str: """Download model weights from Hugging Face Hub. @@ -200,6 +201,9 @@ def download_weights_from_hf( weight files. Files matched by any of the patterns will be downloaded. revision (Optional[str]): The revision of the model. + ignore_patterns (Optional[Union[str, List[str]]]): The patterns to + filter out the weight files. Files matched by any of the patterns + will be ignored. Returns: str: The path to the downloaded model weights. @@ -223,6 +227,7 @@ def download_weights_from_hf( hf_folder = snapshot_download( model_name_or_path, allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, cache_dir=cache_dir, tqdm_class=DisabledTqdm, revision=revision,