Skip to content

Add LLaVA support #775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
from vllm.entrypoints.llm import LLM
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.entrypoints.mllm import MLLM

__version__ = "0.1.3"

__all__ = [
"LLM",
"MLLM",
"SamplingParams",
"RequestOutput",
"CompletionOutput",
Expand Down
176 changes: 176 additions & 0 deletions vllm/engine/mllm_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import base64
import time
from io import BytesIO
import requests
from PIL import Image

from vllm import LLMEngine, SamplingParams
from typing import List, Optional
from vllm.logger import init_logger
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from vllm.core.scheduler import Scheduler
from vllm.engine.ray_utils import DeviceID, ray
from vllm.sequence import Sequence, SequenceGroup
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import Counter
from vllm.worker.worker import MWorker

logger = init_logger(__name__)

DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"


class MLLMEngine(LLMEngine):
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
distributed_init_method: str,
stage_devices: List[List[DeviceID]],
log_stats: bool,
) -> None:
logger.info(
"Initializing an LLM engine with config: "
f"model={model_config.model!r}, "
f"tokenizer={model_config.tokenizer!r}, "
f"tokenizer_mode={model_config.tokenizer_mode}, "
f"trust_remote_code={model_config.trust_remote_code}, "
f"dtype={model_config.dtype}, "
f"use_dummy_weights={model_config.use_dummy_weights}, "
f"download_dir={model_config.download_dir!r}, "
f"use_np_weights={model_config.use_np_weights}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"seed={model_config.seed})")
# TODO(woosuk): Print more configs in debug mode.

self.model_config = model_config
self.cache_config = cache_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.log_stats = log_stats
self._verify_args()

self.tokenizer = get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code)
self.seq_counter = Counter()

# Create the parallel GPU workers.
self.workers: List[MWorker] = []
assert len(stage_devices) == 1, "Only support one stage for now."
for rank, node_resource, _ in stage_devices[0]:
worker_cls = MWorker
if self.parallel_config.worker_use_ray:
worker_cls = ray.remote(
num_cpus=0,
num_gpus=1,
resources={node_resource: 1e-3},
)(worker_cls).remote

worker = worker_cls(
model_config,
parallel_config,
scheduler_config,
rank,
distributed_init_method,
)
self.workers.append(worker)
# Profile the memory usage and initialize the cache.
self._init_cache()

# Create the scheduler.
self.scheduler = Scheduler(scheduler_config, cache_config, log_stats)

def add_request(
self,
request_id: str,
prompt: Optional[str],
image: Optional[dict] = None,
sampling_params: SamplingParams = None,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None
) -> None:
"""Add a request to the engine's request pool.

The request is added to the request pool and will be processed by the
scheduler as `engine.step()` is called. The exact scheduling policy is
determined by the scheduler.

Args:
request_id: The unique ID of the request.
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
sampling_params: The sampling parameters for text generation.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use
the current time.
"""

mm_use_im_start_end = self.workers[0].model.model.vision_tower[0].config.use_im_start_end
image_size = self.workers[0].model.model.vision_tower[0].config.image_size
patch_size = self.workers[0].model.model.vision_tower[0].config.patch_size
image_token_len = int((image_size / patch_size) ** 2)

if arrival_time is None:
arrival_time = time.time()
if image:
if mm_use_im_start_end:
image_tokens = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN
else:
image_tokens = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
prompt += image_tokens
image_data = self._load_image(image)
else:
image_data = None
prompt_token_ids = self.tokenizer.encode(prompt)

# Create the sequences.
block_size = self.cache_config.block_size
seqs: List[Sequence] = []
for _ in range(sampling_params.best_of):
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, image_data=image_data)
seqs.append(seq)

# Create the sequence group.
seq_group = SequenceGroup(request_id, seqs, sampling_params,
arrival_time)

# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)

def _load_image(self, image_srcs):
images = []
image_srcs = image_srcs if isinstance(image_srcs, list) else [image_srcs]
for image_src_i in image_srcs:
image_file = image_src_i.get("image_src")
src_type = image_src_i.get("src_type")

if src_type == "url":
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
elif src_type == "local":
image = Image.open(image_file).convert('RGB')
elif src_type == "base64":
image = Image.open(BytesIO(base64.b64decode(image_file))).convert('RGB')
else:
assert 0, "src_type is not true"
image_tensor = self.workers[0].model.model.image_processor(image, return_tensors='pt')['pixel_values'][0]
images.append(image_tensor.half().cuda())
return images

def initialize_vision_tokenizer(self):
self._run_workers(
"initialize_vision_tokenizer",
tokenizer=self.tokenizer
)
159 changes: 159 additions & 0 deletions vllm/entrypoints/mllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
from typing import Optional, Union, List

from vllm import LLM, SamplingParams, RequestOutput, EngineArgs
from tqdm import tqdm

from vllm.engine.mllm_engine import MLLMEngine, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from vllm.utils import Counter


class MLLM(LLM):
def __init__(
self,
model: str,
tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tensor_parallel_size: int = 1,
dtype: str = "auto",
seed: int = 0,
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
CLIP_MODEL_MAP={}
CLIP_MODEL_MAP.update({"openai/clip-vit-large-patch14":f"{os.path.abspath(model)}/clip-vit-large-patch14"})

engine_args = EngineArgs(
model=model,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
seed=seed,
**kwargs,
)
self.mllm_engine = MLLMEngine.from_engine_args(engine_args)
self.mllm_engine.initialize_vision_tokenizer()
self.request_counter = Counter()

def generate(
self,
prompts: Optional[Union[str, List[str]]] = None,
images: Optional[Union[dict, List[dict]]] = None,
sampling_params: Optional[SamplingParams] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
) -> List[RequestOutput]:
"""Generates the completions for the input prompts.

NOTE: This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
into a single list and pass it to this method.

Args:
prompts: A list of prompts to generate completions for.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
prompt_tokes_t = time.time()n_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar.

Returns:
A list of `RequestOutput` objects containing the generated
completions in the same order as the input prompts.
"""
if prompts is None and prompt_token_ids is None:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")

assert len(prompts) == len(
images), f"The number of images entered should be the same as the number of text,get image number is " \
f"{len(images)} but text number is {len(prompts)}." \
"if image is None, please use {} placeholder。"

if isinstance(prompts, str):
# Convert a single prompt to a list.
prompts = [prompts]
if prompts is not None and prompt_token_ids is not None:
if len(prompts) != len(prompt_token_ids):
raise ValueError("The lengths of prompts and prompt_token_ids "
"must be the same.")
if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()

# Add requests to the engine.
if prompts is not None:
num_requests = len(prompts)
else:
num_requests = len(prompt_token_ids)

for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
if prompt_token_ids is None:
token_ids = None
else:
token_ids = prompt_token_ids[i]
image = images[i]
self._add_request(prompt, sampling_params, token_ids, image)

result = self._run_engine(use_tqdm)
return result

def _prompt_image(self, prompts: List[str], images: List[dict], is_only_prompts=False) -> Union[List[str], None]:
assert len(prompts) == len(
images), f"The number of images entered should be the same as the number of text,get image number is " \
f"{len(images)} but text number is {len(prompts)}." \
"if image is None, please use {} placeholder。"
if is_only_prompts:
results = []
for prompt, image in zip(prompts, images):
if image:
img_data = image.get("image_src")
img_type = image.get("src_type")
img_prompt = " ".join(
[DEFAULT_IM_START_TOKEN, img_type, DEFAULT_IMAGE_PATCH_TOKEN, img_data, DEFAULT_IM_END_TOKEN])
prompt += img_prompt
results.append(prompt)
return results

def _add_request(
self,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]],
image: Optional[dict] = None
) -> None:
request_id = str(next(self.request_counter))

self.mllm_engine.add_request(request_id, prompt, image, sampling_params,
prompt_token_ids)


def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.mllm_engine.get_num_unfinished_requests()
pbar = tqdm(total=num_requests, desc="Processed prompts")
# Run the engine.
outputs: List[RequestOutput] = []
while self.mllm_engine.has_unfinished_requests():
step_outputs = self.mllm_engine.step()
for output in step_outputs:
if output.finished:
outputs.append(output)
if use_tqdm:
pbar.update(1)
if use_tqdm:
pbar.close()
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
outputs = sorted(outputs, key=lambda x: int(x.request_id))
return outputs
2 changes: 2 additions & 0 deletions vllm/model_executor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from vllm.config import ModelConfig
from vllm.model_executor.models import * # pylint: disable=wildcard-import
from vllm.model_executor.models.llava import LlavaLlamaForCausalLM
from vllm.model_executor.weight_utils import initialize_dummy_weights

# TODO(woosuk): Lazy-load the model classes.
Expand All @@ -22,6 +23,7 @@
"InternLMForCausalLM": InternLMForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM,
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
"LlavaLlamaForCausalLM": LlavaLlamaForCausalLM,
"MPTForCausalLM": MPTForCausalLM,
"OPTForCausalLM": OPTForCausalLM,
"QWenLMHeadModel": QWenLMHeadModel,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,9 @@ def forward(
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
inputs_embeds: torch.Tensor = None
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
hidden_states = inputs_embeds if inputs_embeds is not None else self.embed_tokens(input_ids)
for i in range(len(self.layers)):
if cache_events is None:
cache_event = None
Expand Down
Loading