Skip to content

Commit

Permalink
[Core] Consolidate prompt arguments to LLM engines (vllm-project#4328)
Browse files Browse the repository at this point in the history
Co-authored-by: Roger Wang <ywang@roblox.com>
  • Loading branch information
2 people authored and blinkbear committed Jun 6, 2024
1 parent ac509cb commit b18e919
Show file tree
Hide file tree
Showing 12 changed files with 41 additions and 176 deletions.
25 changes: 13 additions & 12 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.distributed import destroy_model_parallel
from vllm.inputs import TextPrompt
from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.sequence import MultiModalData, SampleLogprobs

Expand Down Expand Up @@ -395,23 +395,24 @@ def generate(
images: Optional[torch.Tensor] = None,
) -> List[Tuple[List[List[int]], List[str]]]:
if images is not None:
assert len(prompts) == len(images)
assert len(prompts) == images.shape[0]

prompt_inputs: List[TextPrompt] = []
prompt_inputs: List[PromptInputs] = []
for i, prompt in enumerate(prompts):
prompt = TextPrompt(prompt=prompt)
if images is not None:
prompt["multi_modal_data"] = MultiModalData(
type=MultiModalData.Type.IMAGE,
data=images[i:i + 1],
)
image = None if images is None else images[i:i + 1]
mm_data = None if image is None else MultiModalData(
type=MultiModalData.Type.IMAGE,
data=image,
)

prompt_inputs.append(prompt)
prompt_inputs.append({
"prompt": prompt,
"multi_modal_data": mm_data,
})

req_outputs = self.model.generate(prompt_inputs,
sampling_params=sampling_params)

outputs: List[Tuple[List[List[int]], List[str]]] = []
outputs = []
for req_output in req_outputs:
prompt_str = req_output.prompt
prompt_ids = req_output.prompt_token_ids
Expand Down
2 changes: 2 additions & 0 deletions tests/core/test_block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def test_append_slot_cow():
inputs={
"prompt": "one two three",
"prompt_token_ids": [1, 2, 3],
"multi_modal_data": None
},
block_size=block_size)

Expand Down Expand Up @@ -524,6 +525,7 @@ def test_sliding_window_multi_seq():
inputs={
"prompt": "one two three",
"prompt_token_ids": [0, 1, 2],
"multi_modal_data": None
},
block_size=block_size)
seq_group = SequenceGroup(request_id="1",
Expand Down
7 changes: 6 additions & 1 deletion tests/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def create_dummy_prompt(
inputs={
"prompt": prompt_str,
"prompt_token_ids": prompt_tokens,
"multi_modal_data": None,
},
block_size=block_size)
seq_group = SequenceGroup(request_id=request_id,
Expand Down Expand Up @@ -102,7 +103,11 @@ def create_seq_group(
for seq_id_offset, output_len in enumerate(seq_output_lens):
seq = Sequence(
seq_id=seq_id_start + seq_id_offset,
inputs={"prompt_token_ids": prompt_token_ids},
inputs={
"prompt": "",
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
block_size=16,
)

Expand Down
94 changes: 3 additions & 91 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,26 +184,6 @@ async def test_single_completion(server, client: openai.AsyncOpenAI,
completion.choices[0].text) >= 5


@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
async def test_no_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=None,
)
choice = completion.choices[0]
assert choice.logprobs is None


@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
Expand All @@ -227,70 +207,6 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
assert len(choice.logprobs.top_logprobs[0]) <= 1


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_some_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=5,
)
choice = completion.choices[0]
assert choice.logprobs is not None
assert choice.logprobs.token_logprobs is not None
assert choice.logprobs.top_logprobs is not None
assert len(choice.logprobs.top_logprobs[0]) <= 6


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_too_many_completion_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):

with pytest.raises(
(openai.BadRequestError, openai.APIError)): # test using token IDs
await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=6,
)
...
with pytest.raises(
(openai.BadRequestError, openai.APIError)): # test using token IDs
stream = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=6,
stream=True,
)
async for chunk in stream:
...

# the server should still work afterwards
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
completion = completion.choices[0].text
assert completion is not None and len(completion) >= 0


@pytest.mark.asyncio
@pytest.mark.parametrize(
# just test 1 lora hereafter
Expand Down Expand Up @@ -339,13 +255,9 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,


@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
async def test_no_logprobs_chat(server, client: openai.AsyncOpenAI,
model_name: str):
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand Down
1 change: 1 addition & 0 deletions tests/test_cache_block_hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
inputs={
"prompt": prompt,
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
block_size=block_size,
eos_token_id=tokenizer.tokenizer.eos_token_id,
Expand Down
57 changes: 1 addition & 56 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,9 @@
import asyncio
import sys
from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol,
Tuple, TypeVar)

import pytest

from vllm.utils import deprecate_kwargs, merge_async_iterators
from vllm.utils import deprecate_kwargs

from .utils import error_on_warning

if sys.version_info < (3, 10):
if TYPE_CHECKING:
_AwaitableT = TypeVar("_AwaitableT", bound=Awaitable[Any])
_AwaitableT_co = TypeVar("_AwaitableT_co",
bound=Awaitable[Any],
covariant=True)

class _SupportsSynchronousAnext(Protocol[_AwaitableT_co]):

def __anext__(self) -> _AwaitableT_co:
...

def anext(i: "_SupportsSynchronousAnext[_AwaitableT]", /) -> "_AwaitableT":
return i.__anext__()


@pytest.mark.asyncio
async def test_merge_async_iterators():

async def mock_async_iterator(idx: int) -> AsyncIterator[str]:
try:
while True:
yield f"item from iterator {idx}"
await asyncio.sleep(0.1)
except asyncio.CancelledError:
pass

iterators = [mock_async_iterator(i) for i in range(3)]
merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators(
*iterators)

async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
async for idx, output in generator:
print(f"idx: {idx}, output: {output}")

task = asyncio.create_task(stream_output(merged_iterator))
await asyncio.sleep(0.5)
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task

for iterator in iterators:
try:
await asyncio.wait_for(anext(iterator), 1)
except StopAsyncIteration:
# All iterators should be cancelled and print this message.
print("Iterator was cancelled normally")
except (Exception, asyncio.CancelledError) as e:
raise AssertionError() from e


def test_deprecate_kwargs_always():

Expand Down
1 change: 1 addition & 0 deletions tests/tokenization/test_detokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def create_sequence(prompt_token_ids=None):
inputs={
"prompt": "<s>",
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
block_size=16,
)
Expand Down
2 changes: 2 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ class AsyncLLMEngine:
generate method when there are requests in the waiting queue. The generate
method yields the outputs from the :class:`LLMEngine` to the caller.
NOTE: For the comprehensive list of arguments, see :class:`LLMEngine`.
Args:
worker_use_ray: Whether to use Ray for model workers. Required for
distributed execution. Should be the same as
Expand Down
4 changes: 2 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ class LLMEngine:
The :class:`~vllm.LLM` class wraps this class for offline batched inference
and the :class:`AsyncLLMEngine` class wraps this class for online serving.
The config arguments are derived from :class:`~vllm.EngineArgs`. (See
:ref:`engine_args`)
NOTE: The config arguments are derived from the :class:`~vllm.EngineArgs`
class. For the comprehensive list of arguments, see :ref:`engine_args`.
Args:
model_config: The configuration related to the LLM model.
Expand Down
16 changes: 6 additions & 10 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ class LLM:
this class generates texts from the model, using an intelligent batching
mechanism and efficient memory management.
NOTE: This class is intended to be used for offline inference. For online
serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
NOTE: For the comprehensive list of arguments, see
:class:`~vllm.EngineArgs`.
Args:
model: The name or path of a HuggingFace Transformers model.
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
Expand Down Expand Up @@ -276,11 +282,6 @@ def generate(
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
if self.llm_engine.model_config.embedding_mode:
raise ValueError(
"LLM.generate() is only supported for generation models "
"(XForCausalLM).")

if prompt_token_ids is not None or multi_modal_data is not None:
inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
Expand Down Expand Up @@ -425,11 +426,6 @@ def encode(
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
if not self.llm_engine.model_config.embedding_mode:
raise ValueError(
"LLM.encode() is only supported for embedding models (XModel)."
)

if prompt_token_ids is not None or multi_modal_data is not None:
inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
Expand Down
4 changes: 2 additions & 2 deletions vllm/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,5 +126,5 @@ class TextTokensPrompt(TypedDict):

class LLMInputs(TypedDict):
prompt_token_ids: List[int]
prompt: NotRequired[Optional[str]]
multi_modal_data: NotRequired[Optional["MultiModalData"]]
prompt: Optional[str]
multi_modal_data: Optional["MultiModalData"]
4 changes: 2 additions & 2 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,15 +255,15 @@ def __init__(

@property
def prompt(self) -> Optional[str]:
return self.inputs.get("prompt")
return self.inputs["prompt"]

@property
def prompt_token_ids(self) -> List[int]:
return self.inputs["prompt_token_ids"]

@property
def multi_modal_data(self) -> Optional["MultiModalData"]:
return self.inputs.get("multi_modal_data")
return self.inputs["multi_modal_data"]

@property
def lora_int_id(self) -> int:
Expand Down

0 comments on commit b18e919

Please sign in to comment.