Skip to content

Commit

Permalink
Fix/async chat serving (#2727)
Browse files Browse the repository at this point in the history
  • Loading branch information
schoennenbeck authored May 3, 2024
1 parent 7e65477 commit f8e7add
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 21 deletions.
25 changes: 14 additions & 11 deletions tests/async_engine/test_chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,13 @@ class MockServingChat:
tokenizer: MockTokenizer


def test_load_chat_template():
@pytest.mark.asyncio
async def test_load_chat_template():
# Testing chatml template
tokenizer = MockTokenizer()
mock_serving_chat = MockServingChat(tokenizer)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=chatml_jinja_path)
await OpenAIServingChat._load_chat_template(
mock_serving_chat, chat_template=chatml_jinja_path)

template_content = tokenizer.chat_template

Expand All @@ -76,26 +77,28 @@ def test_load_chat_template():
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501


def test_no_load_chat_template_filelike():
@pytest.mark.asyncio
async def test_no_load_chat_template_filelike():
# Testing chatml template
template = "../../examples/does_not_exist"
tokenizer = MockTokenizer()

mock_serving_chat = MockServingChat(tokenizer)

with pytest.raises(ValueError, match="looks like a file path"):
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
await OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)


def test_no_load_chat_template_literallike():
@pytest.mark.asyncio
async def test_no_load_chat_template_literallike():
# Testing chatml template
template = "{{ messages }}"
tokenizer = MockTokenizer()

mock_serving_chat = MockServingChat(tokenizer)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
await OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
template_content = tokenizer.chat_template

assert template_content == template
Expand All @@ -110,8 +113,8 @@ async def test_get_gen_prompt(model, template, add_generation_prompt,
# Initialize the tokenizer
tokenizer = get_tokenizer(tokenizer_name=model)
mock_serving_chat = MockServingChat(tokenizer)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
await OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)

# Create a mock request object using keyword arguments
mock_request = ChatCompletionRequest(
Expand Down
37 changes: 37 additions & 0 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import asyncio
from dataclasses import dataclass

from vllm.entrypoints.openai.serving_chat import OpenAIServingChat

MODEL_NAME = "openai-community/gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}"


@dataclass
class MockModelConfig:
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
max_model_len = 100
tokenizer_revision = None


@dataclass
class MockEngine:

async def get_model_config(self):
return MockModelConfig


async def _async_serving_chat_init():
serving_completion = OpenAIServingChat(MockEngine(),
served_model_names=[MODEL_NAME],
response_role="assistant",
chat_template=CHAT_TEMPLATE)
return serving_completion


def test_async_serving_chat_init():
serving_completion = asyncio.run(_async_serving_chat_init())
assert serving_completion.tokenizer is not None
assert serving_completion.tokenizer.chat_template == CHAT_TEMPLATE
2 changes: 1 addition & 1 deletion tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def server(zephyr_lora_files):
ray.shutdown()


@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def client():
client = openai.AsyncOpenAI(
base_url="http://localhost:8000/v1",
Expand Down
12 changes: 9 additions & 3 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import codecs
import time
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List,
Expand Down Expand Up @@ -40,9 +41,11 @@ def __init__(self,
chat_template: Optional[str] = None):
super().__init__(engine=engine,
served_model_names=served_model_names,
lora_modules=lora_modules)
lora_modules=lora_modules,
await_post_init=self._load_chat_template(
chat_template=chat_template))

self.response_role = response_role
self._load_chat_template(chat_template)

def _parse_chat_message_content(
self,
Expand Down Expand Up @@ -356,7 +359,10 @@ async def chat_completion_full_generator(

return response

def _load_chat_template(self, chat_template: Optional[str]):
async def _load_chat_template(self, chat_template: Optional[str]):
while self.tokenizer is None:
# Give the parent class time to load the tokenizer
await asyncio.sleep(0.1)
tokenizer = self.tokenizer

if chat_template is not None:
Expand Down
18 changes: 12 additions & 6 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
from dataclasses import dataclass
from http import HTTPStatus
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union

from pydantic import Field
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
Expand All @@ -29,8 +29,11 @@ class LoRAModulePath:

class OpenAIServing:

def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]]):
def __init__(self,
engine: AsyncLLMEngine,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]],
await_post_init: Optional[Awaitable[Any]] = None):
self.engine = engine
self.served_model_names = served_model_names
if lora_modules is None:
Expand All @@ -56,12 +59,12 @@ def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str],
if event_loop is not None and event_loop.is_running():
# If the current is instanced by Ray Serve,
# there is already a running event loop
event_loop.create_task(self._post_init())
event_loop.create_task(self._post_init(await_post_init))
else:
# When using single vLLM without engine_use_ray
asyncio.run(self._post_init())
asyncio.run(self._post_init(await_post_init))

async def _post_init(self):
async def _post_init(self, await_post_init):
engine_model_config = await self.engine.get_model_config()
self.max_model_len = engine_model_config.max_model_len

Expand All @@ -73,6 +76,9 @@ async def _post_init(self):
trust_remote_code=engine_model_config.trust_remote_code,
truncation_side="left")

if await_post_init is not None:
await await_post_init

async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [
Expand Down

0 comments on commit f8e7add

Please sign in to comment.