Skip to content

add llava model support #2153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ COPY vllm vllm
EXPOSE 8000
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.api_server"]


FROM vllm-base as vllm-llava

COPY --from=build /workspace/vllm/*.so /workspace/vllm/
COPY vllm vllm

ENTRYPOINT ["python3", "-m", "vllm.entrypoints.llava_server"]

# openai api server alternative
FROM vllm-base AS vllm-openai
# install additional dependencies for openai api server
Expand Down
1 change: 1 addition & 0 deletions requirements-rocm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ fastapi
uvicorn[standard]
pydantic == 1.10.13 # Required for OpenAI server.
aioprometheus[starlette]
pillow # Rqueired for image processing.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ fastapi
uvicorn[standard]
pydantic == 1.10.13 # Required for OpenAI server.
aioprometheus[starlette]
pillow # Rqueired for image processing.
2 changes: 2 additions & 0 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_cluster
from vllm.entrypoints.llm import LLM
from vllm.entrypoints.llava_llm import LLaVA
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams

__version__ = "0.2.6"

__all__ = [
"LLM",
"LLaVA",
"SamplingParams",
"RequestOutput",
"CompletionOutput",
Expand Down
139 changes: 139 additions & 0 deletions vllm/engine/async_llava_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from vllm.engine.llava_engine import LLaVAEngine
from vllm.engine.async_llm_engine import AsyncLLMEngine, _AsyncLLMEngine, AsyncStream, AsyncEngineDeadError
import asyncio
import time
from typing import (List, Optional, Type, AsyncIterator)
from PIL import Image
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams

logger = init_logger(__name__)


class _AsyncLLaVAEngine(LLaVAEngine, _AsyncLLMEngine):

async def step_async(self) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible.

This function performs one decoding iteration of the engine. It first
schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.

This rewriting of the function is to sent the runner_method to model
runner then knowing that it is a llava model. It won't be needed in the
future when we merge the execute_llava_model function to the
execute_model.
"""
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()

# Execute the model.
output = (await self._run_workers_async(
"execute_model",
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
runner_method="execute_llava_model",
)) if not scheduler_outputs.is_empty() else []

return self._process_model_outputs(output, scheduler_outputs)


class AsyncLLaVAEngine(AsyncLLMEngine):

_engine_class: Type[_AsyncLLaVAEngine] = _AsyncLLaVAEngine

async def add_request(
self,
request_id: str,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
images: Optional[List[Image.Image]] = None) -> AsyncStream:
if self.log_requests:
shortened_prompt = prompt
shortened_token_ids = prompt_token_ids
if self.max_log_len is not None:
if shortened_prompt is not None:
shortened_prompt = shortened_prompt[:self.max_log_len]
if shortened_token_ids is not None:
shortened_token_ids = shortened_token_ids[:self.
max_log_len]
logger.info(f"Received request {request_id}: "
f"prompt: {shortened_prompt!r}, "
f"sampling params: {sampling_params}, "
f"prompt token ids: {shortened_token_ids}."
f"images: {0 if images is None else len(images)}")

if not self.is_running:
if self.start_engine_loop:
self.start_background_loop()
else:
raise AsyncEngineDeadError(
"Background loop is not running. If it was running, "
"inspect the output to find the stacktrace of the "
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")

stream = self._request_tracker.add_request(
request_id,
prompt=prompt,
sampling_params=sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
images=images)

return stream

async def generate(
self,
prompt: Optional[str],
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None,
images: Optional[List[Image.Image]] = None
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.

Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.

Args:
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
images: A list of PIL images for the prompt. It supports multiple
images, although most llava models are trained with only one image.

Yields:
The output `RequestOutput` objects from the LLMEngine for the
request.
"""
# Preprocess the request.
# This should not be used for logging, as it is monotonic time.
arrival_time = time.monotonic()

try:
stream = await self.add_request(request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
images=images)

async for request_output in stream:
yield request_output
except (Exception, asyncio.CancelledError) as e:
# If there is an exception or coroutine is cancelled, abort the
# request.
self._abort(request_id)
raise e
118 changes: 118 additions & 0 deletions vllm/engine/llava_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from vllm.engine.llm_engine import LLMEngine
from transformers import CLIPImageProcessor
import time
from functools import partial
from typing import List, Optional

from vllm.engine.ray_utils import ray
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (Sequence, SequenceGroup)
from PIL import Image
import numpy as np


class LLaVAEngine(LLMEngine):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.image_processor = CLIPImageProcessor.from_pretrained(
self.model_config.tokenizer)

def add_request(
self,
request_id: str,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
images: Optional[List[Image.Image]] = None,
) -> None:
"""Add a request to the engine's request pool.

The request is added to the request pool and will be processed by the
scheduler as `engine.step()` is called. The exact scheduling policy is
determined by the scheduler.

Args:
request_id: The unique ID of the request.
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
sampling_params: The sampling parameters for text generation.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
images: A list of PIL images for the prompt. It supports multiple
images, although most llava models are trained with only one image.
"""
if arrival_time is None:
arrival_time = time.monotonic()
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = self.tokenizer.encode(prompt)

# process images
extra_data = None
if images is not None and len(images) > 0:
pixel_values = self.image_processor(
images, return_tensors="pt")['pixel_values']
extra_data = {'pixel_values': pixel_values}
else:
pixel_values = None

# Check the validation of the imput. And expand each image token to the
# number of tokens per image. So the scheduler can allocate proper resources.
num_workers = len(self.workers)
# random select a worker
worker = self.workers[np.random.randint(num_workers)]
if self.parallel_config.worker_use_ray:
execute_model_methord = partial(worker.execute_method.remote,
'execute_model_methord')
else:
execute_model_methord = worker.execute_model_methord
outputs = execute_model_methord('prepare_promt', prompt_token_ids,
pixel_values)
if self.parallel_config.worker_use_ray:
outputs = ray.get(outputs)
processed_token_ids = outputs
prompt_token_ids = processed_token_ids.tolist()

# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
seq = Sequence(seq_id,
prompt,
prompt_token_ids,
block_size,
extra_data=extra_data)

# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params,
arrival_time)

# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)

def step(self) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results.

This function performs one decoding iteration of the engine. It first
schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()

# Execute the model.
output = self._run_workers(
"execute_model",
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
runner_method="execute_llava_model",
) if not scheduler_outputs.is_empty() else []

return self._process_model_outputs(output, scheduler_outputs)
Loading