Skip to content

Commit

Permalink
[CI/Build] Avoid downloading all HF files in RemoteOpenAIServer (vl…
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored Aug 26, 2024
1 parent 0b76999 commit 029c71d
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
40 changes: 26 additions & 14 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@

import openai
import requests
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from typing_extensions import ParamSpec

from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip

Expand Down Expand Up @@ -60,39 +61,50 @@ class RemoteOpenAIServer:

def __init__(self,
model: str,
cli_args: List[str],
vllm_serve_args: List[str],
*,
env_dict: Optional[Dict[str, str]] = None,
auto_port: bool = True,
max_wait_seconds: Optional[float] = None) -> None:
if not model.startswith("/"):
# download the model if it's not a local path
# to exclude the model download time from the server start time
snapshot_download(model)
if auto_port:
if "-p" in cli_args or "--port" in cli_args:
raise ValueError("You have manually specified the port"
if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
raise ValueError("You have manually specified the port "
"when `auto_port=True`.")

cli_args = cli_args + ["--port", str(get_open_port())]
# Don't mutate the input args
vllm_serve_args = vllm_serve_args + [
"--port", str(get_open_port())
]

parser = FlexibleArgumentParser(
description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
args = parser.parse_args(cli_args)
args = parser.parse_args(["--model", model, *vllm_serve_args])
self.host = str(args.host or 'localhost')
self.port = int(args.port)

# download the model before starting the server to avoid timeout
is_local = os.path.isdir(model)
if not is_local:
engine_args = AsyncEngineArgs.from_cli_args(args)
engine_config = engine_args.create_engine_config()
dummy_loader = DefaultModelLoader(engine_config.load_config)
dummy_loader._prepare_weights(engine_config.model_config.model,
engine_config.model_config.revision,
fall_back_to_pt=True)

env = os.environ.copy()
# the current process might initialize cuda,
# to be safe, we should use spawn method
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
if env_dict is not None:
env.update(env_dict)
self.proc = subprocess.Popen(["vllm", "serve"] + [model] + cli_args,
env=env,
stdout=sys.stdout,
stderr=sys.stderr)
self.proc = subprocess.Popen(
["vllm", "serve", model, *vllm_serve_args],
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
)
max_wait_seconds = max_wait_seconds or 240
self._wait_for_server(url=self.url_for("health"),
timeout=max_wait_seconds)
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ def from_cli_args(cls, args: argparse.Namespace):
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return engine_args

def create_engine_config(self, ) -> EngineConfig:
def create_engine_config(self) -> EngineConfig:
# gguf file needs a specific model loader and doesn't use hf_repo
if self.model.endswith(".gguf"):
self.quantization = self.load_format = "gguf"
Expand Down

0 comments on commit 029c71d

Please sign in to comment.