Skip to content

Commit

Permalink
fix paligemma inference
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed May 20, 2024
1 parent 7262679 commit 542229a
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 20 deletions.
25 changes: 19 additions & 6 deletions src/llamafactory/chat/hf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from transformers import GenerationConfig, TextIteratorStreamer

from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_TOKEN
from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response
Expand Down Expand Up @@ -55,14 +56,28 @@ def _process_args(
image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]:
if processor is not None and image is not None and "<image>" not in messages[0]["content"]:
messages[0]["content"] = "<image>" + messages[0]["content"]
if (
processor is not None
and image is not None
and not hasattr(processor, "image_seq_length")
and IMAGE_TOKEN not in messages[0]["content"]
): # llava case
messages[0]["content"] = IMAGE_TOKEN + messages[0]["content"]

paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"]
pixel_values = None
prompt_ids, _ = template.encode_oneturn(
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
)
if processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
batch_feature = image_processor(image, return_tensors="pt")
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
if hasattr(processor, "image_seq_length"): # paligemma case
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids

prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device)

Expand Down Expand Up @@ -122,10 +137,8 @@ def _process_args(
logits_processor=get_logits_processor(),
)

if processor is not None and image is not None:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
gen_kwargs["pixel_values"] = pixel_values.to(model.device)
if pixel_values is not None:
gen_kwargs["pixel_values"] = pixel_values

return gen_kwargs, prompt_length

Expand Down
29 changes: 18 additions & 11 deletions src/llamafactory/chat/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union

from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_TOKEN
from ..extras.logging import get_logger
from ..extras.misc import get_device_count, infer_optim_dtype
from ..extras.packages import is_vllm_available
Expand All @@ -17,7 +18,6 @@


if TYPE_CHECKING:
import torch
from numpy.typing import NDArray
from transformers.image_processing_utils import BaseImageProcessor

Expand Down Expand Up @@ -67,7 +67,7 @@ def __init__(
patch_size = config.vision_config.patch_size
self.image_feature_size = (image_size // patch_size) ** 2
engine_args["image_input_type"] = "pixel_values"
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids("<image>")
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
engine_args["image_feature_size"] = self.image_feature_size
if getattr(config, "is_yi_vl_derived_model", None):
Expand All @@ -92,14 +92,28 @@ async def _generate(
**input_kwargs,
) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
if self.processor is not None and image is not None and "<image>" not in messages[0]["content"]:
messages[0]["content"] = "<image>" * self.image_feature_size + messages[0]["content"]

if (
self.processor is not None
and image is not None
and not hasattr(self.processor, "image_seq_length")
and IMAGE_TOKEN not in messages[0]["content"]
): # llava case
messages[0]["content"] = IMAGE_TOKEN * self.image_feature_size + messages[0]["content"]

paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
)

if self.processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else:
multi_modal_data = None

prompt_length = len(prompt_ids)

use_beam_search: bool = self.generating_args["num_beams"] > 1
Expand Down Expand Up @@ -144,13 +158,6 @@ async def _generate(
skip_special_tokens=True,
)

if self.processor is not None and image is not None:
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else:
multi_modal_data = None

result_generator = self.model.generate(
prompt=None,
sampling_params=sampling_params,
Expand Down
11 changes: 8 additions & 3 deletions src/llamafactory/data/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,10 +290,10 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
slot_items.append(placeholder)
if slot_pieces[1]:
slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'")
elif isinstance(slot, set):
if "bos_token" in slot:
elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
if "bos_token" in slot and tokenizer.bos_token_id is not None:
slot_items.append("'" + tokenizer.bos_token + "'")
elif "eos_token" in slot: # do not use {{ eos_token }} since it may be replaced
elif "eos_token" in slot and tokenizer.eos_token_id is not None:
slot_items.append("'" + tokenizer.eos_token + "'")
elif isinstance(slot, dict):
raise ValueError("Dict is not supported.")
Expand Down Expand Up @@ -325,9 +325,11 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
jinja_template += "{% set content = " + system_message + " + message['content'] %}"
jinja_template += "{% endif %}"

jinja_template += "{% if message['role'] == 'user' %}"
user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
jinja_template += "{{ " + user_message + " }}"

jinja_template += "{% elif message['role'] == 'assistant' %}"
assistant_message = _convert_slots_to_jinja(
template.format_assistant.apply() + template.format_separator.apply(), tokenizer
Expand Down Expand Up @@ -614,6 +616,9 @@ def get_template_and_fix_tokenizer(
name="empty",
format_user=StringFormatter(slots=["{{content}}"]),
format_assistant=StringFormatter(slots=["{{content}}"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
efficient_eos=True,
force_system=True,
)


Expand Down
2 changes: 2 additions & 0 deletions src/llamafactory/extras/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

IGNORE_INDEX = -100

IMAGE_TOKEN = "<image>"

LAYERNORM_NAMES = {"norm", "ln"}

METHODS = ["full", "freeze", "lora"]
Expand Down

0 comments on commit 542229a

Please sign in to comment.