Skip to content

Commit

Permalink
[Core] Refactor model loading code (#4097)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 authored Apr 16, 2024
1 parent 0543476 commit 69e1d2f
Show file tree
Hide file tree
Showing 67 changed files with 1,064 additions and 973 deletions.
2 changes: 1 addition & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ steps:
parallelism: 4

- label: Tensorizer Test
command: apt-get install curl libsodium23 && pytest -v -s tensorizer
command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader

- label: Metrics Test
command: pytest -v -s metrics
Expand Down
4 changes: 2 additions & 2 deletions examples/fp8/extract_scales.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from vllm.model_executor.layers.quantization.schema import QuantParamSchema


# Adapted from vllm/model_executor/weight_utils.py
# Adapted from vllm/model_executor/model_loader/weight_utils.py
# The main differences are that we add the NPZ format and simplify
# its functionality drastically for our purposes (e.g. we assume that
# the quantized model exists locally and there is no need to download it)
Expand Down Expand Up @@ -71,7 +71,7 @@ def _prepare_hf_weights(
return hf_weights_files, use_safetensors


# Adapted from vllm/model_executor/weight_utils.py
# Adapted from vllm/model_executor/model_loader/weight_utils.py
def _hf_tensorfile_iterator(filename: str, load_format: str,
use_safetensors: bool):
if load_format == "npz":
Expand Down
2 changes: 1 addition & 1 deletion examples/tensorize_vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from vllm.distributed import initialize_model_parallel
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.model_executor.model_loader.tensorizer import TensorizerArgs
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.tensorizer_loader import TensorizerArgs

# yapf conflicts with isort for this docstring
# yapf: disable
Expand Down
10 changes: 5 additions & 5 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,11 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module:
cleanup()
get_model_old = get_model

def get_model_patched(model_config, device_config, **kwargs):
return get_model_old(model_config,
device_config,
lora_config=LoRAConfig(max_loras=4,
max_lora_rank=8))
def get_model_patched(*, model_config, device_config, **kwargs):
kwargs["lora_config"] = LoRAConfig(max_loras=4, max_lora_rank=8)
return get_model_old(model_config=model_config,
device_config=device_config,
**kwargs)

with patch("vllm.worker.model_runner.get_model", get_model_patched):
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
Expand Down
10 changes: 6 additions & 4 deletions tests/lora/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import tempfile
from unittest.mock import patch

from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig)
from vllm.lora.models import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.worker.worker import Worker
Expand All @@ -18,12 +18,14 @@ def test_worker_apply_lora(sql_lora_files):
"meta-llama/Llama-2-7b-hf",
tokenizer_mode="auto",
trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0,
dtype="float16",
revision=None,
),
load_config=LoadConfig(
download_dir=None,
load_format="dummy",
),
parallel_config=ParallelConfig(1, 1, False),
scheduler_config=SchedulerConfig(32, 32, 32),
device_config=DeviceConfig("cuda"),
Expand Down
2 changes: 1 addition & 1 deletion tests/model_executor/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import huggingface_hub.constants
import pytest

from vllm.model_executor.weight_utils import enable_hf_transfer
from vllm.model_executor.model_loader.weight_utils import enable_hf_transfer


def test_hf_transfer_auto_activation():
Expand Down
4 changes: 0 additions & 4 deletions tests/quantization/test_autogptq_marlin_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ def test_auto_gptq(model_quant_type: str, ) -> None:
model_path,
tokenizer_mode="auto",
trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0,
dtype="float16",
revision=None,
Expand All @@ -49,8 +47,6 @@ def test_auto_gptq(model_quant_type: str, ) -> None:
model_path,
tokenizer_mode="auto",
trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0,
dtype="float16",
revision=None,
Expand Down
14 changes: 12 additions & 2 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ def _prepare_test(
1e-2,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(fake_logits)
model_runner = ModelRunner(None, None, None, None, None)
model_runner = ModelRunner(model_config=None,
parallel_config=None,
scheduler_config=None,
device_config=None,
load_config=None,
lora_config=None)
return input_tensor, fake_logits, sampler, model_runner


Expand Down Expand Up @@ -591,7 +596,12 @@ def test_sampler_top_k_top_p(seed: int, device: str):
device=input_tensor.device,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(fake_logits)
model_runner = ModelRunner(None, None, None, None, None)
model_runner = ModelRunner(model_config=None,
parallel_config=None,
scheduler_config=None,
device_config=None,
load_config=None,
lora_config=None)

generation_model = GenerationMixin()
generation_config = GenerationConfig(top_k=top_k,
Expand Down
1 change: 1 addition & 0 deletions tests/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def create_worker(cls: type,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from vllm.distributed import initialize_model_parallel
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.model_executor.model_loader.tensorizer import TensorizerArgs
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.tensorizer_loader import TensorizerArgs

# yapf conflicts with isort for this docstring
# yapf: disable
Expand Down Expand Up @@ -74,7 +74,7 @@ def parse_args():
"extremely quickly. Tensor encryption and decryption is "
"also supported, although libsodium must be installed to "
"use it.")
parser = EngineArgs.add_cli_args(parser)
parser = TensorizerArgs.add_cli_args(EngineArgs.add_cli_args(parser))
subparsers = parser.add_subparsers(dest='command')

serialize_parser = subparsers.add_parser(
Expand Down
Loading

0 comments on commit 69e1d2f

Please sign in to comment.