Skip to content

Commit

Permalink
[Core] Consolidate prompt arguments to LLM engines (vllm-project#4328)
Browse files Browse the repository at this point in the history
Co-authored-by: Roger Wang <ywang@roblox.com>
  • Loading branch information
DarkLight1337 and ywang96 authored May 28, 2024
1 parent 290f4ad commit 5ae5ed1
Show file tree
Hide file tree
Showing 43 changed files with 1,404 additions and 439 deletions.
9 changes: 6 additions & 3 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ steps:
mirror_hardwares: [amd]

commands:
# these tests have to be separated, because each one will allocate all posible GPU memory
- pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py
- pytest -v -s entrypoints/test_server_oot_registration.py
- pytest -v -s test_inputs.py
- pytest -v -s entrypoints -m llm
- pytest -v -s entrypoints -m openai

- label: Examples Test
working_dir: "/vllm-workspace/examples"
Expand Down Expand Up @@ -110,6 +110,9 @@ steps:
mirror_hardwares: [amd]
command: pytest -v -s test_logits_processor.py

- label: Utils Test
command: pytest -v -s test_utils.py

- label: Worker Test
mirror_hardwares: [amd]
command: pytest -v -s worker
Expand Down
11 changes: 7 additions & 4 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import json
import time
from pathlib import Path
from typing import Optional
from typing import List, Optional

import numpy as np
import torch
from tqdm import tqdm

from vllm import LLM, SamplingParams
from vllm.inputs import PromptStrictInputs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS


Expand Down Expand Up @@ -48,7 +49,9 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size,
args.input_len))
dummy_prompt_token_ids = dummy_prompt_token_ids.tolist()
dummy_inputs: List[PromptStrictInputs] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]

def run_to_completion(profile_dir: Optional[str] = None):
if profile_dir:
Expand All @@ -59,13 +62,13 @@ def run_to_completion(profile_dir: Optional[str] = None):
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p:
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
llm.generate(dummy_inputs,
sampling_params=sampling_params,
use_tqdm=False)
print(p.key_averages())
else:
start_time = time.perf_counter()
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
llm.generate(dummy_inputs,
sampling_params=sampling_params,
use_tqdm=False)
end_time = time.perf_counter()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
LLM Class
==========
=========

.. autoclass:: vllm.LLM
:members:
Expand Down
14 changes: 14 additions & 0 deletions docs/source/dev/offline_inference/llm_inputs.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
LLM Inputs
==========

.. autodata:: vllm.inputs.PromptStrictInputs

.. autoclass:: vllm.inputs.TextPrompt
:show-inheritance:
:members:
:member-order: bysource

.. autoclass:: vllm.inputs.TokensPrompt
:show-inheritance:
:members:
:member-order: bysource
8 changes: 8 additions & 0 deletions docs/source/dev/offline_inference/offline_index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Offline Inference
=================================

.. toctree::
:maxdepth: 1

llm
llm_inputs
File renamed without changes.
11 changes: 3 additions & 8 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,6 @@ Documentation
getting_started/quickstart
getting_started/examples/examples_index

.. toctree::
:maxdepth: 1
:caption: Offline Inference

offline_inference/llm
offline_inference/sampling_params

.. toctree::
:maxdepth: 1
:caption: Serving
Expand Down Expand Up @@ -108,7 +101,9 @@ Documentation
.. toctree::
:maxdepth: 2
:caption: Developer Documentation


dev/sampling_params
dev/offline_inference/offline_index
dev/engine/engine_index
dev/kernel/paged_attention
dev/dockerfile/dockerfile
Expand Down
4 changes: 2 additions & 2 deletions docs/source/serving/openai_compatible_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ completion = client.chat.completions.create(
```

### Extra Parameters for Chat API
The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported.
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.

```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
Expand All @@ -65,7 +65,7 @@ The following extra parameters are supported:
```

### Extra Parameters for Completions API
The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported.
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.

```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
Expand Down
25 changes: 16 additions & 9 deletions examples/llava_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@ def run_llava_pixel_values():
"\nUSER: What is the content of this image?\nASSISTANT:")

# This should be provided by another online or offline component.
images = torch.load("images/stop_sign_pixel_values.pt")
image = torch.load("images/stop_sign_pixel_values.pt")

outputs = llm.generate({
"prompt":
prompt,
"multi_modal_data":
MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
})

outputs = llm.generate(prompt,
multi_modal_data=MultiModalData(
type=MultiModalData.Type.IMAGE, data=images))
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
Expand All @@ -46,11 +50,14 @@ def run_llava_image_features():
"\nUSER: What is the content of this image?\nASSISTANT:")

# This should be provided by another online or offline component.
images = torch.load("images/stop_sign_image_features.pt")

outputs = llm.generate(prompt,
multi_modal_data=MultiModalData(
type=MultiModalData.Type.IMAGE, data=images))
image = torch.load("images/stop_sign_image_features.pt")

outputs = llm.generate({
"prompt":
prompt,
"multi_modal_data":
MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
Expand Down
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,10 @@ skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
[tool.isort]
use_parentheses = true
skip_gitignore = true

[tool.pytest.ini_options]
markers = [
"skip_global_cleanup",
"llm: run tests for vLLM API only",
"openai: run tests for OpenAI API only",
]
2 changes: 1 addition & 1 deletion tests/async_engine/test_async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async def step_async(self):
return [RequestOutput(
request_id=self.request_id)] if self.request_id else []

async def encode_request_async(self, *args, **kwargs):
async def process_model_inputs_async(self, *args, **kwargs):
pass

def generate(self, request_id):
Expand Down
2 changes: 1 addition & 1 deletion tests/async_engine/test_openapi_server_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def server():
ray.shutdown()


@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def client():
client = openai.AsyncOpenAI(
base_url="http://localhost:8000/v1",
Expand Down
23 changes: 17 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.distributed import destroy_model_parallel
from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.sequence import MultiModalData

Expand Down Expand Up @@ -402,12 +403,22 @@ def generate(
) -> List[Tuple[List[int], str]]:
if images is not None:
assert len(prompts) == images.shape[0]
req_outputs = self.model.generate(
prompts,
sampling_params=sampling_params,
multi_modal_data=MultiModalData(type=MultiModalData.Type.IMAGE,
data=images)
if images is not None else None)

prompt_inputs: List[PromptInputs] = []
for i, prompt in enumerate(prompts):
image = None if images is None else images[i:i + 1]
mm_data = None if image is None else MultiModalData(
type=MultiModalData.Type.IMAGE,
data=image,
)

prompt_inputs.append({
"prompt": prompt,
"multi_modal_data": mm_data,
})

req_outputs = self.model.generate(prompt_inputs,
sampling_params=sampling_params)
outputs = []
for req_output in req_outputs:
prompt_str = req_output.prompt
Expand Down
15 changes: 12 additions & 3 deletions tests/core/test_block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,11 @@ def test_append_slot_cow():

# Allocate prompt to gpu block. There is one slot left in the block.
prompt = Sequence(seq_id=1,
prompt="one two three",
prompt_token_ids=[1, 2, 3],
inputs={
"prompt": "one two three",
"prompt_token_ids": [1, 2, 3],
"multi_modal_data": None
},
block_size=block_size)

# Fork the sequence, such that a COW will be required when we append a new
Expand Down Expand Up @@ -304,7 +307,13 @@ def test_sliding_window_multi_seq():

assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks

parent = Sequence(1, "one two three", [0, 1, 2], block_size)
parent = Sequence(seq_id=1,
inputs={
"prompt": "one two three",
"prompt_token_ids": [0, 1, 2],
"multi_modal_data": None
},
block_size=block_size)
seq_group = SequenceGroup(request_id="1",
seqs=[parent],
arrival_time=time.time(),
Expand Down
15 changes: 12 additions & 3 deletions tests/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@ def create_dummy_prompt(
# and prompt "0 ... block_size".
prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens])
prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size)
prompt = Sequence(int(request_id),
inputs={
"prompt": prompt_str,
"prompt_token_ids": prompt_tokens,
"multi_modal_data": None,
},
block_size=block_size)
seq_group = SequenceGroup(request_id=request_id,
seqs=[prompt],
arrival_time=time.time(),
Expand Down Expand Up @@ -51,8 +57,11 @@ def create_seq_group(
for seq_id_offset, output_len in enumerate(seq_output_lens):
seq = Sequence(
seq_id=seq_id_start + seq_id_offset,
prompt="",
prompt_token_ids=prompt_token_ids,
inputs={
"prompt": "",
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
block_size=16,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/engine/test_skip_tokenizer_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_skip_tokenizer_initialization(model: str):
with pytest.raises(ValueError) as err:
llm.generate("abc", sampling_params)
assert "prompts must be None if" in str(err.value)
outputs = llm.generate(prompt_token_ids=[[1, 2, 3]],
outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
sampling_params=sampling_params)
assert len(outputs) > 0
completions = outputs[0].outputs
Expand Down
4 changes: 4 additions & 0 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import asyncio
from dataclasses import dataclass

import pytest

from vllm.entrypoints.openai.serving_chat import OpenAIServingChat

MODEL_NAME = "openai-community/gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}"

pytestmark = pytest.mark.openai


@dataclass
class MockModelConfig:
Expand Down
2 changes: 2 additions & 0 deletions tests/entrypoints/test_guided_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")

pytestmark = pytest.mark.openai


def test_guided_logits_processors():
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
Expand Down
Loading

0 comments on commit 5ae5ed1

Please sign in to comment.