From 69e1d2fb6922b2d388bae41286d8867976cbd6c6 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 16 Apr 2024 11:34:39 -0700 Subject: [PATCH] [Core] Refactor model loading code (#4097) --- .buildkite/test-pipeline.yaml | 2 +- examples/fp8/extract_scales.py | 4 +- examples/tensorize_vllm_model.py | 2 +- tests/lora/conftest.py | 10 +- tests/lora/test_worker.py | 10 +- tests/model_executor/weight_utils.py | 2 +- .../test_autogptq_marlin_configs.py | 4 - tests/samplers/test_sampler.py | 14 +- tests/spec_decode/utils.py | 1 + .../__init__.py | 0 .../tensorize_vllm_model_for_testing.py | 4 +- .../test_tensorizer.py | 139 ++++--- tests/test_config.py | 4 - tests/test_logits_processor.py | 7 +- tests/worker/test_model_runner.py | 39 +- tests/worker/test_swap.py | 1 + vllm/config.py | 201 ++++------ vllm/engine/arg_utils.py | 59 ++- vllm/engine/llm_engine.py | 19 +- vllm/executor/cpu_executor.py | 1 + vllm/executor/executor_base.py | 10 +- vllm/executor/gpu_executor.py | 2 +- vllm/executor/ray_gpu_executor.py | 5 +- vllm/model_executor/model_loader.py | 128 ------- vllm/model_executor/model_loader/__init__.py | 30 ++ vllm/model_executor/model_loader/loader.py | 354 ++++++++++++++++++ .../neuron.py} | 0 .../tensorizer.py} | 116 ++++-- vllm/model_executor/model_loader/utils.py | 40 ++ .../{ => model_loader}/weight_utils.py | 295 +++++++-------- vllm/model_executor/models/baichuan.py | 14 +- vllm/model_executor/models/bloom.py | 14 +- vllm/model_executor/models/chatglm.py | 14 +- vllm/model_executor/models/commandr.py | 16 +- vllm/model_executor/models/dbrx.py | 16 +- vllm/model_executor/models/decilm.py | 14 +- vllm/model_executor/models/deepseek.py | 20 +- vllm/model_executor/models/falcon.py | 14 +- vllm/model_executor/models/gemma.py | 14 +- vllm/model_executor/models/gpt2.py | 14 +- vllm/model_executor/models/gpt_bigcode.py | 14 +- vllm/model_executor/models/gpt_j.py | 14 +- vllm/model_executor/models/gpt_neox.py | 14 +- vllm/model_executor/models/internlm2.py | 14 +- vllm/model_executor/models/jais.py | 16 +- vllm/model_executor/models/llama.py | 16 +- vllm/model_executor/models/llava.py | 14 +- vllm/model_executor/models/minicpm.py | 14 +- vllm/model_executor/models/mixtral.py | 20 +- vllm/model_executor/models/mixtral_quant.py | 19 +- vllm/model_executor/models/mpt.py | 14 +- vllm/model_executor/models/olmo.py | 16 +- vllm/model_executor/models/opt.py | 14 +- vllm/model_executor/models/orion.py | 14 +- vllm/model_executor/models/phi.py | 14 +- vllm/model_executor/models/qwen.py | 14 +- vllm/model_executor/models/qwen2.py | 14 +- vllm/model_executor/models/qwen2_moe.py | 20 +- vllm/model_executor/models/stablelm.py | 14 +- vllm/model_executor/models/starcoder2.py | 14 +- vllm/model_executor/models/xverse.py | 14 +- vllm/transformers_utils/tokenizer.py | 21 +- vllm/worker/cpu_model_runner.py | 12 +- vllm/worker/cpu_worker.py | 7 +- vllm/worker/model_runner.py | 15 +- vllm/worker/neuron_model_runner.py | 2 +- vllm/worker/worker.py | 10 +- 67 files changed, 1064 insertions(+), 973 deletions(-) rename tests/{tensorizer => tensorizer_loader}/__init__.py (100%) rename tests/{tensorizer => tensorizer_loader}/tensorize_vllm_model_for_testing.py (98%) rename tests/{tensorizer => tensorizer_loader}/test_tensorizer.py (67%) delete mode 100644 vllm/model_executor/model_loader.py create mode 100644 vllm/model_executor/model_loader/__init__.py create mode 100644 vllm/model_executor/model_loader/loader.py rename vllm/model_executor/{neuron_model_loader.py => model_loader/neuron.py} (100%) rename vllm/model_executor/{tensorizer_loader.py => model_loader/tensorizer.py} (78%) create mode 100644 vllm/model_executor/model_loader/utils.py rename vllm/model_executor/{ => model_loader}/weight_utils.py (53%) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index aa4582bbda0c7..f39c3302ac2e9 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -92,7 +92,7 @@ steps: parallelism: 4 - label: Tensorizer Test - command: apt-get install curl libsodium23 && pytest -v -s tensorizer + command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader - label: Metrics Test command: pytest -v -s metrics diff --git a/examples/fp8/extract_scales.py b/examples/fp8/extract_scales.py index 5e5b31265e3af..1eb961a5a76e3 100644 --- a/examples/fp8/extract_scales.py +++ b/examples/fp8/extract_scales.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.quantization.schema import QuantParamSchema -# Adapted from vllm/model_executor/weight_utils.py +# Adapted from vllm/model_executor/model_loader/weight_utils.py # The main differences are that we add the NPZ format and simplify # its functionality drastically for our purposes (e.g. we assume that # the quantized model exists locally and there is no need to download it) @@ -71,7 +71,7 @@ def _prepare_hf_weights( return hf_weights_files, use_safetensors -# Adapted from vllm/model_executor/weight_utils.py +# Adapted from vllm/model_executor/model_loader/weight_utils.py def _hf_tensorfile_iterator(filename: str, load_format: str, use_safetensors: bool): if load_format == "npz": diff --git a/examples/tensorize_vllm_model.py b/examples/tensorize_vllm_model.py index 8cf8be09d0b9c..e2456168de9d5 100644 --- a/examples/tensorize_vllm_model.py +++ b/examples/tensorize_vllm_model.py @@ -16,8 +16,8 @@ from vllm.distributed import initialize_model_parallel from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine +from vllm.model_executor.model_loader.tensorizer import TensorizerArgs from vllm.model_executor.models import ModelRegistry -from vllm.model_executor.tensorizer_loader import TensorizerArgs # yapf conflicts with isort for this docstring # yapf: disable diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 1127cc33183c9..2dabfb6b4337c 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -153,11 +153,11 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module: cleanup() get_model_old = get_model - def get_model_patched(model_config, device_config, **kwargs): - return get_model_old(model_config, - device_config, - lora_config=LoRAConfig(max_loras=4, - max_lora_rank=8)) + def get_model_patched(*, model_config, device_config, **kwargs): + kwargs["lora_config"] = LoRAConfig(max_loras=4, max_lora_rank=8) + return get_model_old(model_config=model_config, + device_config=device_config, + **kwargs) with patch("vllm.worker.model_runner.get_model", get_model_patched): engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False) diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 54594690f7922..732e91a52c0a9 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -3,8 +3,8 @@ import tempfile from unittest.mock import patch -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig) from vllm.lora.models import LoRAMapping from vllm.lora.request import LoRARequest from vllm.worker.worker import Worker @@ -18,12 +18,14 @@ def test_worker_apply_lora(sql_lora_files): "meta-llama/Llama-2-7b-hf", tokenizer_mode="auto", trust_remote_code=False, - download_dir=None, - load_format="dummy", seed=0, dtype="float16", revision=None, ), + load_config=LoadConfig( + download_dir=None, + load_format="dummy", + ), parallel_config=ParallelConfig(1, 1, False), scheduler_config=SchedulerConfig(32, 32, 32), device_config=DeviceConfig("cuda"), diff --git a/tests/model_executor/weight_utils.py b/tests/model_executor/weight_utils.py index 3154f2826d10c..b0086dd7a7d71 100644 --- a/tests/model_executor/weight_utils.py +++ b/tests/model_executor/weight_utils.py @@ -3,7 +3,7 @@ import huggingface_hub.constants import pytest -from vllm.model_executor.weight_utils import enable_hf_transfer +from vllm.model_executor.model_loader.weight_utils import enable_hf_transfer def test_hf_transfer_auto_activation(): diff --git a/tests/quantization/test_autogptq_marlin_configs.py b/tests/quantization/test_autogptq_marlin_configs.py index cd64622e2226f..1310b4da218b5 100644 --- a/tests/quantization/test_autogptq_marlin_configs.py +++ b/tests/quantization/test_autogptq_marlin_configs.py @@ -36,8 +36,6 @@ def test_auto_gptq(model_quant_type: str, ) -> None: model_path, tokenizer_mode="auto", trust_remote_code=False, - download_dir=None, - load_format="dummy", seed=0, dtype="float16", revision=None, @@ -49,8 +47,6 @@ def test_auto_gptq(model_quant_type: str, ) -> None: model_path, tokenizer_mode="auto", trust_remote_code=False, - download_dir=None, - load_format="dummy", seed=0, dtype="float16", revision=None, diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 26e2d29ffd04c..dbbe13b8da060 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -32,7 +32,12 @@ def _prepare_test( 1e-2, dtype=input_tensor.dtype) sampler = MockLogitsSampler(fake_logits) - model_runner = ModelRunner(None, None, None, None, None) + model_runner = ModelRunner(model_config=None, + parallel_config=None, + scheduler_config=None, + device_config=None, + load_config=None, + lora_config=None) return input_tensor, fake_logits, sampler, model_runner @@ -591,7 +596,12 @@ def test_sampler_top_k_top_p(seed: int, device: str): device=input_tensor.device, dtype=input_tensor.dtype) sampler = MockLogitsSampler(fake_logits) - model_runner = ModelRunner(None, None, None, None, None) + model_runner = ModelRunner(model_config=None, + parallel_config=None, + scheduler_config=None, + device_config=None, + load_config=None, + lora_config=None) generation_model = GenerationMixin() generation_config = GenerationConfig(top_k=top_k, diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 4637826f254d6..edba4c226b289 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -118,6 +118,7 @@ def create_worker(cls: type, scheduler_config=engine_config.scheduler_config, device_config=engine_config.device_config, cache_config=engine_config.cache_config, + load_config=engine_config.load_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, diff --git a/tests/tensorizer/__init__.py b/tests/tensorizer_loader/__init__.py similarity index 100% rename from tests/tensorizer/__init__.py rename to tests/tensorizer_loader/__init__.py diff --git a/tests/tensorizer/tensorize_vllm_model_for_testing.py b/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py similarity index 98% rename from tests/tensorizer/tensorize_vllm_model_for_testing.py rename to tests/tensorizer_loader/tensorize_vllm_model_for_testing.py index d0be08329fd64..e4b15fd57add4 100644 --- a/tests/tensorizer/tensorize_vllm_model_for_testing.py +++ b/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py @@ -16,8 +16,8 @@ from vllm.distributed import initialize_model_parallel from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine +from vllm.model_executor.model_loader.tensorizer import TensorizerArgs from vllm.model_executor.models import ModelRegistry -from vllm.model_executor.tensorizer_loader import TensorizerArgs # yapf conflicts with isort for this docstring # yapf: disable @@ -74,7 +74,7 @@ def parse_args(): "extremely quickly. Tensor encryption and decryption is " "also supported, although libsodium must be installed to " "use it.") - parser = EngineArgs.add_cli_args(parser) + parser = TensorizerArgs.add_cli_args(EngineArgs.add_cli_args(parser)) subparsers = parser.add_subparsers(dest='command') serialize_parser = subparsers.add_parser( diff --git a/tests/tensorizer/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py similarity index 67% rename from tests/tensorizer/test_tensorizer.py rename to tests/tensorizer_loader/test_tensorizer.py index 2ab893e95da9c..a97cc0b3706b4 100644 --- a/tests/tensorizer/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -1,16 +1,19 @@ import gc +import json +import os import subprocess from unittest.mock import MagicMock, patch +import openai import pytest +import ray import torch from tests.entrypoints.test_openai_server import ServerRunner from vllm import SamplingParams -from vllm.config import TensorizerConfig -from vllm.model_executor.tensorizer_loader import ( - EncryptionParams, TensorSerializer, is_vllm_serialized_tensorizer, - load_with_tensorizer, open_stream) +from vllm.model_executor.model_loader.tensorizer import ( + EncryptionParams, TensorizerConfig, TensorSerializer, + is_vllm_serialized_tensorizer, load_with_tensorizer, open_stream) prompts = [ "Hello, my name is", @@ -22,6 +25,8 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0) model_ref = "facebook/opt-125m" +tensorize_model_for_testing_script = os.path.join( + os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py") def is_curl_installed(): @@ -38,7 +43,7 @@ def tensorizer_config(): return config -@patch('vllm.model_executor.tensorizer_loader.TensorizerAgent') +@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent') def test_load_with_tensorizer(mock_agent, tensorizer_config): mock_linear_method = MagicMock() mock_agent_instance = mock_agent.return_value @@ -81,11 +86,13 @@ def test_deserialized_vllm_model_has_same_outputs(vllm_runner, tmp_path): del vllm_model, model gc.collect() torch.cuda.empty_cache() - loaded_vllm_model = vllm_runner(model_ref, - load_format="tensorizer", - tensorizer_uri=model_path, - num_readers=1, - vllm_tensorized=True) + loaded_vllm_model = vllm_runner( + model_ref, + load_format="tensorizer", + model_loader_extra_config=TensorizerConfig(tensorizer_uri=model_path, + num_readers=1, + vllm_tensorized=True), + ) deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # Assumes SamplingParams being seeded ensures the outputs are deterministic @@ -97,14 +104,14 @@ def test_can_deserialize_s3(vllm_runner): model_ref = "EleutherAI/pythia-1.4b" tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors" - loaded_hf_model = vllm_runner( - model_ref, - tensorizer_uri=tensorized_path, - load_format="tensorizer", - num_readers=1, - vllm_tensorized=False, - s3_endpoint="object.ord1.coreweave.com", - ) + loaded_hf_model = vllm_runner(model_ref, + load_format="tensorizer", + model_loader_extra_config=TensorizerConfig( + tensorizer_uri=tensorized_path, + num_readers=1, + vllm_tensorized=False, + s3_endpoint="object.ord1.coreweave.com", + )) deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params) @@ -131,11 +138,12 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs( gc.collect() torch.cuda.empty_cache() loaded_vllm_model = vllm_runner(model_ref, - tensorizer_uri=model_path, load_format="tensorizer", - encryption_keyfile=key_path, - num_readers=1, - vllm_tensorized=True) + model_loader_extra_config=TensorizerConfig( + tensorizer_uri=model_path, + encryption_keyfile=key_path, + num_readers=1, + vllm_tensorized=True)) deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) @@ -156,10 +164,11 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, gc.collect() torch.cuda.empty_cache() loaded_hf_model = vllm_runner(model_ref, - tensorizer_uri=model_path, load_format="tensorizer", - num_readers=1, - vllm_tensorized=False) + model_loader_extra_config=TensorizerConfig( + tensorizer_uri=model_path, + num_readers=1, + vllm_tensorized=False)) deserialized_outputs = loaded_hf_model.generate_greedy( prompts, max_tokens=max_tokens) @@ -190,10 +199,12 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): torch.cuda.empty_cache() loaded_vllm_model = vllm_runner( model_ref, - tensorizer_uri=model_path, load_format="tensorizer", - num_readers=1, - vllm_tensorized=True, + model_loader_extra_config=TensorizerConfig( + tensorizer_uri=model_path, + num_readers=1, + vllm_tensorized=True, + ), enable_lora=True, max_loras=1, max_lora_rank=8, @@ -208,16 +219,18 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): def test_load_without_tensorizer_load_format(vllm_runner): with pytest.raises(ValueError): - vllm_runner(model_ref, tensorizer_uri="test") + vllm_runner(model_ref, + model_loader_extra_config=TensorizerConfig( + tensorizer_uri="test", vllm_tensorized=False)) @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") def test_tensorize_vllm_model(tmp_path): # Test serialize command serialize_args = [ - "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", - model_ref, "--dtype", "float16", "serialize", "--serialized-directory", - tmp_path, "--suffix", "tests" + "python3", tensorize_model_for_testing_script, "--model", model_ref, + "--dtype", "float16", "serialize", "--serialized-directory", tmp_path, + "--suffix", "tests" ] result = subprocess.run(serialize_args, capture_output=True, text=True) print(result.stdout) # Print the output of the serialize command @@ -229,8 +242,8 @@ def test_tensorize_vllm_model(tmp_path): # Test deserialize command deserialize_args = [ - "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", - model_ref, "--dtype", "float16", "deserialize", "--path-to-tensors", + "python3", tensorize_model_for_testing_script, "--model", model_ref, + "--dtype", "float16", "deserialize", "--path-to-tensors", path_to_tensors ] result = subprocess.run(deserialize_args, capture_output=True, text=True) @@ -242,9 +255,9 @@ def test_tensorize_vllm_model(tmp_path): def test_openai_apiserver_with_tensorizer(tmp_path): ## Serialize model serialize_args = [ - "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", - model_ref, "--dtype", "float16", "serialize", "--serialized-directory", - tmp_path, "--suffix", "tests" + "python3", tensorize_model_for_testing_script, "--model", model_ref, + "--dtype", "float16", "serialize", "--serialized-directory", tmp_path, + "--suffix", "tests" ] result = subprocess.run(serialize_args, capture_output=True, text=True) print(result.stdout) # Print the output of the serialize command @@ -253,25 +266,47 @@ def test_openai_apiserver_with_tensorizer(tmp_path): f"\n{result.stdout}\n{result.stderr}") path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors" + model_loader_extra_config = { + "tensorizer_uri": path_to_tensors, + "vllm_tensorized": True + } ## Start OpenAI API server openai_args = [ "--model", model_ref, "--dtype", "float16", "--load-format", - "tensorizer", "--tensorizer-uri", path_to_tensors, "--vllm-tensorized", - "--port", "8000" + "tensorizer", "--model-loader-extra-config", + json.dumps(model_loader_extra_config), "--port", "8000" ] server = ServerRunner.remote(openai_args) + assert ray.get(server.ready.remote()) print("Server ready.") - assert server.ready.remote() + + client = openai.OpenAI( + base_url="http://localhost:8000/v1", + api_key="token-abc123", + ) + completion = client.completions.create(model=model_ref, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + assert completion.choices[0].text is not None and len( + completion.choices[0].text) >= 5 + assert completion.choices[0].finish_reason == "length" + assert completion.usage == openai.types.CompletionUsage( + completion_tokens=5, prompt_tokens=6, total_tokens=11) def test_raise_value_error_on_invalid_load_format(vllm_runner): with pytest.raises(ValueError): vllm_runner(model_ref, load_format="safetensors", - tensorizer_uri="test") + model_loader_extra_config=TensorizerConfig( + tensorizer_uri="test", vllm_tensorized=False)) def test_tensorizer_with_tp(vllm_runner): @@ -281,22 +316,12 @@ def test_tensorizer_with_tp(vllm_runner): vllm_runner( model_ref, - tensorizer_uri=tensorized_path, load_format="tensorizer", - num_readers=1, - vllm_tensorized=False, - s3_endpoint="object.ord1.coreweave.com", + model_loader_extra_config=TensorizerConfig( + tensorizer_uri=tensorized_path, + num_readers=1, + vllm_tensorized=False, + s3_endpoint="object.ord1.coreweave.com", + ), tensor_parallel_size=2, ) - - -@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") -def test_tensorizer_warn_quant(tmp_path): - model_ref = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" - serialize_args = [ - "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", - model_ref, "--quantization", "gptq", "--tensorizer-uri", "test", - "serialize", "--serialized-directory", tmp_path, "--suffix", "tests" - ] - result = subprocess.run(serialize_args, capture_output=True, text=True) - assert 'PerformanceWarning' in result.stderr diff --git a/tests/test_config.py b/tests/test_config.py index 13a9f76212679..19db10630bbae 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -11,8 +11,6 @@ def test_get_sliding_window(): "Qwen/Qwen1.5-7B", tokenizer_mode="auto", trust_remote_code=False, - download_dir=None, - load_format="dummy", seed=0, dtype="float16", revision=None, @@ -30,8 +28,6 @@ def test_get_sliding_window(): "mistralai/Mistral-7B-v0.1", tokenizer_mode="auto", trust_remote_code=False, - download_dir=None, - load_format="dummy", seed=0, dtype="float16", revision=None, diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index fe321520114f7..5bb93ca74855b 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -37,7 +37,12 @@ def _prepare_test( 1e-2, dtype=input_tensor.dtype) logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits) - model_runner = ModelRunner(None, None, None, None, None) + model_runner = ModelRunner(model_config=None, + parallel_config=None, + scheduler_config=None, + device_config=None, + load_config=None, + lora_config=None) return input_tensor, fake_logits, logits_processor, model_runner diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index dcaae4af4a6f8..59bed2ce0dad3 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -12,7 +12,12 @@ def test_prepare_prompt(batch_size): 100000, 100000, enable_chunked_prefill=False) - model_runner = ModelRunner(None, None, scheduler_config, None, None) + model_runner = ModelRunner(model_config=None, + parallel_config=None, + scheduler_config=scheduler_config, + device_config=None, + load_config=None, + lora_config=None) model_runner.set_block_size(16) prompt_lens = [] @@ -118,8 +123,6 @@ def test_prepare_decode_cuda_graph(batch_size): "facebook/opt-125m", tokenizer_mode="auto", trust_remote_code=False, - download_dir=None, - load_format="dummy", seed=0, dtype="float16", revision=None, @@ -129,8 +132,12 @@ def test_prepare_decode_cuda_graph(batch_size): 100000, 100000, enable_chunked_prefill=False) - model_runner = ModelRunner(model_config, None, scheduler_config, None, - None) + model_runner = ModelRunner(model_config=model_config, + parallel_config=None, + scheduler_config=scheduler_config, + device_config=None, + load_config=None, + lora_config=None) model_runner.set_block_size(16) prompt_lens = [] @@ -205,14 +212,17 @@ def test_empty_seq_group(): "facebook/opt-125m", tokenizer_mode="auto", trust_remote_code=False, - download_dir=None, - load_format="dummy", seed=0, dtype="float16", revision=None, enforce_eager=False, ) - model_runner = ModelRunner(model_config, None, None, None, None) + model_runner = ModelRunner(model_config=model_config, + parallel_config=None, + scheduler_config=None, + device_config=None, + load_config=None, + lora_config=None) model_runner.set_block_size(16) seq_group_metadata_list = [] input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( @@ -251,8 +261,6 @@ def mock_get_process_group_ranks(group=None): "facebook/opt-125m", tokenizer_mode="auto", trust_remote_code=False, - download_dir=None, - load_format="dummy", seed=0, dtype="float16", revision=None, @@ -262,11 +270,12 @@ def mock_get_process_group_ranks(group=None): 100000, 100000, enable_chunked_prefill=True) - model_runner = ModelRunner(model_config, - None, - scheduler_config, - None, - None, + model_runner = ModelRunner(model_config=model_config, + parallel_config=None, + scheduler_config=scheduler_config, + device_config=None, + load_config=None, + lora_config=None, is_driver_worker=True) model_runner.set_block_size(16) diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py index 8edb1cf05c08e..1804cf78d8003 100644 --- a/tests/worker/test_swap.py +++ b/tests/worker/test_swap.py @@ -23,6 +23,7 @@ def test_swap() -> None: scheduler_config=engine_config.scheduler_config, device_config=engine_config.device_config, cache_config=engine_config.cache_config, + load_config=engine_config.load_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, diff --git a/vllm/config.py b/vllm/config.py index bf31b03b7c6c4..5a29620e85ac6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,9 +1,7 @@ import enum -import io import json import os -import typing -from dataclasses import dataclass, fields +from dataclasses import dataclass, field, fields from typing import TYPE_CHECKING, ClassVar, List, Optional, Union import torch @@ -18,10 +16,14 @@ if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup - from vllm.model_executor.tensorizer_loader import TensorizerArgs + from vllm.model_executor.model_loader.loader import BaseModelLoader logger = init_logger(__name__) +# If true, will load models from ModelScope instead of Hugging Face Hub. +VLLM_USE_MODELSCOPE = os.environ.get("VLLM_USE_MODELSCOPE", + "False").lower() == "true" + _GB = 1 << 30 @@ -35,18 +37,6 @@ class ModelConfig: available, and "slow" will always use the slow tokenizer. trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. - download_dir: Directory to download and load the weights, default to the - default cache directory of huggingface. - load_format: The format of the model weights to load: - "auto" will try to load the weights in the safetensors format and - fall back to the pytorch bin format if safetensors format is - not available. - "pt" will load the weights in the pytorch bin format. - "safetensors" will load the weights in the safetensors format. - "npcache" will load the weights in pytorch format and store - a numpy cache to speed up the loading. - "dummy" will initialize the weights with random values, which is - mainly for profiling. dtype: Data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. @@ -83,8 +73,6 @@ def __init__( tokenizer: str, tokenizer_mode: str, trust_remote_code: bool, - download_dir: Optional[str], - load_format: str, dtype: Union[str, torch.dtype], seed: int, revision: Optional[str] = None, @@ -101,8 +89,6 @@ def __init__( self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode self.trust_remote_code = trust_remote_code - self.download_dir = download_dir - self.load_format = load_format self.seed = seed self.revision = revision self.code_revision = code_revision @@ -113,64 +99,16 @@ def __init__( self.max_context_len_to_capture = max_context_len_to_capture self.max_logprobs = max_logprobs - if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true": - # download model from ModelScope hub, - # lazy import so that modelscope is not required for normal use. - # pylint: disable=C. - from modelscope.hub.snapshot_download import snapshot_download - - if not os.path.exists(model): - model_path = snapshot_download(model_id=model, - cache_dir=download_dir, - revision=revision) - else: - model_path = model - self.model = model_path - self.download_dir = model_path - self.tokenizer = model_path - self.hf_config = get_config(self.model, trust_remote_code, revision, code_revision) self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.max_model_len = _get_and_verify_max_len(self.hf_text_config, max_model_len) - self._verify_load_format() self._verify_tokenizer_mode() self._verify_quantization() self._verify_cuda_graph() - def _verify_load_format(self) -> None: - load_format = self.load_format.lower() - supported_load_format = [ - "auto", "pt", "safetensors", "npcache", "dummy", "tensorizer" - ] - rocm_not_supported_load_format: List[str] = [] - if load_format not in supported_load_format: - raise ValueError( - f"Unknown load format: {self.load_format}. Must be one of " - "'auto', 'pt', 'safetensors', 'npcache', 'tensorizer', or " - "'dummy'.") - if is_hip() and load_format in rocm_not_supported_load_format: - rocm_supported_load_format = [ - f for f in supported_load_format - if (f not in rocm_not_supported_load_format) - ] - raise ValueError( - f"load format '{load_format}' is not supported in ROCm. " - f"Supported load format are " - f"{rocm_supported_load_format}") - - # TODO: Remove this check once HF updates the pt weights of Mixtral. - architectures = getattr(self.hf_config, "architectures", []) - # architectures can be None instead of [] - if architectures and "MixtralForCausalLM" in architectures \ - and load_format == "pt": - raise ValueError( - "Currently, the 'pt' format is not supported for Mixtral. " - "Please use the 'safetensors' format instead. ") - self.load_format = load_format - def _verify_tokenizer_mode(self) -> None: tokenizer_mode = self.tokenizer_mode.lower() if tokenizer_mode not in ["auto", "slow"]: @@ -471,6 +409,65 @@ def create_config( return tokenizer_pool_config +class LoadFormat(str, enum.Enum): + AUTO = "auto" + PT = "pt" + SAFETENSORS = "safetensors" + NPCACHE = "npcache" + DUMMY = "dummy" + TENSORIZER = "tensorizer" + + +@dataclass +class LoadConfig: + """ + download_dir: Directory to download and load the weights, default to the + default cache directory of huggingface. + load_format: The format of the model weights to load: + "auto" will try to load the weights in the safetensors format and + fall back to the pytorch bin format if safetensors format is + not available. + "pt" will load the weights in the pytorch bin format. + "safetensors" will load the weights in the safetensors format. + "npcache" will load the weights in pytorch format and store + a numpy cache to speed up the loading. + "dummy" will initialize the weights with random values, which is + mainly for profiling. + "tensorizer" will use CoreWeave's tensorizer library for + fast weight loading. + """ + + load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO + download_dir: Optional[str] = None + model_loader_extra_config: Optional[Union[str, dict]] = field( + default_factory=dict) + + def __post_init__(self): + model_loader_extra_config = self.model_loader_extra_config or {} + if isinstance(model_loader_extra_config, str): + self.model_loader_extra_config = json.loads( + model_loader_extra_config) + self._verify_load_format() + + def _verify_load_format(self) -> None: + if not isinstance(self.load_format, str): + return + + load_format = self.load_format.lower() + self.load_format = LoadFormat(load_format) + + rocm_not_supported_load_format: List[str] = [] + if is_hip() and load_format in rocm_not_supported_load_format: + rocm_supported_load_format = [ + f for f in LoadFormat.__members__ + if (f not in rocm_not_supported_load_format) + ] + raise ValueError( + f"load format '{load_format}' is not supported in ROCm. " + f"Supported load formats are " + f"{rocm_supported_load_format}") + + class ParallelConfig: """Configuration for the distributed execution. @@ -699,8 +696,6 @@ def maybe_create_spec_config( tokenizer=target_model_config.tokenizer, tokenizer_mode=target_model_config.tokenizer_mode, trust_remote_code=target_model_config.trust_remote_code, - download_dir=target_model_config.download_dir, - load_format=target_model_config.load_format, dtype=target_model_config.dtype, seed=target_model_config.seed, revision=draft_revision, @@ -887,65 +882,6 @@ def get_image_input_enum_type( f"{[x.name for x in cls.ImageInputType]}.") from e -@dataclass -class TensorizerConfig: - tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO, - str, bytes, os.PathLike, int] - vllm_tensorized: bool - verify_hash: Optional[bool] = False - num_readers: Optional[int] = 1 - encryption_keyfile: Optional[str] = None - s3_access_key_id: Optional[str] = None - s3_secret_access_key: Optional[str] = None - s3_endpoint: Optional[str] = None - model_class: Optional[torch.nn.Module] = None - hf_config: Optional[PretrainedConfig] = None - dtype: Union[str, torch.dtype] = None - - def _construct_tensorizer_args(self) -> "TensorizerArgs": - from vllm.model_executor.tensorizer_loader import TensorizerArgs - tensorizer_args = { - "tensorizer_uri": self.tensorizer_uri, - "vllm_tensorized": self.vllm_tensorized, - "verify_hash": self.verify_hash, - "num_readers": self.num_readers, - "encryption_keyfile": self.encryption_keyfile, - "s3_access_key_id": self.s3_access_key_id, - "s3_secret_access_key": self.s3_secret_access_key, - "s3_endpoint": self.s3_endpoint, - } - return TensorizerArgs(**tensorizer_args) - - def verify_with_parallel_config( - self, - parallel_config: "ParallelConfig", - ) -> None: - if (parallel_config.tensor_parallel_size > 1 - and self.tensorizer_uri is not None): - raise ValueError( - "Loading to multiple GPUs is not currently supported with " - "vLLM-serialized models. Please set tensor_parallel_size=1." - " or use a non-vLLM-serialized model, such as a " - "serialized Hugging Face `PretrainedModel`.") - - def verify_with_model_config(self, model_config) -> None: - if (model_config.quantization is not None - and self.tensorizer_uri is not None): - from vllm.model_executor.tensorizer_loader import ( - tensorizer_warning) - tensorizer_warning( - "Loading a model using Tensorizer with quantization on vLLM" - " is unstable and may lead to errors.") - - if (model_config.load_format != "tensorizer" - and self.tensorizer_uri is not None): - raise ValueError( - "A tensorizer uri was passed for tensorizer loading, but the " - f"load format was set to {model_config.load_format}. " - "Please set the load format to 'tensorizer' to use " - f"tensorizer args.") - - _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, "float16": torch.float16, @@ -1105,11 +1041,11 @@ class EngineConfig: parallel_config: ParallelConfig scheduler_config: SchedulerConfig device_config: DeviceConfig + load_config: LoadConfig lora_config: Optional[LoRAConfig] vision_language_config: Optional[VisionLanguageConfig] speculative_config: Optional[SpeculativeConfig] decoding_config: Optional[DecodingConfig] - tensorizer_config: Optional[TensorizerConfig] def __post_init__(self): """Verify configs are valid & consistent with each other. @@ -1117,11 +1053,6 @@ def __post_init__(self): self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) - if self.tensorizer_config: - self.tensorizer_config.verify_with_parallel_config( - self.parallel_config) - self.tensorizer_config.verify_with_model_config(self.model_config) - if self.lora_config: self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_scheduler_config( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3de74b0ac28b9..c61c0cc67d7a2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,15 +1,12 @@ import argparse import dataclasses -import io -import os from dataclasses import dataclass -from typing import BinaryIO, Optional, Union +from typing import Optional from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, - EngineConfig, LoRAConfig, ModelConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig, TensorizerConfig, + EngineConfig, LoadConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig, SpeculativeConfig, TokenizerPoolConfig, VisionLanguageConfig) -from vllm.model_executor.tensorizer_loader import TensorizerArgs from vllm.utils import str_to_int_tuple @@ -60,17 +57,7 @@ class EngineArgs: ray_workers_use_nsight: bool = False num_gpu_blocks_override: Optional[int] = None num_lookahead_slots: int = 0 - - # Tensorizer configuration parameters - tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str, - bytes, os.PathLike, int] = None - vllm_tensorized: bool = False - verify_hash: Optional[bool] = False - num_readers: Optional[int] = 1 - encryption_keyfile: Optional[str] = None - s3_access_key_id: Optional[str] = None - s3_secret_access_key: Optional[str] = None - s3_endpoint: Optional[str] = None + model_loader_extra_config: Optional[dict] = None # Related to Vision-language models such as llava image_input_type: Optional[str] = None @@ -429,7 +416,16 @@ def add_cli_args( default=None, help='The number of speculative tokens to sample from ' 'the draft model in speculative decoding') - parser = TensorizerArgs.add_cli_args(parser) + + parser.add_argument('--model-loader-extra-config', + type=str, + default=EngineArgs.model_loader_extra_config, + help='Extra config for model loader. ' + 'This will be passed to the model loader ' + 'corresponding to the chosen load_format. ' + 'This should be a JSON string that will be ' + 'parsed into a dictionary.') + return parser @classmethod @@ -444,11 +440,11 @@ def create_engine_config(self, ) -> EngineConfig: device_config = DeviceConfig(self.device) model_config = ModelConfig( self.model, self.tokenizer, self.tokenizer_mode, - self.trust_remote_code, self.download_dir, self.load_format, - self.dtype, self.seed, self.revision, self.code_revision, - self.tokenizer_revision, self.max_model_len, self.quantization, - self.quantization_param_path, self.enforce_eager, - self.max_context_len_to_capture, self.max_logprobs) + self.trust_remote_code, self.dtype, self.seed, self.revision, + self.code_revision, self.tokenizer_revision, self.max_model_len, + self.quantization, self.quantization_param_path, + self.enforce_eager, self.max_context_len_to_capture, + self.max_logprobs) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, @@ -492,15 +488,10 @@ def create_engine_config(self, ) -> EngineConfig: max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None) if self.enable_lora else None - tensorizer_config = TensorizerConfig( - tensorizer_uri=self.tensorizer_uri, - vllm_tensorized=self.vllm_tensorized, - verify_hash=self.verify_hash, - num_readers=self.num_readers, - encryption_keyfile=self.encryption_keyfile, - s3_access_key_id=self.s3_access_key_id, - s3_secret_access_key=self.s3_secret_access_key, - s3_endpoint=self.s3_endpoint, + load_config = LoadConfig( + load_format=self.load_format, + download_dir=self.download_dir, + model_loader_extra_config=self.model_loader_extra_config, ) if self.image_input_type: @@ -530,8 +521,8 @@ def create_engine_config(self, ) -> EngineConfig: lora_config=lora_config, vision_language_config=vision_language_config, speculative_config=speculative_config, - decoding_config=decoding_config, - tensorizer_config=tensorizer_config) + load_config=load_config, + decoding_config=decoding_config) @dataclass diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f06c1d18ace4b..563694946d16e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -4,9 +4,9 @@ from transformers import PreTrainedTokenizer import vllm -from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, TensorizerConfig, +from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, + LoRAConfig, ModelConfig, ParallelConfig, + SchedulerConfig, SpeculativeConfig, VisionLanguageConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs @@ -72,11 +72,11 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + load_config: LoadConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], decoding_config: Optional[DecodingConfig], - tensorizer_config: Optional[TensorizerConfig], executor_class: Type[ExecutorBase], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, @@ -92,8 +92,8 @@ def __init__( f"trust_remote_code={model_config.trust_remote_code}, " f"dtype={model_config.dtype}, " f"max_seq_len={model_config.max_model_len}, " - f"download_dir={model_config.download_dir!r}, " - f"load_format={model_config.load_format}, " + f"download_dir={load_config.download_dir!r}, " + f"load_format={load_config.load_format}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " f"disable_custom_all_reduce=" f"{parallel_config.disable_custom_all_reduce}, " @@ -114,8 +114,8 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.speculative_config = speculative_config + self.load_config = load_config self.decoding_config = decoding_config or DecodingConfig() - self.tensorizer_config = tensorizer_config self.log_stats = log_stats self._init_tokenizer() @@ -131,7 +131,7 @@ def __init__( lora_config=lora_config, vision_language_config=vision_language_config, speculative_config=speculative_config, - tensorizer_config=tensorizer_config, + load_config=load_config, ) self._initialize_kv_caches() @@ -271,9 +271,6 @@ def _init_tokenizer(self, **tokenizer_init_kwargs): def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) - if self.tensorizer_config: - self.tensorizer_config.verify_with_parallel_config( - self.parallel_config) if self.lora_config: self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_scheduler_config( diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index f562e4e0ae3de..426e2c41d8427 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -40,6 +40,7 @@ def _init_worker(self): scheduler_config=self.scheduler_config, device_config=self.device_config, cache_config=self.cache_config, + load_config=self.load_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index bbb6ec80f7b7e..8cc04c5299ca1 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional, Set, Tuple -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - TensorizerConfig, VisionLanguageConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig, VisionLanguageConfig) from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -23,20 +23,20 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + load_config: LoadConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], - tensorizer_config: Optional[TensorizerConfig], ) -> None: self.model_config = model_config self.cache_config = cache_config self.lora_config = lora_config + self.load_config = load_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config self.vision_language_config = vision_language_config self.speculative_config = speculative_config - self.tensorizer_config = tensorizer_config self._init_executor() diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index bae509f48025b..3a9537effe6d9 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -35,12 +35,12 @@ def _init_worker(self): scheduler_config=self.scheduler_config, device_config=self.device_config, cache_config=self.cache_config, + load_config=self.load_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, - tensorizer_config=self.tensorizer_config, is_driver_worker=True, ) self.driver_worker.init_device() diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 7aca5e36107aa..4065c0868d79a 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -147,6 +147,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", model_config = copy.deepcopy(self.model_config) parallel_config = copy.deepcopy(self.parallel_config) scheduler_config = copy.deepcopy(self.scheduler_config) + load_config = copy.deepcopy(self.load_config) device_config = copy.deepcopy(self.device_config) lora_config = copy.deepcopy(self.lora_config) cache_config = copy.deepcopy(self.cache_config) @@ -165,12 +166,12 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", scheduler_config=scheduler_config, device_config=device_config, cache_config=cache_config, + load_config=load_config, local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method, lora_config=lora_config, vision_language_config=vision_language_config, - tensorizer_config=self.tensorizer_config, )) # Initialize the driver worker with the Worker class. @@ -187,7 +188,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, - tensorizer_config=self.tensorizer_config, + load_config=self.load_config, is_driver_worker=True, ) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py deleted file mode 100644 index c70ca48bca70a..0000000000000 --- a/vllm/model_executor/model_loader.py +++ /dev/null @@ -1,128 +0,0 @@ -"""Utilities for selecting and loading models.""" -import contextlib -from typing import Tuple, Type - -import torch -from torch import nn - -from vllm.config import DeviceConfig, ModelConfig -from vllm.model_executor.models import ModelRegistry -from vllm.model_executor.models.llava import LlavaForConditionalGeneration -from vllm.model_executor.tensorizer_loader import ( - ParameterizedLoadFormat, is_vllm_serialized_tensorizer, - load_with_tensorizer) -from vllm.model_executor.weight_utils import (get_quant_config, - initialize_dummy_weights) - -_VISION_MODEL_CLASSES = [ - LlavaForConditionalGeneration, -] - - -@contextlib.contextmanager -def _set_default_torch_dtype(dtype: torch.dtype): - """Sets the default torch dtype to the given dtype.""" - old_dtype = torch.get_default_dtype() - torch.set_default_dtype(dtype) - yield - torch.set_default_dtype(old_dtype) - - -def _get_model_architecture( - model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: - architectures = getattr(model_config.hf_config, "architectures", []) - # Special handling for quantized Mixtral. - # FIXME(woosuk): This is a temporary hack. - if (model_config.quantization is not None - and "MixtralForCausalLM" in architectures): - architectures = ["QuantMixtralForCausalLM"] - - for arch in architectures: - model_cls = ModelRegistry.load_model_cls(arch) - if model_cls is not None: - return (model_cls, arch) - raise ValueError( - f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {ModelRegistry.get_supported_archs()}") - - -def get_architecture_class_name(model_config: ModelConfig) -> str: - return _get_model_architecture(model_config)[1] - - -def get_model(model_config: ModelConfig, device_config: DeviceConfig, - **kwargs) -> nn.Module: - lora_config = kwargs.get("lora_config", None) - vision_language_config = kwargs.get("vision_language_config", None) - tensorizer_config = kwargs.get("tensorizer_config", None) - model_class = _get_model_architecture(model_config)[0] - - # Get the (maybe quantized) linear method. - linear_method = None - if model_config.quantization is not None: - quant_config = get_quant_config(model_config) - capability = torch.cuda.get_device_capability() - capability = capability[0] * 10 + capability[1] - if capability < quant_config.get_min_capability(): - raise ValueError( - f"The quantization method {model_config.quantization} is not " - "supported for the current GPU. " - f"Minimum capability: {quant_config.get_min_capability()}. " - f"Current capability: {capability}.") - supported_dtypes = quant_config.get_supported_act_dtypes() - if model_config.dtype not in supported_dtypes: - raise ValueError( - f"{model_config.dtype} is not supported for quantization " - f"method {model_config.quantization}. Supported dtypes: " - f"{supported_dtypes}") - - linear_method = quant_config.get_linear_method() - - with _set_default_torch_dtype(model_config.dtype): - # Create a model instance. - # The weights will be initialized as empty tensors. - extra_kwargs = {} - if hasattr(model_class, "supported_lora_modules"): - extra_kwargs["lora_config"] = lora_config - elif lora_config: - raise ValueError( - f"Model {model_class.__name__} does not support LoRA, " - "but LoRA is enabled. Support for this model may " - "be added in the future. If this is important to you, " - "please open an issue on github.") - elif model_class in _VISION_MODEL_CLASSES: - extra_kwargs["vision_language_config"] = vision_language_config - - with torch.device(device_config.device): - if (model_config.load_format == "tensorizer" - and is_vllm_serialized_tensorizer(tensorizer_config)): - extra_kwargs["linear_method"] = linear_method - tensorizer_config.model_class = model_class - tensorizer_config.hf_config = model_config.hf_config - tensorizer_config.dtype = model_config.dtype - model = load_with_tensorizer(tensorizer_config, **extra_kwargs) - return model.eval() - model = model_class(config=model_config.hf_config, - linear_method=linear_method, - **extra_kwargs) - if model_config.load_format == "dummy": - # NOTE(woosuk): For accurate performance evaluation, we assign - # random values to the weights. - initialize_dummy_weights(model) - else: - # Load the weights from the cached or downloaded files. - if model_config.load_format == "tensorizer": - # Provide a dynamic load format for `model.load_weights` - # to retain tensorizer args from CLI. - model_config.load_format = ParameterizedLoadFormat( - model_config.load_format) - model_config.load_format.params = ( - tensorizer_config._construct_tensorizer_args()) - - model.load_weights( - model_config.model, - model_config.download_dir, - model_config.load_format, - model_config.revision, - ) - return model.eval() diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py new file mode 100644 index 0000000000000..6f90e49994fb2 --- /dev/null +++ b/vllm/model_executor/model_loader/__init__.py @@ -0,0 +1,30 @@ +from typing import Optional + +from torch import nn + +from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.model_executor.model_loader.loader import (BaseModelLoader, + get_model_loader) +from vllm.model_executor.model_loader.utils import ( + get_architecture_class_name, get_model_architecture) + + +def get_model( + *, model_config: ModelConfig, load_config: LoadConfig, + device_config: DeviceConfig, parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module: + loader = get_model_loader(load_config) + return loader.load_model(model_config=model_config, + device_config=device_config, + lora_config=lora_config, + vision_language_config=vision_language_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config) + + +__all__ = [ + "get_model", "get_model_loader", "BaseModelLoader", + "get_architecture_class_name", "get_model_architecture" +] diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py new file mode 100644 index 0000000000000..3b1d125ef8a67 --- /dev/null +++ b/vllm/model_executor/model_loader/loader.py @@ -0,0 +1,354 @@ +# ruff: noqa: SIM117 +import copy +import glob +import os +from abc import ABC, abstractmethod +from typing import (TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, + Type) + +import torch +from torch import nn + +from vllm.config import (VLLM_USE_MODELSCOPE, DeviceConfig, LoadConfig, + LoadFormat, LoRAConfig, ModelConfig, ParallelConfig, + SchedulerConfig, VisionLanguageConfig) +from vllm.logger import init_logger +from vllm.model_executor.model_loader.tensorizer import ( + TensorizerConfig, is_vllm_serialized_tensorizer, load_with_tensorizer, + tensorizer_weights_iterator) +from vllm.model_executor.model_loader.utils import (get_model_architecture, + set_default_torch_dtype) +from vllm.model_executor.model_loader.weight_utils import ( + download_weights_from_hf, filter_files_not_needed_for_inference, + get_quant_config, initialize_dummy_weights, np_cache_weights_iterator, + pt_weights_iterator, safetensors_weights_iterator) +from vllm.model_executor.models.llava import LlavaForConditionalGeneration + +if TYPE_CHECKING: + from vllm.model_executor.layers.linear import LinearMethodBase + +_VISION_MODEL_CLASSES = [ + LlavaForConditionalGeneration, +] + +logger = init_logger(__name__) + + +def _get_linear_method( + model_config: ModelConfig, + load_config: LoadConfig) -> Optional["LinearMethodBase"]: + """Get the (maybe quantized) linear method.""" + linear_method = None + if model_config.quantization is not None: + quant_config = get_quant_config(model_config, load_config) + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} is not " + "supported for the current GPU. " + f"Minimum capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}.") + supported_dtypes = quant_config.get_supported_act_dtypes() + if model_config.dtype not in supported_dtypes: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}") + + linear_method = quant_config.get_linear_method() + return linear_method + + +def _get_model_initialization_kwargs( + model_class: Type[nn.Module], lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig] +) -> Dict[str, Any]: + """Get extra kwargs for model initialization.""" + extra_kwargs = {} + if hasattr(model_class, "supported_lora_modules"): + extra_kwargs["lora_config"] = lora_config + elif lora_config: + raise ValueError( + f"Model {model_class.__name__} does not support LoRA, " + "but LoRA is enabled. Support for this model may " + "be added in the future. If this is important to you, " + "please open an issue on github.") + elif model_class in _VISION_MODEL_CLASSES: + extra_kwargs["vision_language_config"] = vision_language_config + return extra_kwargs + + +def _initialize_model( + model_config: ModelConfig, load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module: + """Initialize a model with the given configurations.""" + model_class = get_model_architecture(model_config)[0] + linear_method = _get_linear_method(model_config, load_config) + + return model_class(config=model_config.hf_config, + linear_method=linear_method, + **_get_model_initialization_kwargs( + model_class, lora_config, vision_language_config)) + + +class BaseModelLoader(ABC): + """Base class for model loaders.""" + + def __init__(self, load_config: LoadConfig): + self.load_config = load_config + + @abstractmethod + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig) -> nn.Module: + """Load a model with the given configurations.""" + ... + + +class DefaultModelLoader(BaseModelLoader): + """Model loader that can load different file types from disk.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def _maybe_download_from_modelscope( + self, model: str, revision: Optional[str]) -> Optional[str]: + """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. + + Returns the path to the downloaded model, or None if the model is not + downloaded from ModelScope.""" + if VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + from modelscope.hub.snapshot_download import snapshot_download + + if not os.path.exists(model): + model_path = snapshot_download( + model_id=model, + cache_dir=self.load_config.download_dir, + revision=revision) + else: + model_path = model + return model_path + return None + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str], + fall_back_to_pt: bool) -> Tuple[str, List[str], bool]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + model_name_or_path = self._maybe_download_from_modelscope( + model_name_or_path, revision) or model_name_or_path + + is_local = os.path.isdir(model_name_or_path) + load_format = self.load_config.load_format + use_safetensors = False + # Some quantized models use .pt files for storing the weights. + if load_format == LoadFormat.AUTO: + allow_patterns = ["*.safetensors", "*.bin"] + elif load_format == LoadFormat.SAFETENSORS: + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == LoadFormat.PT: + allow_patterns = ["*.pt"] + elif load_format == LoadFormat.NPCACHE: + allow_patterns = ["*.bin"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + if not is_local: + hf_folder = download_weights_from_hf(model_name_or_path, + self.load_config.download_dir, + allow_patterns) + else: + hf_folder = model_name_or_path + + hf_weights_files: List[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + + if not use_safetensors: + hf_weights_files = filter_files_not_needed_for_inference( + hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_folder, hf_weights_files, use_safetensors + + def _get_weights_iterator( + self, model_name_or_path: str, revision: Optional[str], + fall_back_to_pt: bool + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( + model_name_or_path, revision, fall_back_to_pt) + if self.load_config.load_format == LoadFormat.NPCACHE: + # Currently np_cache only support *.bin checkpoints + assert use_safetensors is False + return np_cache_weights_iterator(model_name_or_path, + self.load_config.download_dir, + hf_folder, hf_weights_files) + if use_safetensors: + return safetensors_weights_iterator(hf_weights_files) + return pt_weights_iterator(hf_weights_files) + + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, + lora_config, vision_language_config) + model.load_weights( + self._get_weights_iterator(model_config.model, + model_config.revision, + fall_back_to_pt=getattr( + model, + "fall_back_to_pt_during_load", + True)), ) + return model.eval() + + +class DummyModelLoader(BaseModelLoader): + """Model loader that will set model weights to random values.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, + lora_config, vision_language_config) + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + initialize_dummy_weights(model) + return model.eval() + + +class TensorizerLoader(BaseModelLoader): + """Model loader using CoreWeave's tensorizer library.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if isinstance(load_config.model_loader_extra_config, TensorizerConfig): + self.tensorizer_config = load_config.model_loader_extra_config + else: + self.tensorizer_config = TensorizerConfig( + **load_config.model_loader_extra_config) + + def _verify_config(self, model_config: ModelConfig, + parallel_config: ParallelConfig): + self.tensorizer_config.verify_with_model_config(model_config) + self.tensorizer_config.verify_with_parallel_config(parallel_config) + + def _get_weights_iterator( + self) -> Generator[Tuple[str, torch.Tensor], None, None]: + tensorizer_args = self.tensorizer_config._construct_tensorizer_args() + return tensorizer_weights_iterator(tensorizer_args) + + def _load_model_unserialized( + self, model_config: ModelConfig, device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig] + ) -> nn.Module: + """Load an unserialized model with tensorizer. + + Unserialized here means "not serialized with tensorizer". This + should still be faster than default HuggingFace loading, but will + be slower than loading a tensorizer-serialized model. + """ + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, + lora_config, vision_language_config) + + model.load_weights(self._get_weights_iterator()) + return model.eval() + + def _load_model_serialized( + self, model_config: ModelConfig, device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig] + ) -> nn.Module: + """Load a serialized model with tensorizer. + + See the examples/tensorize_vllm_model.py example " + script for serializing vLLM models.""" + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model_class = get_model_architecture(model_config)[0] + linear_method = _get_linear_method(model_config, + self.load_config) + extra_kwargs = _get_model_initialization_kwargs( + model_class, lora_config, vision_language_config) + extra_kwargs["linear_method"] = linear_method + + tensorizer_config = copy.copy(self.tensorizer_config) + tensorizer_config.model_class = model_class + tensorizer_config.hf_config = model_config.hf_config + tensorizer_config.dtype = model_config.dtype + + model = load_with_tensorizer(tensorizer_config, **extra_kwargs) + return model.eval() + + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig) -> nn.Module: + self._verify_config(model_config, parallel_config) + + if is_vllm_serialized_tensorizer(self.tensorizer_config): + return self._load_model_serialized(model_config, device_config, + lora_config, + vision_language_config) + return self._load_model_unserialized(model_config, device_config, + lora_config, + vision_language_config) + + +def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: + """Get a model loader based on the load format.""" + + if isinstance(load_config.load_format, type): + return load_config.load_format(load_config) + + if load_config.load_format == LoadFormat.DUMMY: + return DummyModelLoader(load_config) + + if load_config.load_format == LoadFormat.TENSORIZER: + return TensorizerLoader(load_config) + + return DefaultModelLoader(load_config) diff --git a/vllm/model_executor/neuron_model_loader.py b/vllm/model_executor/model_loader/neuron.py similarity index 100% rename from vllm/model_executor/neuron_model_loader.py rename to vllm/model_executor/model_loader/neuron.py diff --git a/vllm/model_executor/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer.py similarity index 78% rename from vllm/model_executor/tensorizer_loader.py rename to vllm/model_executor/model_loader/tensorizer.py index 8550cc97aefe8..ad554844384eb 100644 --- a/vllm/model_executor/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -4,20 +4,20 @@ import os import time import typing -import warnings from dataclasses import dataclass -from typing import Optional, Union +from typing import Generator, Optional, Tuple, Type, Union import torch from torch import nn +from transformers import PretrainedConfig -from vllm.config import TensorizerConfig +from vllm.config import ModelConfig, ParallelConfig from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearMethodBase from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -tensorizer_load_fail = False +tensorizer_load_fail = None try: from tensorizer import (DecryptionParams, EncryptionParams, @@ -25,51 +25,78 @@ from tensorizer.stream_io import open_stream from tensorizer.utils import (convert_bytes, get_mem_usage, no_init_or_tensor) -except ImportError: - tensorizer_load_fail = True +except ImportError as e: + tensorizer_load_fail = e __all__ = [ 'EncryptionParams', 'DecryptionParams', 'TensorDeserializer', 'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage', - 'no_init_or_tensor' + 'no_init_or_tensor', 'TensorizerConfig' ] logger = init_logger(__name__) +@dataclass +class TensorizerConfig: + tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO, + str, bytes, os.PathLike, int] + vllm_tensorized: bool + verify_hash: Optional[bool] = False + num_readers: Optional[int] = 1 + encryption_keyfile: Optional[str] = None + s3_access_key_id: Optional[str] = None + s3_secret_access_key: Optional[str] = None + s3_endpoint: Optional[str] = None + model_class: Optional[Type[torch.nn.Module]] = None + hf_config: Optional[PretrainedConfig] = None + dtype: Optional[Union[str, torch.dtype]] = None + + def _construct_tensorizer_args(self) -> "TensorizerArgs": + tensorizer_args = { + "tensorizer_uri": self.tensorizer_uri, + "vllm_tensorized": self.vllm_tensorized, + "verify_hash": self.verify_hash, + "num_readers": self.num_readers, + "encryption_keyfile": self.encryption_keyfile, + "s3_access_key_id": self.s3_access_key_id, + "s3_secret_access_key": self.s3_secret_access_key, + "s3_endpoint": self.s3_endpoint, + } + return TensorizerArgs(**tensorizer_args) + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + if (parallel_config.tensor_parallel_size > 1 + and self.tensorizer_uri is not None): + raise ValueError( + "Loading to multiple GPUs is not currently supported with " + "vLLM-serialized models. Please set tensor_parallel_size=1." + " or use a non-vLLM-serialized model, such as a " + "serialized Hugging Face `PretrainedModel`.") + + def verify_with_model_config(self, model_config: "ModelConfig") -> None: + if (model_config.quantization is not None + and self.tensorizer_uri is not None): + logger.warning( + "Loading a model using Tensorizer with quantization on vLLM" + " is unstable and may lead to errors.") + + def load_with_tensorizer(tensorizer_config: TensorizerConfig, **extra_kwargs) -> nn.Module: tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs) return tensorizer.deserialize() -def tensorizer_warning(message: str): - return warnings.warn(message, category=PerformanceWarning, stacklevel=2) - - def is_vllm_serialized_tensorizer(tensorizer_config: TensorizerConfig) -> bool: if tensorizer_config is None: return False return tensorizer_config.vllm_tensorized -class ParameterizedLoadFormat(str): - __slots__ = "params" - - -class PerformanceWarning(UserWarning): - - def __str__(self): - return (f"{super().__str__()}" - " (set the VLLM_SILENCE_PERFORMANCE_WARNINGS" - " environment variable to hide this)") - - -if (os.getenv("VLLM_SILENCE_PERFORMANCE_WARNINGS", "").lower() - not in ("", "0", "n", "no", "off", "disable")): - warnings.simplefilter("ignore", category=PerformanceWarning) - - @dataclass class TensorizerArgs: tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO, @@ -219,11 +246,17 @@ class TensorizerAgent: behavior of the TensorDeserializer when loading tensors from a serialized model. For deserializations of HuggingFace models, TensorDeserializer is instead used as an iterator directly in the func hf_model_weights_iterator - in vllm/model_executor/weight_utils.py + in vllm/model_executor/model_loader/weight_utils.py """ def __init__(self, tensorizer_config: TensorizerConfig, linear_method: LinearMethodBase, **extra_kwargs): + if tensorizer_load_fail is not None: + raise ImportError( + "Tensorizer is not installed. Please install tensorizer " + "to use this feature with `pip install vllm[tensorizer]`." + ) from tensorizer_load_fail + self.tensorizer_config = tensorizer_config self.tensorizer_args = ( self.tensorizer_config._construct_tensorizer_args()) @@ -234,11 +267,6 @@ def __init__(self, tensorizer_config: TensorizerConfig, self.linear_method = linear_method self.model = self._init_model() - if tensorizer_load_fail: - raise ImportError( - "Tensorizer is not installed. Please install tensorizer " - "to use this feature with `pip install vllm[tensorizer]`.") - def _init_model(self): model_args = self.tensorizer_config.hf_config model_args.torch_dtype = self.tensorizer_config.dtype @@ -313,3 +341,23 @@ def deserialize(self): self._check_tensors_on_meta_device() self._resize_lora_embeddings() return self.model.eval() + + +def tensorizer_weights_iterator( + tensorizer_args: "TensorizerArgs" +) -> Generator[Tuple[str, torch.Tensor], None, None]: + logger.warning( + "Deserializing HuggingFace models is not optimized for " + "loading on vLLM, as tensorizer is forced to load to CPU. " + "Consider deserializing a vLLM model instead for faster " + "load times. See the examples/tensorize_vllm_model.py example " + "script for serializing vLLM models.") + + deserializer_args = tensorizer_args.deserializer_params + stream_params = tensorizer_args.stream_params + stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params) + with TensorDeserializer(stream, **deserializer_args, + device="cpu") as state: + for name, param in state.items(): + yield name, param + del state diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py new file mode 100644 index 0000000000000..a0a3b2784614d --- /dev/null +++ b/vllm/model_executor/model_loader/utils.py @@ -0,0 +1,40 @@ +"""Utilities for selecting and loading models.""" +import contextlib +from typing import Tuple, Type + +import torch +from torch import nn + +from vllm.config import ModelConfig +from vllm.model_executor.models import ModelRegistry + + +@contextlib.contextmanager +def set_default_torch_dtype(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(old_dtype) + + +def get_model_architecture( + model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: + architectures = getattr(model_config.hf_config, "architectures", []) + # Special handling for quantized Mixtral. + # FIXME(woosuk): This is a temporary hack. + if (model_config.quantization is not None + and "MixtralForCausalLM" in architectures): + architectures = ["QuantMixtralForCausalLM"] + + for arch in architectures: + model_cls = ModelRegistry.load_model_cls(arch) + if model_cls is not None: + return (model_cls, arch) + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + +def get_architecture_class_name(model_config: ModelConfig) -> str: + return get_model_architecture(model_config)[1] diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py similarity index 53% rename from vllm/model_executor/weight_utils.py rename to vllm/model_executor/model_loader/weight_utils.py index 08425604f0511..1798db0136868 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -4,8 +4,9 @@ import hashlib import json import os +import tempfile from collections import defaultdict -from typing import Any, Iterable, Iterator, List, Optional, Tuple, Union +from typing import Any, Generator, Iterable, List, Optional, Tuple import filelock import huggingface_hub.constants @@ -15,7 +16,7 @@ from safetensors.torch import load_file, safe_open, save_file from tqdm.auto import tqdm -from vllm.config import ModelConfig +from vllm.config import LoadConfig, ModelConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QuantizationConfig, get_quantization_config) @@ -27,8 +28,7 @@ # can share the same lock without error. # lock files in the temp directory will be automatically deleted when the # system reboots, so users will not complain about annoying lock files -temp_dir = os.environ.get('TMPDIR') or os.environ.get( - 'TEMP') or os.environ.get('TMP') or "/tmp/" +temp_dir = tempfile.gettempdir() def enable_hf_transfer(): @@ -46,7 +46,7 @@ def enable_hf_transfer(): enable_hf_transfer() -class Disabledtqdm(tqdm): +class DisabledTqdm(tqdm): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs, disable=True) @@ -114,7 +114,8 @@ def convert_bin_to_safetensor_file( # TODO(woosuk): Move this to other place. -def get_quant_config(model_config: ModelConfig) -> QuantizationConfig: +def get_quant_config(model_config: ModelConfig, + load_config: LoadConfig) -> QuantizationConfig: quant_cls = get_quantization_config(model_config.quantization) # Read the quantization config from the HF model config, if available. hf_quant_config = getattr(model_config.hf_config, "quantization_config", @@ -125,12 +126,12 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig: is_local = os.path.isdir(model_name_or_path) if not is_local: # Download the config files. - with get_lock(model_name_or_path, model_config.download_dir): + with get_lock(model_name_or_path, load_config.download_dir): hf_folder = snapshot_download(model_name_or_path, revision=model_config.revision, allow_patterns="*.json", - cache_dir=model_config.download_dir, - tqdm_class=Disabledtqdm) + cache_dir=load_config.download_dir, + tqdm_class=DisabledTqdm) else: hf_folder = model_name_or_path config_files = glob.glob(os.path.join(hf_folder, "*.json")) @@ -153,169 +154,127 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig: return quant_cls.from_config(config) -def prepare_hf_model_weights( - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - fall_back_to_pt: bool = True, - revision: Optional[str] = None, -) -> Tuple[str, List[str], bool]: - # Download model weights from huggingface. - is_local = os.path.isdir(model_name_or_path) \ - and load_format != "tensorizer" - use_safetensors = False - # Some quantized models use .pt files for storing the weights. - if load_format == "auto": - allow_patterns = ["*.safetensors", "*.bin"] - elif load_format == "safetensors": - use_safetensors = True - allow_patterns = ["*.safetensors"] - elif load_format == "pt": - allow_patterns = ["*.pt"] - elif load_format == "npcache": - allow_patterns = ["*.bin"] - elif load_format == "tensorizer": - allow_patterns = ["*.tensors"] - else: - raise ValueError(f"Unknown load_format: {load_format}") - - if fall_back_to_pt: - allow_patterns += ["*.pt"] - - if not is_local and load_format != "tensorizer": - # Before we download we look at that is available: - fs = HfFileSystem() - file_list = fs.ls(model_name_or_path, detail=False, revision=revision) - - # depending on what is available we download different things - for pattern in allow_patterns: - matching = fnmatch.filter(file_list, pattern) - if len(matching) > 0: - allow_patterns = [pattern] - break - - logger.info(f"Using model weights format {allow_patterns}") - # Use file lock to prevent multiple processes from - # downloading the same model weights at the same time. - with get_lock(model_name_or_path, cache_dir): - hf_folder = snapshot_download(model_name_or_path, - allow_patterns=allow_patterns, - cache_dir=cache_dir, - tqdm_class=Disabledtqdm, - revision=revision) - else: - hf_folder = model_name_or_path - hf_weights_files: List[str] = [] +def download_weights_from_hf(model_name_or_path: str, + cache_dir: Optional[str], + allow_patterns: List[str], + revision: Optional[str] = None) -> str: + """Download model weights from Hugging Face Hub. + + Args: + model_name_or_path (str): The model name or path. + cache_dir (Optional[str]): The cache directory to store the model + weights. If None, will use HF defaults. + allow_patterns (List[str]): The allowed patterns for the + weight files. Files matched by any of the patterns will be + downloaded. + revision (Optional[str]): The revision of the model. + + Returns: + str: The path to the downloaded model weights. + """ + # Before we download we look at that is available: + fs = HfFileSystem() + file_list = fs.ls(model_name_or_path, detail=False, revision=revision) + + # depending on what is available we download different things for pattern in allow_patterns: - hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) - if len(hf_weights_files) > 0: - if pattern == "*.safetensors": - use_safetensors = True + matching = fnmatch.filter(file_list, pattern) + if len(matching) > 0: + allow_patterns = [pattern] break - if not use_safetensors: - # Exclude files that are not needed for inference. - # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233 - blacklist = [ - "training_args.bin", - "optimizer.bin", - "optimizer.pt", - "scheduler.pt", - "scaler.pt", - ] - hf_weights_files = [ - f for f in hf_weights_files - if not any(f.endswith(x) for x in blacklist) - ] - - if load_format == "tensorizer": - return hf_folder, hf_weights_files, use_safetensors - - if len(hf_weights_files) == 0: - raise RuntimeError( - f"Cannot find any model weights with `{model_name_or_path}`") - - return hf_folder, hf_weights_files, use_safetensors - - -def hf_model_weights_iterator( - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: Union[Tuple, str] = "auto", - revision: Optional[str] = None, - fall_back_to_pt: Optional[bool] = True, -) -> Iterator[Tuple[str, torch.Tensor]]: - hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights( - model_name_or_path, - cache_dir=cache_dir, - load_format=load_format, - fall_back_to_pt=fall_back_to_pt, - revision=revision) - - if load_format == "npcache": - # Currently np_cache only support *.bin checkpoints - assert use_safetensors is False - - # Convert the model weights from torch tensors to numpy arrays for - # faster loading. - np_folder = os.path.join(hf_folder, "np") - os.makedirs(np_folder, exist_ok=True) - weight_names_file = os.path.join(np_folder, "weight_names.json") - # Use file lock to prevent multiple processes from - # dumping the same model weights to numpy at the same time. - with get_lock(model_name_or_path, cache_dir): - if not os.path.exists(weight_names_file): - weight_names = [] - for bin_file in hf_weights_files: - state = torch.load(bin_file, map_location="cpu") - for name, param in state.items(): - param_path = os.path.join(np_folder, name) - with open(param_path, "wb") as f: - np.save(f, param.cpu().detach().numpy()) - weight_names.append(name) - with open(weight_names_file, "w") as f: - json.dump(weight_names, f) - - with open(weight_names_file, "r") as f: - weight_names = json.load(f) - - for name in weight_names: - param_path = os.path.join(np_folder, name) - with open(param_path, "rb") as f: - param = np.load(f) - yield name, torch.from_numpy(param) - elif load_format == "tensorizer": - from vllm.model_executor.tensorizer_loader import (TensorDeserializer, - open_stream, - tensorizer_warning) - tensorizer_args = load_format.params - tensorizer_warning( - "Deserializing HuggingFace models is not optimized for " - "loading on vLLM, as tensorizer is forced to load to CPU. " - "Consider deserializing a vLLM model instead for faster " - "load times. See the examples/tensorize_vllm_model.py example " - "script for serializing vLLM models.") - - deserializer_args = tensorizer_args.deserializer_params - stream_params = tensorizer_args.stream_params - stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params) - with TensorDeserializer(stream, **deserializer_args, - device="cpu") as state: - for name, param in state.items(): + + logger.info(f"Using model weights format {allow_patterns}") + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + hf_folder = snapshot_download(model_name_or_path, + allow_patterns=allow_patterns, + cache_dir=cache_dir, + tqdm_class=DisabledTqdm, + revision=revision) + return hf_folder + + +def filter_files_not_needed_for_inference( + hf_weights_files: List[str]) -> List[str]: + """ + Exclude files that are not needed for inference. + + See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233 + """ + blacklist = [ + "training_args.bin", + "optimizer.bin", + "optimizer.pt", + "scheduler.pt", + "scaler.pt", + ] + hf_weights_files = [ + f for f in hf_weights_files + if not any(f.endswith(x) for x in blacklist) + ] + return hf_weights_files + + +def np_cache_weights_iterator( + model_name_or_path: str, cache_dir: Optional[str], hf_folder: str, + hf_weights_files: List[str] +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model np files. + + Will dump the model weights to numpy files if they are not already dumped. + """ + # Convert the model weights from torch tensors to numpy arrays for + # faster loading. + np_folder = os.path.join(hf_folder, "np") + os.makedirs(np_folder, exist_ok=True) + weight_names_file = os.path.join(np_folder, "weight_names.json") + # Use file lock to prevent multiple processes from + # dumping the same model weights to numpy at the same time. + with get_lock(model_name_or_path, cache_dir): + if not os.path.exists(weight_names_file): + weight_names = [] + for bin_file in hf_weights_files: + state = torch.load(bin_file, map_location="cpu") + for name, param in state.items(): + param_path = os.path.join(np_folder, name) + with open(param_path, "wb") as f: + np.save(f, param.cpu().detach().numpy()) + weight_names.append(name) + with open(weight_names_file, "w") as f: + json.dump(weight_names, f) + + with open(weight_names_file, "r") as f: + weight_names = json.load(f) + + for name in weight_names: + param_path = os.path.join(np_folder, name) + with open(param_path, "rb") as f: + param = np.load(f) + yield name, torch.from_numpy(param) + + +def safetensors_weights_iterator( + hf_weights_files: List[str] +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model safetensor files.""" + for st_file in hf_weights_files: + with safe_open(st_file, framework="pt") as f: + for name in f.keys(): # noqa: SIM118 + param = f.get_tensor(name) yield name, param + + +def pt_weights_iterator( + hf_weights_files: List[str] +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model bin/pt files.""" + for bin_file in hf_weights_files: + state = torch.load(bin_file, map_location="cpu") + for name, param in state.items(): + yield name, param del state - elif use_safetensors: - for st_file in hf_weights_files: - with safe_open(st_file, framework="pt") as f: - for name in f.keys(): # noqa: SIM118 - param = f.get_tensor(name) - yield name, param - else: - for bin_file in hf_weights_files: - state = torch.load(bin_file, map_location="cpu") - for name, param in state.items(): - yield name, param - del state - torch.cuda.empty_cache() + torch.cuda.empty_cache() def kv_cache_scales_loader( diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 30588aecdebe9..69162b0a92d65 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -19,7 +19,7 @@ # limitations under the License. """Inference-only BaiChuan model compatible with HuggingFace weights.""" import math -from typing import List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -40,9 +40,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -340,19 +339,14 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if name == "lm_head.weight": diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 40966ab33631a..14f325e624f41 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -17,7 +17,7 @@ # limitations under the License. """Inference-only BLOOM model compatible with HuggingFace weights.""" import math -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -35,9 +35,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -298,14 +297,9 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if name == "lm_head.weight": continue if not name.startswith("transformer."): diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 7b46ba306619a..3cdb7a7bca1c1 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -2,7 +2,7 @@ # Adapted from # https://github.com/THUDM/ChatGLM2-6B """Inference-only ChatGLM model compatible with THUDM weights.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -22,9 +22,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import ChatGLMConfig @@ -370,14 +369,9 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_pos_emb.inv_freq" in name: continue if "word_embeddings" in name: diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index aa9b28b676e0b..d80969773e163 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -20,7 +20,7 @@ # This file is based on the LLama model definition file in transformers """PyTorch Cohere model.""" -from typing import List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch import torch.utils.checkpoint @@ -41,10 +41,9 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -335,13 +334,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -352,8 +345,7 @@ def load_weights( ] params_dict = dict(self.named_parameters()) loaded_params = set() - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 49eb7f1b2c185..179094b8fd7aa 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -1,5 +1,5 @@ # coding=utf-8 -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch import torch.nn as nn @@ -18,10 +18,9 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.dbrx import DbrxConfig @@ -391,20 +390,13 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): expert_params_mapping = [( "ws" if weight_name in ["w1", "v1"] else "w2s", f"experts.mlp.{weight_name}", ) for weight_name in ["w1", "v1", "w2"]] params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: for param_name, weight_name in expert_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py index abf4a462871b0..d476630ee6f11 100644 --- a/vllm/model_executor/models/decilm.py +++ b/vllm/model_executor/models/decilm.py @@ -23,16 +23,15 @@ # limitations under the License. """Inference-only DeciLM model compatible with HuggingFace weights.""" -from typing import Optional +from typing import Iterable, Optional, Tuple import torch from transformers import PretrainedConfig from vllm.config import LoRAConfig from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) class DeciLMForCausalLM(LlamaForCausalLM): @@ -65,11 +64,7 @@ def __init__( linear_method=linear_method, lora_config=lora_config) - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -79,8 +74,7 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index c7dd11d07e6da..46101a152ec0d 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Deepseek model.""" -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn @@ -44,9 +44,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -316,6 +315,8 @@ def forward( class DeepseekModel(nn.Module): + fall_back_to_pt_during_load = False + def __init__( self, config: PretrainedConfig, @@ -395,11 +396,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -410,12 +407,7 @@ def load_weights(self, ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, - cache_dir, - load_format, - revision, - fall_back_to_pt=False): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 4f1ebcd5fb43c..25ce239d14662 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -19,7 +19,7 @@ """PyTorch Falcon model.""" import math -from typing import List, Optional, Union +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -40,9 +40,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import RWConfig @@ -399,11 +398,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): total_num_heads = self.config.num_attention_heads if self.config.new_decoder_architecture: total_num_kv_heads = self.config.num_kv_heads @@ -413,8 +408,7 @@ def load_weights(self, total_num_kv_heads = total_num_heads num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if name == "lm_head.weight": # Falcon uses tied embeddings. continue diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index fc1fc35570368..6d01537c5c344 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -15,7 +15,7 @@ # limitations under the License. """Inference-only Gemma model compatible with HuggingFace weights.""" from functools import lru_cache -from typing import List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -36,9 +36,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput logger = init_logger(__name__) @@ -346,11 +345,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -361,8 +356,7 @@ def load_weights(self, ] params_dict = dict(self.named_parameters()) loaded_params = set() - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: for (param_name, shard_name, shard_id) in stacked_params_mapping: if shard_name not in name: continue diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 43f0d47fcb122..850050c7232d0 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -17,7 +17,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-2 model compatible with HuggingFace weights.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -34,9 +34,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -239,14 +238,9 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "lm_head.weight" in name: # GPT-2 ties the weights of the embedding layer and the final # linear layer. diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index cec2d771adfa8..8278ba02514d5 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPTBigCode model compatible with HuggingFace weights.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -35,9 +35,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -260,14 +259,9 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "lm_head.weight" in name: continue if ".attn.bias" in name: diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 5660097652748..7a830d7f9c965 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-J model compatible with HuggingFace weights.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -34,9 +34,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -248,11 +247,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -262,8 +257,7 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "attn.bias" in name or "attn.masked_bias" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 2f9e2171cf114..b946aed92ed35 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-NeoX model compatible with HuggingFace weights.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -34,9 +34,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -262,14 +261,9 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if ("attention.bias" in name or "attention.masked_bias" in name or "rotary_emb.inv_freq" in name): continue diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 6e9cbd3f9f43f..db1da8bdc4fb9 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn @@ -18,9 +18,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -274,19 +273,14 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "w1", 0), ("gate_up_proj", "w3", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index a041b0c9a0452..e7ee749e824e4 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -20,7 +20,7 @@ """Inference-only Jais model compatible with HuggingFace weights.""" import math -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -36,9 +36,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import JAISConfig @@ -303,16 +302,9 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "lm_head.weight" in name: # GPT-2 ties the weights of the embedding layer and the final # linear layer. diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index c86e292e7df1a..016e3b039d1e8 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn @@ -42,10 +42,9 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, kv_cache_scales_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator, - kv_cache_scales_loader) from vllm.sequence import SamplerOutput from vllm.utils import is_hip @@ -376,11 +375,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -390,8 +385,7 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index c2571d0893c8d..314a2792bf167 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -13,10 +13,9 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput _KEYS_TO_MODIFY_MAPPING = { @@ -198,11 +197,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # only doing this for language model part for now. stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -213,8 +208,7 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 49eda9c9a8112..f0d72fafcaf70 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -22,7 +22,7 @@ # limitations under the License. """Inference-only MiniCPM model compatible with HuggingFace weights.""" import math -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn @@ -45,10 +45,9 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -472,11 +471,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -493,8 +488,7 @@ def load_weights(self, for weight_name in ["w1", "w2", "w3"] ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index ff552a9d86536..4d1755f2bbe63 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -43,10 +43,9 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -319,6 +318,8 @@ def forward( class MixtralForCausalLM(nn.Module): + fall_back_to_pt_during_load = False + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -393,11 +394,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -414,12 +411,7 @@ def load_weights(self, ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, - cache_dir, - load_format, - revision, - fall_back_to_pt=False): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 1f0c0e912beea..acd13cc27f159 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import numpy as np import torch @@ -43,9 +43,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -327,6 +326,7 @@ def forward( class MixtralForCausalLM(nn.Module): + fall_back_to_pt_during_load = False def __init__( self, @@ -366,11 +366,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -379,12 +375,7 @@ def load_weights(self, ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, - cache_dir, - load_format, - revision, - fall_back_to_pt=False): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index af4cdce29d085..340f63286739b 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -1,7 +1,7 @@ # coding=utf-8 # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main import math -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch import torch.nn as nn @@ -18,9 +18,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.mpt import MPTConfig @@ -284,14 +283,9 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 3513c72879102..b92003bc0e067 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -36,7 +36,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Inference-only OLMo model compatible with HuggingFace weights.""" -from typing import List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch # this model must need this dependency @@ -56,9 +56,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -348,16 +347,9 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: # attention if ".att" in name: name = name.replace(".att", ".attn.att") diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 3a640850662c0..89263166bca81 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -17,7 +17,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OPT model compatible with HuggingFace weights.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -35,9 +35,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -315,11 +314,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -327,8 +322,7 @@ def load_weights(self, ("qkv_proj", "v_proj", "v"), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "lm_head.weight" in name: continue if name.startswith("decoder."): diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index c606ac027e9d9..bbb9fa5347cc8 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -4,7 +4,7 @@ # Copyright (c) OrionStar Inc. # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE """Inference-only Orion-14B model compatible with HuggingFace weights.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn @@ -22,9 +22,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -280,11 +279,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -294,8 +289,7 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index e91624da90955..f974b78a0fbda 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -35,7 +35,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Inference-only Phi-1.5 model compatible with HuggingFace weights.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -53,9 +53,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -265,11 +264,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -278,8 +273,7 @@ def load_weights(self, ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 6213a2ded65ab..a77da7cb15984 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -4,7 +4,7 @@ # Copyright (c) Alibaba Cloud. # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE """Inference-only QWen model compatible with HuggingFace weights.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn @@ -23,9 +23,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -253,19 +252,14 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "w2", 0), ("gate_up_proj", "w1", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 796e30e633e85..71b906e20ac19 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -22,7 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2 model compatible with HuggingFace weights.""" -from typing import List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -42,9 +42,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -331,11 +330,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -345,8 +340,7 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if self.config.tie_word_embeddings and "lm_head.weight" in name: diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index f920b4f5a40c7..59908bc9ef26a 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -22,7 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2MoE model compatible with HuggingFace weights.""" -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch import torch.nn.functional as F @@ -46,9 +46,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -366,6 +365,8 @@ def forward( class Qwen2MoeForCausalLM(nn.Module): + fall_back_to_pt_during_load = False + def __init__( self, config: PretrainedConfig, @@ -404,11 +405,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -419,12 +416,7 @@ def load_weights(self, ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, - cache_dir, - load_format, - revision, - fall_back_to_pt=False): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 651598b770f13..3e6c2db6f3c65 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -19,7 +19,7 @@ # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json """Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights.""" -from typing import List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -37,9 +37,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -262,11 +261,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -276,8 +271,7 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 76e8e48673413..b90f3da141c2e 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Starcoder2 model.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -36,9 +36,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -274,11 +273,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -287,8 +282,7 @@ def load_weights(self, ] params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 7e9ce9e5c8e15..4e905390c2340 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Xverse model compatible with HuggingFace weights.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn @@ -40,9 +40,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -331,11 +330,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), @@ -344,8 +339,7 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if ("rotary_emb.inv_freq" in name or "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name): diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 5d3d5801c960d..c98a673bfed4b 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -1,8 +1,10 @@ +import os from typing import Optional, Union from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) +from vllm.config import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizers import BaichuanTokenizer @@ -57,9 +59,26 @@ def get_tokenizer( tokenizer_mode: str = "auto", trust_remote_code: bool = False, tokenizer_revision: Optional[str] = None, + download_dir: Optional[str] = None, **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: - """Gets a tokenizer for the given model name via Huggingface.""" + """Gets a tokenizer for the given model name via Huggingface/modelscope.""" + if VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + from modelscope.hub.snapshot_download import snapshot_download + + # Only set the tokenizer here, model will be downloaded on the workers. + if not os.path.exists(tokenizer_name): + tokenizer_path = snapshot_download( + model_id=tokenizer_name, + cache_dir=download_dir, + revision=tokenizer_revision, + # Ignore weights - we only need the tokenizer. + ignore_file_pattern=["*.pt", "*.safetensors", "*.bin"]) + tokenizer_name = tokenizer_path + if tokenizer_mode == "slow": if kwargs.get("use_fast", False): raise ValueError( diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 49e1ad5709f5d..d378e3a90e1e7 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -3,8 +3,8 @@ import torch from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, - SchedulerConfig) +from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata @@ -26,6 +26,7 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + load_config: LoadConfig, lora_config: Optional[LoRAConfig], kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, @@ -36,6 +37,7 @@ def __init__( self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.lora_config = lora_config + self.load_config = load_config self.is_driver_worker = is_driver_worker # model_config can be None in tests/samplers/test_sampler.py. @@ -55,8 +57,10 @@ def __init__( self.model_config.dtype if model_config is not None else None) def load_model(self) -> None: - self.model = get_model(self.model_config, - self.device_config, + self.model = get_model(model_config=self.model_config, + load_config=self.load_config, + device_config=self.device_config, + vision_language_config=None, lora_config=self.lora_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config) diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 41341b063bed7..6610b9c4be876 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -5,8 +5,8 @@ import torch.distributed from vllm.attention import get_attn_backend -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment) @@ -117,6 +117,7 @@ def __init__( scheduler_config: SchedulerConfig, device_config: DeviceConfig, cache_config: CacheConfig, + load_config: LoadConfig, local_rank: int, rank: int, distributed_init_method: str, @@ -129,6 +130,7 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.cache_config = cache_config + self.load_config = load_config self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method @@ -141,6 +143,7 @@ def __init__( parallel_config, scheduler_config, device_config, + load_config=self.load_config, lora_config=self.lora_config, kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 7dbe14ead0976..42c06a1b19361 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -9,9 +9,8 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, get_attn_backend) -from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, - SchedulerConfig, TensorizerConfig, - VisionLanguageConfig) +from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce from vllm.distributed.device_communicators import (custom_all_reduce, pynccl_utils) @@ -108,17 +107,17 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + load_config: LoadConfig, lora_config: Optional[LoRAConfig], kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, vision_language_config: Optional[VisionLanguageConfig] = None, - tensorizer_config: Optional[TensorizerConfig] = None, ): self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.lora_config = lora_config - self.tensorizer_config = tensorizer_config + self.load_config = load_config self.is_driver_worker = is_driver_worker # model_config can be None in tests/samplers/test_sampler.py. @@ -156,13 +155,13 @@ def __init__( def load_model(self) -> None: with CudaMemoryProfiler() as m: self.model = get_model( - self.model_config, - self.device_config, + model_config=self.model_config, + device_config=self.device_config, + load_config=self.load_config, lora_config=self.lora_config, vision_language_config=self.vision_language_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, - tensorizer_config=self.tensorizer_config, ) self.model_memory_usage = m.consumed_memory diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index fff721a80c204..f70a7193effeb 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -6,7 +6,7 @@ SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata -from vllm.model_executor.neuron_model_loader import get_neuron_model +from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import (async_tensor_h2d, is_pin_memory_available, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 82491c6df6616..6a79285f60579 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -6,8 +6,8 @@ import torch import torch.distributed -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, TensorizerConfig, +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, @@ -38,12 +38,12 @@ def __init__( scheduler_config: SchedulerConfig, device_config: DeviceConfig, cache_config: CacheConfig, + load_config: LoadConfig, local_rank: int, rank: int, distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None, - tensorizer_config: Optional[TensorizerConfig] = None, is_driver_worker: bool = False, ) -> None: self.model_config = model_config @@ -55,7 +55,7 @@ def __init__( self.rank = rank self.distributed_init_method = distributed_init_method self.lora_config = lora_config - self.tensorizer_config = tensorizer_config + self.load_config = load_config self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." @@ -70,11 +70,11 @@ def __init__( parallel_config, scheduler_config, device_config, + load_config=load_config, lora_config=self.lora_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, vision_language_config=vision_language_config, - tensorizer_config=tensorizer_config, ) # Uninitialized cache engine. Will be initialized by # initialize_cache.