Skip to content

Commit

Permalink
[Bugfix] Fix Mistral v0.3 Weight Loading (vllm-project#5005)
Browse files Browse the repository at this point in the history
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
  • Loading branch information
2 people authored and dtrifiro committed May 31, 2024
1 parent 05968c9 commit e1bf3c3
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 3 deletions.
1 change: 1 addition & 0 deletions tests/models/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

MODELS = [
"mistralai/Mistral-7B-Instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.3",
]


Expand Down
17 changes: 15 additions & 2 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from vllm.model_executor.model_loader.utils import (get_model_architecture,
set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf, filter_files_not_needed_for_inference,
download_safetensors_index_file_from_hf, download_weights_from_hf,
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
Expand Down Expand Up @@ -188,7 +189,19 @@ def _prepare_weights(self, model_name_or_path: str,
use_safetensors = True
break

if not use_safetensors:
if use_safetensors:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if not is_local:
download_safetensors_index_file_from_hf(
model_name_or_path, self.load_config.download_dir,
revision)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder)
else:
hf_weights_files = filter_files_not_needed_for_inference(
hf_weights_files)

Expand Down
64 changes: 63 additions & 1 deletion vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
import huggingface_hub.constants
import numpy as np
import torch
from huggingface_hub import HfFileSystem, snapshot_download
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME

from vllm.config import LoadConfig, ModelConfig
from vllm.logger import init_logger
Expand Down Expand Up @@ -218,6 +219,67 @@ def download_weights_from_hf(
return hf_folder


def download_safetensors_index_file_from_hf(
model_name_or_path: str,
cache_dir: Optional[str],
revision: Optional[str] = None,
) -> None:
"""Download hf safetensors index file from Hugging Face Hub.
Args:
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
weights. If None, will use HF defaults.
revision (Optional[str]): The revision of the model.
"""
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
try:
# Download the safetensors index file.
hf_hub_download(
repo_id=model_name_or_path,
filename=SAFE_WEIGHTS_INDEX_NAME,
cache_dir=cache_dir,
revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
)
# If file not found on remote or locally, we should not fail since
# only some models will have SAFE_WEIGHTS_INDEX_NAME.
except huggingface_hub.utils.EntryNotFoundError:
logger.info("No %s found in remote.", SAFE_WEIGHTS_INDEX_NAME)
except huggingface_hub.utils.LocalEntryNotFoundError:
logger.info("No %s found in local cache.", SAFE_WEIGHTS_INDEX_NAME)


# For models like Mistral-7B-v0.3, there are both sharded
# safetensors files and a consolidated safetensors file.
# Passing both of these to the weight loader functionality breaks.
# So, we use the SAFE_WEIGHTS_INDEX_NAME to
# look up which safetensors files should be used.
def filter_duplicate_safetensors_files(hf_weights_files: List[str],
hf_folder: str) -> List[str]:
# model.safetensors.index.json is a mapping from keys in the
# torch state_dict to safetensors file holding that weight.
index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME)
if not os.path.isfile(index_file_name):
return hf_weights_files

# Iterate through the weight_map (weight_name: safetensors files)
# to identify weights that we should use.
with open(index_file_name) as index_file:
weight_map = json.load(index_file)["weight_map"]
weight_files_in_index = set()
for weight_name in weight_map:
weight_files_in_index.add(
os.path.join(hf_folder, weight_map[weight_name]))
# Filter out any fields that are not found in the index file.
hf_weights_files = [
f for f in hf_weights_files if f in weight_files_in_index
]
return hf_weights_files


def filter_files_not_needed_for_inference(
hf_weights_files: List[str]) -> List[str]:
"""
Expand Down

0 comments on commit e1bf3c3

Please sign in to comment.